From e94e4c0ce1f296154833b2e68c41d6d5b3b2bec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Thu, 24 Jan 2019 15:30:36 +0100 Subject: [PATCH] Iterative clustering --- .../polimi/middleware/projects/flink/KMeans.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java index 5e0b010..8ec7850 100644 --- a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java +++ b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.utils.ParameterTool; @@ -34,6 +35,7 @@ public class KMeans { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final Integer k = params.getInt("k", 3); + final Integer maxIterations = params.getInt("maxIterations", 25); // Read CSV input DataSet> csvInput = env.readCsvFile(params.get("input")).types(Double.class); @@ -44,9 +46,7 @@ public class KMeans { // Generate random centroids final RandomCentroids r = new RandomCentroids(k); - DataSet centroids = env.fromCollection(r, Double.class); - - centroids.print(); + IterativeDataSet centroids = env.fromCollection(r, Double.class).iterate(maxIterations); // Assign points to centroids DataSet> assigned = input @@ -59,11 +59,12 @@ public class KMeans { .reduce(new MeanSum()) .map(new MeanDivide()); - // Re-assign points to centroids - assigned = input - .map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids"); + DataSet finalCentroids = centroids.closeWith(newCentroids); + + // Final assignment of points to centroids + assigned = input + .map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids"); - newCentroids.print(); assigned.writeAsCsv(params.get("output", "output.csv")); env.execute("K-Means clustering");