Iterative clustering

This commit is contained in:
Geoffrey Frogeye 2019-01-24 15:30:36 +01:00
parent 72a187112c
commit e94e4c0ce1

View file

@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.utils.ParameterTool; import org.apache.flink.api.java.utils.ParameterTool;
@ -34,6 +35,7 @@ public class KMeans {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
final Integer k = params.getInt("k", 3); final Integer k = params.getInt("k", 3);
final Integer maxIterations = params.getInt("maxIterations", 25);
// Read CSV input // Read CSV input
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class); DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
@ -44,9 +46,7 @@ public class KMeans {
// Generate random centroids // Generate random centroids
final RandomCentroids r = new RandomCentroids(k); final RandomCentroids r = new RandomCentroids(k);
DataSet<Double> centroids = env.fromCollection(r, Double.class); IterativeDataSet<Double> centroids = env.fromCollection(r, Double.class).iterate(maxIterations);
centroids.print();
// Assign points to centroids // Assign points to centroids
DataSet<Tuple2<Double, Integer>> assigned = input DataSet<Tuple2<Double, Integer>> assigned = input
@ -59,11 +59,12 @@ public class KMeans {
.reduce(new MeanSum()) .reduce(new MeanSum())
.map(new MeanDivide()); .map(new MeanDivide());
// Re-assign points to centroids DataSet<Double> finalCentroids = centroids.closeWith(newCentroids);
assigned = input
.map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids"); // Final assignment of points to centroids
assigned = input
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
newCentroids.print();
assigned.writeAsCsv(params.get("output", "output.csv")); assigned.writeAsCsv(params.get("output", "output.csv"));
env.execute("K-Means clustering"); env.execute("K-Means clustering");