Iterative clustering
This commit is contained in:
parent
72a187112c
commit
e94e4c0ce1
|
@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration;
|
||||||
|
|
||||||
import org.apache.flink.api.java.DataSet;
|
import org.apache.flink.api.java.DataSet;
|
||||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||||
|
import org.apache.flink.api.java.operators.IterativeDataSet;
|
||||||
import org.apache.flink.api.java.utils.ParameterTool;
|
import org.apache.flink.api.java.utils.ParameterTool;
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +35,7 @@ public class KMeans {
|
||||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||||
|
|
||||||
final Integer k = params.getInt("k", 3);
|
final Integer k = params.getInt("k", 3);
|
||||||
|
final Integer maxIterations = params.getInt("maxIterations", 25);
|
||||||
|
|
||||||
// Read CSV input
|
// Read CSV input
|
||||||
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
||||||
|
@ -44,9 +46,7 @@ public class KMeans {
|
||||||
|
|
||||||
// Generate random centroids
|
// Generate random centroids
|
||||||
final RandomCentroids r = new RandomCentroids(k);
|
final RandomCentroids r = new RandomCentroids(k);
|
||||||
DataSet<Double> centroids = env.fromCollection(r, Double.class);
|
IterativeDataSet<Double> centroids = env.fromCollection(r, Double.class).iterate(maxIterations);
|
||||||
|
|
||||||
centroids.print();
|
|
||||||
|
|
||||||
// Assign points to centroids
|
// Assign points to centroids
|
||||||
DataSet<Tuple2<Double, Integer>> assigned = input
|
DataSet<Tuple2<Double, Integer>> assigned = input
|
||||||
|
@ -59,11 +59,12 @@ public class KMeans {
|
||||||
.reduce(new MeanSum())
|
.reduce(new MeanSum())
|
||||||
.map(new MeanDivide());
|
.map(new MeanDivide());
|
||||||
|
|
||||||
// Re-assign points to centroids
|
DataSet<Double> finalCentroids = centroids.closeWith(newCentroids);
|
||||||
assigned = input
|
|
||||||
.map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids");
|
// Final assignment of points to centroids
|
||||||
|
assigned = input
|
||||||
|
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
|
||||||
|
|
||||||
newCentroids.print();
|
|
||||||
assigned.writeAsCsv(params.get("output", "output.csv"));
|
assigned.writeAsCsv(params.get("output", "output.csv"));
|
||||||
|
|
||||||
env.execute("K-Means clustering");
|
env.execute("K-Means clustering");
|
||||||
|
|
Reference in a new issue