223 lines
8.5 KiB
Java
223 lines
8.5 KiB
Java
package it.polimi.middleware.projects.flink;
|
||
|
||
import java.io.IOException;
|
||
import java.io.ObjectInputStream;
|
||
import java.io.ObjectOutputStream;
|
||
import java.io.Serializable;
|
||
import java.util.ArrayList;
|
||
import java.util.Arrays;
|
||
import java.util.Collection;
|
||
import java.util.Collections;
|
||
import java.util.Iterator;
|
||
import java.util.List;
|
||
import java.util.Random;
|
||
|
||
import org.apache.flink.api.common.functions.FlatMapFunction;
|
||
import org.apache.flink.api.common.functions.MapFunction;
|
||
import org.apache.flink.api.common.functions.ReduceFunction;
|
||
import org.apache.flink.api.common.functions.RichMapFunction;
|
||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||
import org.apache.flink.api.java.tuple.Tuple1;
|
||
import org.apache.flink.api.java.tuple.Tuple2;
|
||
import org.apache.flink.api.java.tuple.Tuple3;
|
||
import org.apache.flink.api.java.tuple.Tuple4;
|
||
import org.apache.flink.configuration.Configuration;
|
||
import org.apache.flink.util.Collector;
|
||
|
||
import org.apache.flink.api.java.DataSet;
|
||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||
import org.apache.flink.api.java.operators.IterativeDataSet;
|
||
import org.apache.flink.api.java.utils.ParameterTool;
|
||
|
||
|
||
public class KMeans {
|
||
|
||
public static void main(String[] args) throws Exception {
|
||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||
|
||
final ParameterTool params = ParameterTool.fromArgs(args);
|
||
final Integer k = params.getInt("k", 3);
|
||
final Integer maxIterations = params.getInt("maxIterations", 25);
|
||
|
||
// Read CSV input
|
||
DataSet<Tuple2<Double, Double>> inputCsv = env.readCsvFile(params.get("input")).types(Double.class, Double.class);
|
||
|
||
// Convert to internal format
|
||
DataSet<Point> input = inputCsv
|
||
.map(tuple -> new Point(tuple.f0, tuple.f1));
|
||
|
||
// Find min and max of the coordinates to determine where the initial centroids should be
|
||
DataSet<Tuple4<Double, Double, Double, Double>> area = input
|
||
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() { // Format points so
|
||
// they can be passed as reduce parameters
|
||
@Override
|
||
public Tuple4<Double, Double, Double, Double> map(Point point) {
|
||
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
|
||
}
|
||
}).reduce(new FindArea()); // Gives the minX, minY, maxX, maxY of all the point
|
||
|
||
// Generate random centroids
|
||
IterativeDataSet<Point> centroids = area
|
||
.flatMap(new RandomCentroids(k)) // Create centroids randomly in the area of the points
|
||
.iterate(maxIterations); // Mark beginning of the loop
|
||
|
||
// Assign points to centroids
|
||
DataSet<Tuple2<Point, Integer>> assigned = input
|
||
.map(new AssignCentroid()).withBroadcastSet(centroids, "centroids");
|
||
|
||
// Calculate means
|
||
DataSet<Point> newCentroids = assigned
|
||
.map(new MeanPrepare()) // Add Integer field to tuple to count the points
|
||
.groupBy(1) // GroupBy CentroidID
|
||
.reduce(new MeanSum()) // Sum every points by centroid
|
||
.map(new MeanDivide()); // Divide by the number of points to get the average
|
||
|
||
DataSet<Point> finalCentroids = centroids.closeWith(newCentroids); // Mark end of the loop
|
||
|
||
// Final assignment of points to centroids (that's the data we want)
|
||
assigned = input
|
||
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
|
||
|
||
// Convert to CSV format
|
||
DataSet<Tuple3<Double, Double, Integer>> output = assigned
|
||
.map(new MapFunction<Tuple2<Point, Integer>, Tuple3<Double, Double, Integer>>() {
|
||
@Override
|
||
public Tuple3<Double, Double, Integer> map(Tuple2<Point, Integer> tuple) {
|
||
return new Tuple3<Double, Double, Integer>(tuple.f0.x, tuple.f0.y, tuple.f1);
|
||
}
|
||
});
|
||
|
||
output.writeAsCsv(params.get("output", "output.csv"));
|
||
|
||
env.execute("K-Means clustering");
|
||
}
|
||
|
||
public static class Point implements Comparable<Point> {
|
||
public Double x;
|
||
public Double y;
|
||
|
||
public Point(Double x, Double y) {
|
||
this.x = x;
|
||
this.y = y;
|
||
}
|
||
|
||
public int compareTo(Point other) {
|
||
int comp = x.compareTo(other.x);
|
||
if (comp == 0) {
|
||
comp = y.compareTo(other.y);
|
||
}
|
||
return comp;
|
||
}
|
||
|
||
public Point addTo(Point other) {
|
||
// Since input is always re-fetched we can overwrite the values
|
||
x += other.x;
|
||
y += other.y;
|
||
return this;
|
||
}
|
||
|
||
public Point divideBy(Integer factor) {
|
||
// Since input is always re-fetched we can overwrite the values
|
||
x /= factor;
|
||
y /= factor;
|
||
return this;
|
||
}
|
||
|
||
public Double distanceTo(Point other) {
|
||
return Math.sqrt(Math.pow(other.x - x, 2) + Math.pow(other.y - y, 2));
|
||
}
|
||
|
||
}
|
||
|
||
public static class FindArea implements ReduceFunction<Tuple4<Double, Double, Double, Double>> {
|
||
// minX, minY, maxX, maxY
|
||
@Override
|
||
public Tuple4<Double, Double, Double, Double> reduce(Tuple4<Double, Double, Double, Double> a, Tuple4<Double, Double, Double, Double> b) {
|
||
return new Tuple4<Double, Double, Double, Double>(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3));
|
||
}
|
||
}
|
||
|
||
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
|
||
// minX, minY, maxX, maxY → Point × k
|
||
Integer k;
|
||
Random r;
|
||
|
||
public RandomCentroids(Integer k) {
|
||
this.k = k;
|
||
this.r = new Random();
|
||
}
|
||
|
||
private Double randomRange(Double min, Double max) {
|
||
return min + (r.nextDouble() * (max - min));
|
||
}
|
||
|
||
@Override
|
||
public void flatMap(Tuple4<Double, Double, Double, Double> area, Collector<Point> out) {
|
||
for (int i = 0; i < k; i++) {
|
||
out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3)));
|
||
}
|
||
}
|
||
}
|
||
|
||
public static class AssignCentroid extends RichMapFunction<Point, Tuple2<Point, Integer>> {
|
||
// Point → Point, CentroidID
|
||
private List<Point> centroids;
|
||
|
||
@Override
|
||
public void open(Configuration parameters) throws Exception {
|
||
// Centroids are sorted so they have an identifier common to all the operators
|
||
centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids"));
|
||
Collections.sort(centroids);
|
||
}
|
||
|
||
@Override
|
||
public Tuple2<Point, Integer> map(Point point) {
|
||
// Calculate the distance Point-Centroid for all centroids,
|
||
// keep the identifier of the closest centroid
|
||
Integer c;
|
||
Point centroid;
|
||
Double distance;
|
||
Integer minCentroid = 0;
|
||
Double minDistance = Double.POSITIVE_INFINITY;
|
||
|
||
for (c = 0; c < centroids.size(); c++) {
|
||
centroid = centroids.get(c);
|
||
distance = point.distanceTo(centroid);
|
||
if (distance < minDistance) {
|
||
minCentroid = c;
|
||
minDistance = distance;
|
||
}
|
||
}
|
||
|
||
return new Tuple2<Point, Integer>(point, minCentroid);
|
||
}
|
||
|
||
}
|
||
|
||
public static class MeanPrepare implements MapFunction<Tuple2<Point, Integer>, Tuple3<Point, Integer, Integer>> {
|
||
// Point, CentroidID → Point, CentroidID, Number of points
|
||
@Override
|
||
public Tuple3<Point, Integer, Integer> map(Tuple2<Point, Integer> info) {
|
||
return new Tuple3<Point, Integer, Integer>(info.f0, info.f1, 1);
|
||
}
|
||
}
|
||
|
||
public static class MeanSum implements ReduceFunction<Tuple3<Point, Integer, Integer>> {
|
||
// Point, CentroidID (irrelevant), Number of points
|
||
@Override
|
||
public Tuple3<Point, Integer, Integer> reduce(Tuple3<Point, Integer, Integer> a, Tuple3<Point, Integer, Integer> b) {
|
||
return new Tuple3<Point, Integer, Integer>(a.f0.addTo(b.f0), a.f1, a.f2 + b.f2);
|
||
}
|
||
}
|
||
|
||
public static class MeanDivide implements MapFunction<Tuple3<Point, Integer, Integer>, Point> {
|
||
// Point, CentroidID (irrelevant), Number of points → Point
|
||
@Override
|
||
public Point map(Tuple3<Point, Integer, Integer> info) {
|
||
return info.f0.divideBy(info.f2);
|
||
}
|
||
}
|
||
|
||
}
|