Support other ranges than [0, 1]
This commit is contained in:
parent
d43f0ebec8
commit
a9a0e518e8
|
@ -5,8 +5,10 @@ import sys
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
|
MAX = 100
|
||||||
|
|
||||||
D = int(sys.argv[1]) # Number of dimensions
|
D = int(sys.argv[1]) # Number of dimensions
|
||||||
N = int(sys.argv[2]) # Number of vectors
|
N = int(sys.argv[2]) # Number of vectors
|
||||||
|
|
||||||
for _ in range(N):
|
for _ in range(N):
|
||||||
print(','.join([str(random.random()) for _ in range(D)]))
|
print(','.join([str(random.random() * MAX) for _ in range(D)]))
|
||||||
|
|
|
@ -12,6 +12,7 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import org.apache.flink.api.common.functions.FlatMapFunction;
|
||||||
import org.apache.flink.api.common.functions.MapFunction;
|
import org.apache.flink.api.common.functions.MapFunction;
|
||||||
import org.apache.flink.api.common.functions.ReduceFunction;
|
import org.apache.flink.api.common.functions.ReduceFunction;
|
||||||
import org.apache.flink.api.common.functions.RichMapFunction;
|
import org.apache.flink.api.common.functions.RichMapFunction;
|
||||||
|
@ -20,7 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||||
import org.apache.flink.api.java.tuple.Tuple1;
|
import org.apache.flink.api.java.tuple.Tuple1;
|
||||||
import org.apache.flink.api.java.tuple.Tuple2;
|
import org.apache.flink.api.java.tuple.Tuple2;
|
||||||
import org.apache.flink.api.java.tuple.Tuple3;
|
import org.apache.flink.api.java.tuple.Tuple3;
|
||||||
|
import org.apache.flink.api.java.tuple.Tuple4;
|
||||||
import org.apache.flink.configuration.Configuration;
|
import org.apache.flink.configuration.Configuration;
|
||||||
|
import org.apache.flink.util.Collector;
|
||||||
|
|
||||||
import org.apache.flink.api.java.DataSet;
|
import org.apache.flink.api.java.DataSet;
|
||||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||||
|
@ -44,9 +47,30 @@ public class KMeans {
|
||||||
DataSet<Point> input = inputCsv
|
DataSet<Point> input = inputCsv
|
||||||
.map(tuple -> new Point(tuple.f0, tuple.f1));
|
.map(tuple -> new Point(tuple.f0, tuple.f1));
|
||||||
|
|
||||||
|
// Find min and max of the coordinates to determine where the initial centroids should be
|
||||||
|
DataSet<Tuple4<Double, Double, Double, Double>> area = input
|
||||||
|
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() {
|
||||||
|
@Override
|
||||||
|
public Tuple4<Double, Double, Double, Double> map(Point point) {
|
||||||
|
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
|
||||||
|
}
|
||||||
|
}).reduce(new FindArea());
|
||||||
|
|
||||||
|
area.print();
|
||||||
|
|
||||||
|
DataSet<Tuple2<Double, Double>> testCentroids = area
|
||||||
|
.flatMap(new RandomCentroids(k))
|
||||||
|
.map(new MapFunction<Point, Tuple2<Double, Double>>() {
|
||||||
|
@Override
|
||||||
|
public Tuple2<Double, Double> map(Point point) {
|
||||||
|
return new Tuple2<Double, Double>(point.x, point.y);
|
||||||
|
}});
|
||||||
|
testCentroids.print();
|
||||||
|
|
||||||
// Generate random centroids
|
// Generate random centroids
|
||||||
final RandomCentroids r = new RandomCentroids(k);
|
IterativeDataSet<Point> centroids = area
|
||||||
IterativeDataSet<Point> centroids = env.fromCollection(r, Point.class).iterate(maxIterations);
|
.flatMap(new RandomCentroids(k))
|
||||||
|
.iterate(maxIterations);
|
||||||
|
|
||||||
// Assign points to centroids
|
// Assign points to centroids
|
||||||
DataSet<Tuple2<Point, Integer>> assigned = input
|
DataSet<Tuple2<Point, Integer>> assigned = input
|
||||||
|
@ -115,34 +139,32 @@ public class KMeans {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class RandomCentroids implements Iterator<Point>, Serializable {
|
public static class FindArea implements ReduceFunction<Tuple4<Double, Double, Double, Double>> {
|
||||||
|
// minX, minY, maxX, maxY
|
||||||
|
@Override
|
||||||
|
public Tuple4<Double, Double, Double, Double> reduce(Tuple4<Double, Double, Double, Double> a, Tuple4<Double, Double, Double, Double> b) {
|
||||||
|
return new Tuple4<Double, Double, Double, Double>(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
|
||||||
Integer k;
|
Integer k;
|
||||||
Integer i;
|
|
||||||
Random r;
|
Random r;
|
||||||
|
|
||||||
public RandomCentroids(Integer k) {
|
public RandomCentroids(Integer k) {
|
||||||
this.k = k;
|
this.k = k;
|
||||||
this.i = 0;
|
|
||||||
this.r = new Random(0);
|
this.r = new Random(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private Double randomRange(Double min, Double max) {
|
||||||
public boolean hasNext() {
|
return min + (r.nextDouble() * (max - min));
|
||||||
return i < k;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Point next() {
|
public void flatMap(Tuple4<Double, Double, Double, Double> area, Collector<Point> out) {
|
||||||
i += 1;
|
for (int i = 0; i < k; i++) {
|
||||||
return new Point(r.nextDouble(), r.nextDouble());
|
out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
|
|
||||||
inputStream.defaultReadObject();
|
|
||||||
}
|
|
||||||
private void writeObject(ObjectOutputStream outputStream) throws IOException {
|
|
||||||
outputStream.defaultWriteObject();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Reference in a new issue