package it.polimi.middleware.projects.flink; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Random; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.utils.ParameterTool; public class KMeans { public static void main(String[] args) throws Exception { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final ParameterTool params = ParameterTool.fromArgs(args); final Integer k = params.getInt("k", 3); final Integer maxIterations = params.getInt("maxIterations", 25); // Read CSV input DataSet> inputCsv = env.readCsvFile(params.get("input")).types(Double.class, Double.class); // Convert to internal format DataSet input = inputCsv .map(tuple -> new Point(tuple.f0, tuple.f1)); // Find min and max of the coordinates to determine where the initial centroids should be DataSet> area = input .map(new MapFunction>() { // Format points so // they can be passed as reduce parameters @Override public Tuple4 map(Point point) { return new Tuple4(point.x, point.y, point.x, point.y); } }).reduce(new FindArea()); // Gives the minX, minY, maxX, maxY of all the point // Generate random centroids IterativeDataSet centroids = area .flatMap(new RandomCentroids(k)) // Create centroids randomly in the area of the points .iterate(maxIterations); // Mark beginning of the loop // Assign points to centroids DataSet> assigned = input .map(new AssignCentroid()).withBroadcastSet(centroids, "centroids"); // Calculate means DataSet newCentroids = assigned .map(new MeanPrepare()) // Add Integer field to tuple to count the points .groupBy(1) // GroupBy CentroidID .reduce(new MeanSum()) // Sum every points by centroid .map(new MeanDivide()); // Divide by the number of points to get the average DataSet finalCentroids = centroids.closeWith(newCentroids); // Mark end of the loop // Final assignment of points to centroids (that's the data we want) assigned = input .map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids"); // Convert to CSV format DataSet> output = assigned .map(new MapFunction, Tuple3>() { @Override public Tuple3 map(Tuple2 tuple) { return new Tuple3(tuple.f0.x, tuple.f0.y, tuple.f1); } }); output.writeAsCsv(params.get("output", "output.csv")); env.execute("K-Means clustering"); } public static class Point implements Comparable { public Double x; public Double y; public Point(Double x, Double y) { this.x = x; this.y = y; } public int compareTo(Point other) { int comp = x.compareTo(other.x); if (comp == 0) { comp = y.compareTo(other.y); } return comp; } public Point addTo(Point other) { // Since input is always re-fetched we can overwrite the values x += other.x; y += other.y; return this; } public Point divideBy(Integer factor) { // Since input is always re-fetched we can overwrite the values x /= factor; y /= factor; return this; } public Double distanceTo(Point other) { return Math.sqrt(Math.pow(other.x - x, 2) + Math.pow(other.y - y, 2)); } } public static class FindArea implements ReduceFunction> { // minX, minY, maxX, maxY @Override public Tuple4 reduce(Tuple4 a, Tuple4 b) { return new Tuple4(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3)); } } public static class RandomCentroids implements FlatMapFunction, Point> { // minX, minY, maxX, maxY → Point × k Integer k; Random r; public RandomCentroids(Integer k) { this.k = k; this.r = new Random(); } private Double randomRange(Double min, Double max) { return min + (r.nextDouble() * (max - min)); } @Override public void flatMap(Tuple4 area, Collector out) { for (int i = 0; i < k; i++) { out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3))); } } } public static class AssignCentroid extends RichMapFunction> { // Point → Point, CentroidID private List centroids; @Override public void open(Configuration parameters) throws Exception { // Centroids are sorted so they have an identifier common to all the operators centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids")); Collections.sort(centroids); } @Override public Tuple2 map(Point point) { // Calculate the distance Point-Centroid for all centroids, // keep the identifier of the closest centroid Integer c; Point centroid; Double distance; Integer minCentroid = 0; Double minDistance = Double.POSITIVE_INFINITY; for (c = 0; c < centroids.size(); c++) { centroid = centroids.get(c); distance = point.distanceTo(centroid); if (distance < minDistance) { minCentroid = c; minDistance = distance; } } return new Tuple2(point, minCentroid); } } public static class MeanPrepare implements MapFunction, Tuple3> { // Point, CentroidID → Point, CentroidID, Number of points @Override public Tuple3 map(Tuple2 info) { return new Tuple3(info.f0, info.f1, 1); } } public static class MeanSum implements ReduceFunction> { // Point, CentroidID (irrelevant), Number of points @Override public Tuple3 reduce(Tuple3 a, Tuple3 b) { return new Tuple3(a.f0.addTo(b.f0), a.f1, a.f2 + b.f2); } } public static class MeanDivide implements MapFunction, Point> { // Point, CentroidID (irrelevant), Number of points → Point @Override public Point map(Tuple3 info) { return info.f0.divideBy(info.f2); } } }