From a9a0e518e88d37bd05cee5aa2485b17f69ff76af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Thu, 24 Jan 2019 20:12:14 +0100 Subject: [PATCH] Support other ranges than [0, 1] --- genVectors.py | 4 +- .../middleware/projects/flink/KMeans.java | 58 +++++++++++++------ 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/genVectors.py b/genVectors.py index b31a919..9318841 100755 --- a/genVectors.py +++ b/genVectors.py @@ -5,8 +5,10 @@ import sys random.seed(0) +MAX = 100 + D = int(sys.argv[1]) # Number of dimensions N = int(sys.argv[2]) # Number of vectors for _ in range(N): - print(','.join([str(random.random()) for _ in range(D)])) + print(','.join([str(random.random() * MAX) for _ in range(D)])) diff --git a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java index 12a96e0..b8d5121 100644 --- a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java +++ b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java @@ -12,6 +12,7 @@ import java.util.Iterator; import java.util.List; import java.util.Random; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; @@ -20,7 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.Collector; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; @@ -44,9 +47,30 @@ public class KMeans { DataSet input = inputCsv .map(tuple -> new Point(tuple.f0, tuple.f1)); + // Find min and max of the coordinates to determine where the initial centroids should be + DataSet> area = input + .map(new MapFunction>() { + @Override + public Tuple4 map(Point point) { + return new Tuple4(point.x, point.y, point.x, point.y); + } + }).reduce(new FindArea()); + + area.print(); + + DataSet> testCentroids = area + .flatMap(new RandomCentroids(k)) + .map(new MapFunction>() { + @Override + public Tuple2 map(Point point) { + return new Tuple2(point.x, point.y); + }}); + testCentroids.print(); + // Generate random centroids - final RandomCentroids r = new RandomCentroids(k); - IterativeDataSet centroids = env.fromCollection(r, Point.class).iterate(maxIterations); + IterativeDataSet centroids = area + .flatMap(new RandomCentroids(k)) + .iterate(maxIterations); // Assign points to centroids DataSet> assigned = input @@ -115,34 +139,32 @@ public class KMeans { } - public static class RandomCentroids implements Iterator, Serializable { + public static class FindArea implements ReduceFunction> { + // minX, minY, maxX, maxY + @Override + public Tuple4 reduce(Tuple4 a, Tuple4 b) { + return new Tuple4(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3)); + } + } + public static class RandomCentroids implements FlatMapFunction, Point> { Integer k; - Integer i; Random r; public RandomCentroids(Integer k) { this.k = k; - this.i = 0; this.r = new Random(0); } - @Override - public boolean hasNext() { - return i < k; + private Double randomRange(Double min, Double max) { + return min + (r.nextDouble() * (max - min)); } @Override - public Point next() { - i += 1; - return new Point(r.nextDouble(), r.nextDouble()); - } - - private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException { - inputStream.defaultReadObject(); - } - private void writeObject(ObjectOutputStream outputStream) throws IOException { - outputStream.defaultWriteObject(); + public void flatMap(Tuple4 area, Collector out) { + for (int i = 0; i < k; i++) { + out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3))); + } } }