Iterative clustering
This commit is contained in:
parent
72a187112c
commit
e94e4c0ce1
|
@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration;
|
|||
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
import org.apache.flink.api.java.operators.IterativeDataSet;
|
||||
import org.apache.flink.api.java.utils.ParameterTool;
|
||||
|
||||
|
||||
|
@ -34,6 +35,7 @@ public class KMeans {
|
|||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
|
||||
final Integer k = params.getInt("k", 3);
|
||||
final Integer maxIterations = params.getInt("maxIterations", 25);
|
||||
|
||||
// Read CSV input
|
||||
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
||||
|
@ -44,9 +46,7 @@ public class KMeans {
|
|||
|
||||
// Generate random centroids
|
||||
final RandomCentroids r = new RandomCentroids(k);
|
||||
DataSet<Double> centroids = env.fromCollection(r, Double.class);
|
||||
|
||||
centroids.print();
|
||||
IterativeDataSet<Double> centroids = env.fromCollection(r, Double.class).iterate(maxIterations);
|
||||
|
||||
// Assign points to centroids
|
||||
DataSet<Tuple2<Double, Integer>> assigned = input
|
||||
|
@ -59,11 +59,12 @@ public class KMeans {
|
|||
.reduce(new MeanSum())
|
||||
.map(new MeanDivide());
|
||||
|
||||
// Re-assign points to centroids
|
||||
assigned = input
|
||||
.map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids");
|
||||
DataSet<Double> finalCentroids = centroids.closeWith(newCentroids);
|
||||
|
||||
// Final assignment of points to centroids
|
||||
assigned = input
|
||||
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
|
||||
|
||||
newCentroids.print();
|
||||
assigned.writeAsCsv(params.get("output", "output.csv"));
|
||||
|
||||
env.execute("K-Means clustering");
|
||||
|
|
Reference in a new issue