From d43f0ebec8a1cce62fad61ea44e447ea9d2da40c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Thu, 24 Jan 2019 17:31:52 +0100 Subject: [PATCH] 2 Dimensions --- .../middleware/projects/flink/KMeans.java | 106 ++++++++++++------ 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java index 8ec7850..12a96e0 100644 --- a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java +++ b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java @@ -38,40 +38,84 @@ public class KMeans { final Integer maxIterations = params.getInt("maxIterations", 25); // Read CSV input - DataSet> csvInput = env.readCsvFile(params.get("input")).types(Double.class); + DataSet> inputCsv = env.readCsvFile(params.get("input")).types(Double.class, Double.class); - // Convert CSV to internal format - DataSet input = csvInput - .map(point -> point.f0); + // Convert to internal format + DataSet input = inputCsv + .map(tuple -> new Point(tuple.f0, tuple.f1)); // Generate random centroids final RandomCentroids r = new RandomCentroids(k); - IterativeDataSet centroids = env.fromCollection(r, Double.class).iterate(maxIterations); + IterativeDataSet centroids = env.fromCollection(r, Point.class).iterate(maxIterations); // Assign points to centroids - DataSet> assigned = input + DataSet> assigned = input .map(new AssignCentroid()).withBroadcastSet(centroids, "centroids"); // Calculate means - DataSet newCentroids = assigned + DataSet newCentroids = assigned .map(new MeanPrepare()) .groupBy(1) // GroupBy CentroidID .reduce(new MeanSum()) .map(new MeanDivide()); - DataSet finalCentroids = centroids.closeWith(newCentroids); + DataSet finalCentroids = centroids.closeWith(newCentroids); // Final assignment of points to centroids assigned = input .map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids"); - assigned.writeAsCsv(params.get("output", "output.csv")); + // Convert to external format + DataSet> output = assigned + .map(new MapFunction, Tuple3>() { + @Override + public Tuple3 map(Tuple2 tuple) { + return new Tuple3(tuple.f0.x, tuple.f0.y, tuple.f1); + } + }); + + output.writeAsCsv(params.get("output", "output.csv")); env.execute("K-Means clustering"); } + public static class Point implements Comparable { + public Double x; + public Double y; - public static class RandomCentroids implements Iterator, Serializable { + public Point(Double x, Double y) { + this.x = x; + this.y = y; + } + + public int compareTo(Point other) { + int comp = x.compareTo(other.x); + if (comp == 0) { + comp = y.compareTo(other.y); + } + return comp; + } + + public Point addTo(Point other) { + // Since input is always re-fetched we can overwrite the values + x += other.x; + y += other.y; + return this; + } + + public Point divideBy(Integer factor) { + x /= factor; + y /= factor; + return this; + } + + public Double distanceTo(Point other) { + return Math.sqrt(Math.pow(other.x - x, 2) + Math.pow(other.y - y, 2)); + } + + } + + public static class RandomCentroids implements Iterator, Serializable { Integer k; Integer i; @@ -89,9 +133,9 @@ public class KMeans { } @Override - public Double next() { + public Point next() { i += 1; - return r.nextDouble(); + return new Point(r.nextDouble(), r.nextDouble()); } private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException { @@ -102,9 +146,9 @@ public class KMeans { } } - public static class AssignCentroid extends RichMapFunction> { + public static class AssignCentroid extends RichMapFunction> { // Point → Point, CentroidID - private List centroids; + private List centroids; @Override public void open(Configuration parameters) throws Exception { @@ -113,56 +157,48 @@ public class KMeans { } @Override - public Tuple2 map(Double point) { + public Tuple2 map(Point point) { Integer c; - Double centroid; + Point centroid; Double distance; Integer minCentroid = 4; Double minDistance = Double.POSITIVE_INFINITY; for (c = 0; c < centroids.size(); c++) { centroid = centroids.get(c); - distance = distancePointCentroid(point, centroid); + distance = point.distanceTo(centroid); if (distance < minDistance) { minCentroid = c; minDistance = distance; } } - return new Tuple2(point, minCentroid); + return new Tuple2(point, minCentroid); } - private Double distancePointCentroid(Double point, Double centroid) { - return Math.abs(point - centroid); - // return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2)); - } } - public static class MeanPrepare implements MapFunction, Tuple3> { + public static class MeanPrepare implements MapFunction, Tuple3> { // Point, CentroidID → Point, CentroidID, Number of points @Override - public Tuple3 map(Tuple2 point) { - return new Tuple3(point.f0, point.f1, 1); + public Tuple3 map(Tuple2 info) { + return new Tuple3(info.f0, info.f1, 1); } } - public static class MeanSum implements ReduceFunction> { + public static class MeanSum implements ReduceFunction> { // Point, CentroidID (irrelevant), Number of points @Override - public Tuple3 reduce(Tuple3 a, Tuple3 b) { - return new Tuple3(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2); - } - - private Double sumPoints(Double a, Double b) { - return a + b; + public Tuple3 reduce(Tuple3 a, Tuple3 b) { + return new Tuple3(a.f0.addTo(b.f0), a.f1, a.f2 + b.f2); } } - public static class MeanDivide implements MapFunction, Double> { + public static class MeanDivide implements MapFunction, Point> { // Point, CentroidID (irrelevant), Number of points → Point @Override - public Double map(Tuple3 point) { - return point.f0 / point.f2; + public Point map(Tuple3 info) { + return info.f0.divideBy(info.f2); } }