Support other ranges than [0, 1]

This commit is contained in:
Geoffrey Frogeye 2019-01-24 20:12:14 +01:00
parent d43f0ebec8
commit a9a0e518e8
2 changed files with 43 additions and 19 deletions

View file

@ -5,8 +5,10 @@ import sys
random.seed(0) random.seed(0)
MAX = 100
D = int(sys.argv[1]) # Number of dimensions D = int(sys.argv[1]) # Number of dimensions
N = int(sys.argv[2]) # Number of vectors N = int(sys.argv[2]) # Number of vectors
for _ in range(N): for _ in range(N):
print(','.join([str(random.random()) for _ in range(D)])) print(','.join([str(random.random() * MAX) for _ in range(D)]))

View file

@ -12,6 +12,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapFunction;
@ -20,7 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.ExecutionEnvironment;
@ -44,9 +47,30 @@ public class KMeans {
DataSet<Point> input = inputCsv DataSet<Point> input = inputCsv
.map(tuple -> new Point(tuple.f0, tuple.f1)); .map(tuple -> new Point(tuple.f0, tuple.f1));
// Find min and max of the coordinates to determine where the initial centroids should be
DataSet<Tuple4<Double, Double, Double, Double>> area = input
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() {
@Override
public Tuple4<Double, Double, Double, Double> map(Point point) {
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
}
}).reduce(new FindArea());
area.print();
DataSet<Tuple2<Double, Double>> testCentroids = area
.flatMap(new RandomCentroids(k))
.map(new MapFunction<Point, Tuple2<Double, Double>>() {
@Override
public Tuple2<Double, Double> map(Point point) {
return new Tuple2<Double, Double>(point.x, point.y);
}});
testCentroids.print();
// Generate random centroids // Generate random centroids
final RandomCentroids r = new RandomCentroids(k); IterativeDataSet<Point> centroids = area
IterativeDataSet<Point> centroids = env.fromCollection(r, Point.class).iterate(maxIterations); .flatMap(new RandomCentroids(k))
.iterate(maxIterations);
// Assign points to centroids // Assign points to centroids
DataSet<Tuple2<Point, Integer>> assigned = input DataSet<Tuple2<Point, Integer>> assigned = input
@ -115,34 +139,32 @@ public class KMeans {
} }
public static class RandomCentroids implements Iterator<Point>, Serializable { public static class FindArea implements ReduceFunction<Tuple4<Double, Double, Double, Double>> {
// minX, minY, maxX, maxY
@Override
public Tuple4<Double, Double, Double, Double> reduce(Tuple4<Double, Double, Double, Double> a, Tuple4<Double, Double, Double, Double> b) {
return new Tuple4<Double, Double, Double, Double>(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3));
}
}
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
Integer k; Integer k;
Integer i;
Random r; Random r;
public RandomCentroids(Integer k) { public RandomCentroids(Integer k) {
this.k = k; this.k = k;
this.i = 0;
this.r = new Random(0); this.r = new Random(0);
} }
@Override private Double randomRange(Double min, Double max) {
public boolean hasNext() { return min + (r.nextDouble() * (max - min));
return i < k;
} }
@Override @Override
public Point next() { public void flatMap(Tuple4<Double, Double, Double, Double> area, Collector<Point> out) {
i += 1; for (int i = 0; i < k; i++) {
return new Point(r.nextDouble(), r.nextDouble()); out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3)));
} }
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
inputStream.defaultReadObject();
}
private void writeObject(ObjectOutputStream outputStream) throws IOException {
outputStream.defaultWriteObject();
} }
} }