Support other ranges than [0, 1]
This commit is contained in:
parent
d43f0ebec8
commit
a9a0e518e8
|
@ -5,8 +5,10 @@ import sys
|
|||
|
||||
random.seed(0)
|
||||
|
||||
MAX = 100
|
||||
|
||||
D = int(sys.argv[1]) # Number of dimensions
|
||||
N = int(sys.argv[2]) # Number of vectors
|
||||
|
||||
for _ in range(N):
|
||||
print(','.join([str(random.random()) for _ in range(D)]))
|
||||
print(','.join([str(random.random() * MAX) for _ in range(D)]))
|
||||
|
|
|
@ -12,6 +12,7 @@ import java.util.Iterator;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.flink.api.common.functions.FlatMapFunction;
|
||||
import org.apache.flink.api.common.functions.MapFunction;
|
||||
import org.apache.flink.api.common.functions.ReduceFunction;
|
||||
import org.apache.flink.api.common.functions.RichMapFunction;
|
||||
|
@ -20,7 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
|
|||
import org.apache.flink.api.java.tuple.Tuple1;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.api.java.tuple.Tuple3;
|
||||
import org.apache.flink.api.java.tuple.Tuple4;
|
||||
import org.apache.flink.configuration.Configuration;
|
||||
import org.apache.flink.util.Collector;
|
||||
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
|
@ -44,9 +47,30 @@ public class KMeans {
|
|||
DataSet<Point> input = inputCsv
|
||||
.map(tuple -> new Point(tuple.f0, tuple.f1));
|
||||
|
||||
// Find min and max of the coordinates to determine where the initial centroids should be
|
||||
DataSet<Tuple4<Double, Double, Double, Double>> area = input
|
||||
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() {
|
||||
@Override
|
||||
public Tuple4<Double, Double, Double, Double> map(Point point) {
|
||||
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
|
||||
}
|
||||
}).reduce(new FindArea());
|
||||
|
||||
area.print();
|
||||
|
||||
DataSet<Tuple2<Double, Double>> testCentroids = area
|
||||
.flatMap(new RandomCentroids(k))
|
||||
.map(new MapFunction<Point, Tuple2<Double, Double>>() {
|
||||
@Override
|
||||
public Tuple2<Double, Double> map(Point point) {
|
||||
return new Tuple2<Double, Double>(point.x, point.y);
|
||||
}});
|
||||
testCentroids.print();
|
||||
|
||||
// Generate random centroids
|
||||
final RandomCentroids r = new RandomCentroids(k);
|
||||
IterativeDataSet<Point> centroids = env.fromCollection(r, Point.class).iterate(maxIterations);
|
||||
IterativeDataSet<Point> centroids = area
|
||||
.flatMap(new RandomCentroids(k))
|
||||
.iterate(maxIterations);
|
||||
|
||||
// Assign points to centroids
|
||||
DataSet<Tuple2<Point, Integer>> assigned = input
|
||||
|
@ -115,34 +139,32 @@ public class KMeans {
|
|||
|
||||
}
|
||||
|
||||
public static class RandomCentroids implements Iterator<Point>, Serializable {
|
||||
public static class FindArea implements ReduceFunction<Tuple4<Double, Double, Double, Double>> {
|
||||
// minX, minY, maxX, maxY
|
||||
@Override
|
||||
public Tuple4<Double, Double, Double, Double> reduce(Tuple4<Double, Double, Double, Double> a, Tuple4<Double, Double, Double, Double> b) {
|
||||
return new Tuple4<Double, Double, Double, Double>(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3));
|
||||
}
|
||||
}
|
||||
|
||||
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
|
||||
Integer k;
|
||||
Integer i;
|
||||
Random r;
|
||||
|
||||
public RandomCentroids(Integer k) {
|
||||
this.k = k;
|
||||
this.i = 0;
|
||||
this.r = new Random(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return i < k;
|
||||
private Double randomRange(Double min, Double max) {
|
||||
return min + (r.nextDouble() * (max - min));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Point next() {
|
||||
i += 1;
|
||||
return new Point(r.nextDouble(), r.nextDouble());
|
||||
public void flatMap(Tuple4<Double, Double, Double, Double> area, Collector<Point> out) {
|
||||
for (int i = 0; i < k; i++) {
|
||||
out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3)));
|
||||
}
|
||||
|
||||
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
|
||||
inputStream.defaultReadObject();
|
||||
}
|
||||
private void writeObject(ObjectOutputStream outputStream) throws IOException {
|
||||
outputStream.defaultWriteObject();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Reference in a new issue