package it.polimi.middleware.projects.flink; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Random; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.utils.ParameterTool; public class KMeans { public static void main(String[] args) throws Exception { final ParameterTool params = ParameterTool.fromArgs(args); final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final Integer k = params.getInt("k", 3); // Read CSV input DataSet> csvInput = env.readCsvFile(params.get("input")).types(Double.class); // Convert CSV to internal format DataSet input = csvInput .map(point -> point.f0); // Generate random centroids final RandomCentroids r = new RandomCentroids(k); DataSet centroids = env.fromCollection(r, Double.class); centroids.print(); // Assign points to centroids DataSet> assigned = input .map(new AssignCentroid()).withBroadcastSet(centroids, "centroids"); // Calculate means DataSet newCentroids = assigned .map(new MeanPrepare()) .groupBy(1) // GroupBy CentroidID .reduce(new MeanSum()) .map(new MeanDivide()); // Re-assign points to centroids assigned = input .map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids"); newCentroids.print(); assigned.writeAsCsv(params.get("output", "output.csv")); env.execute("K-Means clustering"); } public static class RandomCentroids implements Iterator, Serializable { Integer k; Integer i; Random r; public RandomCentroids(Integer k) { this.k = k; this.i = 0; this.r = new Random(0); } @Override public boolean hasNext() { return i < k; } @Override public Double next() { i += 1; return r.nextDouble(); } private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException { inputStream.defaultReadObject(); } private void writeObject(ObjectOutputStream outputStream) throws IOException { outputStream.defaultWriteObject(); } } public static class AssignCentroid extends RichMapFunction> { // Point → Point, CentroidID private List centroids; @Override public void open(Configuration parameters) throws Exception { centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids")); Collections.sort(centroids); } @Override public Tuple2 map(Double point) { Integer c; Double centroid; Double distance; Integer minCentroid = 4; Double minDistance = Double.POSITIVE_INFINITY; for (c = 0; c < centroids.size(); c++) { centroid = centroids.get(c); distance = distancePointCentroid(point, centroid); if (distance < minDistance) { minCentroid = c; minDistance = distance; } } return new Tuple2(point, minCentroid); } private Double distancePointCentroid(Double point, Double centroid) { return Math.abs(point - centroid); // return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2)); } } public static class MeanPrepare implements MapFunction, Tuple3> { // Point, CentroidID → Point, CentroidID, Number of points @Override public Tuple3 map(Tuple2 point) { return new Tuple3(point.f0, point.f1, 1); } } public static class MeanSum implements ReduceFunction> { // Point, CentroidID (irrelevant), Number of points @Override public Tuple3 reduce(Tuple3 a, Tuple3 b) { return new Tuple3(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2); } private Double sumPoints(Double a, Double b) { return a + b; } } public static class MeanDivide implements MapFunction, Double> { // Point, CentroidID (irrelevant), Number of points → Point @Override public Double map(Tuple3 point) { return point.f0 / point.f2; } } }