2 Dimensions
This commit is contained in:
parent
e94e4c0ce1
commit
d43f0ebec8
|
@ -38,40 +38,84 @@ public class KMeans {
|
|||
final Integer maxIterations = params.getInt("maxIterations", 25);
|
||||
|
||||
// Read CSV input
|
||||
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
||||
DataSet<Tuple2<Double, Double>> inputCsv = env.readCsvFile(params.get("input")).types(Double.class, Double.class);
|
||||
|
||||
// Convert CSV to internal format
|
||||
DataSet<Double> input = csvInput
|
||||
.map(point -> point.f0);
|
||||
// Convert to internal format
|
||||
DataSet<Point> input = inputCsv
|
||||
.map(tuple -> new Point(tuple.f0, tuple.f1));
|
||||
|
||||
// Generate random centroids
|
||||
final RandomCentroids r = new RandomCentroids(k);
|
||||
IterativeDataSet<Double> centroids = env.fromCollection(r, Double.class).iterate(maxIterations);
|
||||
IterativeDataSet<Point> centroids = env.fromCollection(r, Point.class).iterate(maxIterations);
|
||||
|
||||
// Assign points to centroids
|
||||
DataSet<Tuple2<Double, Integer>> assigned = input
|
||||
DataSet<Tuple2<Point, Integer>> assigned = input
|
||||
.map(new AssignCentroid()).withBroadcastSet(centroids, "centroids");
|
||||
|
||||
// Calculate means
|
||||
DataSet<Double> newCentroids = assigned
|
||||
DataSet<Point> newCentroids = assigned
|
||||
.map(new MeanPrepare())
|
||||
.groupBy(1) // GroupBy CentroidID
|
||||
.reduce(new MeanSum())
|
||||
.map(new MeanDivide());
|
||||
|
||||
DataSet<Double> finalCentroids = centroids.closeWith(newCentroids);
|
||||
DataSet<Point> finalCentroids = centroids.closeWith(newCentroids);
|
||||
|
||||
// Final assignment of points to centroids
|
||||
assigned = input
|
||||
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
|
||||
|
||||
assigned.writeAsCsv(params.get("output", "output.csv"));
|
||||
// Convert to external format
|
||||
DataSet<Tuple3<Double, Double, Integer>> output = assigned
|
||||
.map(new MapFunction<Tuple2<Point, Integer>, Tuple3<Double, Double, Integer>>() {
|
||||
@Override
|
||||
public Tuple3<Double, Double, Integer> map(Tuple2<Point, Integer> tuple) {
|
||||
return new Tuple3<Double, Double, Integer>(tuple.f0.x, tuple.f0.y, tuple.f1);
|
||||
}
|
||||
});
|
||||
|
||||
output.writeAsCsv(params.get("output", "output.csv"));
|
||||
|
||||
env.execute("K-Means clustering");
|
||||
}
|
||||
|
||||
public static class Point implements Comparable<Point> {
|
||||
public Double x;
|
||||
public Double y;
|
||||
|
||||
public static class RandomCentroids implements Iterator<Double>, Serializable {
|
||||
public Point(Double x, Double y) {
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
public int compareTo(Point other) {
|
||||
int comp = x.compareTo(other.x);
|
||||
if (comp == 0) {
|
||||
comp = y.compareTo(other.y);
|
||||
}
|
||||
return comp;
|
||||
}
|
||||
|
||||
public Point addTo(Point other) {
|
||||
// Since input is always re-fetched we can overwrite the values
|
||||
x += other.x;
|
||||
y += other.y;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Point divideBy(Integer factor) {
|
||||
x /= factor;
|
||||
y /= factor;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Double distanceTo(Point other) {
|
||||
return Math.sqrt(Math.pow(other.x - x, 2) + Math.pow(other.y - y, 2));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class RandomCentroids implements Iterator<Point>, Serializable {
|
||||
|
||||
Integer k;
|
||||
Integer i;
|
||||
|
@ -89,9 +133,9 @@ public class KMeans {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Double next() {
|
||||
public Point next() {
|
||||
i += 1;
|
||||
return r.nextDouble();
|
||||
return new Point(r.nextDouble(), r.nextDouble());
|
||||
}
|
||||
|
||||
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
|
||||
|
@ -102,9 +146,9 @@ public class KMeans {
|
|||
}
|
||||
}
|
||||
|
||||
public static class AssignCentroid extends RichMapFunction<Double, Tuple2<Double, Integer>> {
|
||||
public static class AssignCentroid extends RichMapFunction<Point, Tuple2<Point, Integer>> {
|
||||
// Point → Point, CentroidID
|
||||
private List<Double> centroids;
|
||||
private List<Point> centroids;
|
||||
|
||||
@Override
|
||||
public void open(Configuration parameters) throws Exception {
|
||||
|
@ -113,56 +157,48 @@ public class KMeans {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple2<Double, Integer> map(Double point) {
|
||||
public Tuple2<Point, Integer> map(Point point) {
|
||||
Integer c;
|
||||
Double centroid;
|
||||
Point centroid;
|
||||
Double distance;
|
||||
Integer minCentroid = 4;
|
||||
Double minDistance = Double.POSITIVE_INFINITY;
|
||||
|
||||
for (c = 0; c < centroids.size(); c++) {
|
||||
centroid = centroids.get(c);
|
||||
distance = distancePointCentroid(point, centroid);
|
||||
distance = point.distanceTo(centroid);
|
||||
if (distance < minDistance) {
|
||||
minCentroid = c;
|
||||
minDistance = distance;
|
||||
}
|
||||
}
|
||||
|
||||
return new Tuple2<Double, Integer>(point, minCentroid);
|
||||
return new Tuple2<Point, Integer>(point, minCentroid);
|
||||
}
|
||||
|
||||
private Double distancePointCentroid(Double point, Double centroid) {
|
||||
return Math.abs(point - centroid);
|
||||
// return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2));
|
||||
}
|
||||
}
|
||||
|
||||
public static class MeanPrepare implements MapFunction<Tuple2<Double, Integer>, Tuple3<Double, Integer, Integer>> {
|
||||
public static class MeanPrepare implements MapFunction<Tuple2<Point, Integer>, Tuple3<Point, Integer, Integer>> {
|
||||
// Point, CentroidID → Point, CentroidID, Number of points
|
||||
@Override
|
||||
public Tuple3<Double, Integer, Integer> map(Tuple2<Double, Integer> point) {
|
||||
return new Tuple3<Double, Integer, Integer>(point.f0, point.f1, 1);
|
||||
public Tuple3<Point, Integer, Integer> map(Tuple2<Point, Integer> info) {
|
||||
return new Tuple3<Point, Integer, Integer>(info.f0, info.f1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MeanSum implements ReduceFunction<Tuple3<Double, Integer, Integer>> {
|
||||
public static class MeanSum implements ReduceFunction<Tuple3<Point, Integer, Integer>> {
|
||||
// Point, CentroidID (irrelevant), Number of points
|
||||
@Override
|
||||
public Tuple3<Double, Integer, Integer> reduce(Tuple3<Double, Integer, Integer> a, Tuple3<Double, Integer, Integer> b) {
|
||||
return new Tuple3<Double, Integer, Integer>(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2);
|
||||
}
|
||||
|
||||
private Double sumPoints(Double a, Double b) {
|
||||
return a + b;
|
||||
public Tuple3<Point, Integer, Integer> reduce(Tuple3<Point, Integer, Integer> a, Tuple3<Point, Integer, Integer> b) {
|
||||
return new Tuple3<Point, Integer, Integer>(a.f0.addTo(b.f0), a.f1, a.f2 + b.f2);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MeanDivide implements MapFunction<Tuple3<Double, Integer, Integer>, Double> {
|
||||
public static class MeanDivide implements MapFunction<Tuple3<Point, Integer, Integer>, Point> {
|
||||
// Point, CentroidID (irrelevant), Number of points → Point
|
||||
@Override
|
||||
public Double map(Tuple3<Double, Integer, Integer> point) {
|
||||
return point.f0 / point.f2;
|
||||
public Point map(Tuple3<Point, Integer, Integer> info) {
|
||||
return info.f0.divideBy(info.f2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Reference in a new issue