This repository has been archived on 2019-08-08. You can view files and clone it, but cannot push or open issues or pull requests.
s9-mtds-prj-flink/src/main/java/it/polimi/middleware/projects/flink/KMeans.java
2019-01-24 20:44:44 +01:00

223 lines
8.5 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package it.polimi.middleware.projects.flink;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.utils.ParameterTool;
public class KMeans {
public static void main(String[] args) throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
final ParameterTool params = ParameterTool.fromArgs(args);
final Integer k = params.getInt("k", 3);
final Integer maxIterations = params.getInt("maxIterations", 25);
// Read CSV input
DataSet<Tuple2<Double, Double>> inputCsv = env.readCsvFile(params.get("input")).types(Double.class, Double.class);
// Convert to internal format
DataSet<Point> input = inputCsv
.map(tuple -> new Point(tuple.f0, tuple.f1));
// Find min and max of the coordinates to determine where the initial centroids should be
DataSet<Tuple4<Double, Double, Double, Double>> area = input
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() { // Format points so
// they can be passed as reduce parameters
@Override
public Tuple4<Double, Double, Double, Double> map(Point point) {
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
}
}).reduce(new FindArea()); // Gives the minX, minY, maxX, maxY of all the point
// Generate random centroids
IterativeDataSet<Point> centroids = area
.flatMap(new RandomCentroids(k)) // Create centroids randomly in the area of the points
.iterate(maxIterations); // Mark beginning of the loop
// Assign points to centroids
DataSet<Tuple2<Point, Integer>> assigned = input
.map(new AssignCentroid()).withBroadcastSet(centroids, "centroids");
// Calculate means
DataSet<Point> newCentroids = assigned
.map(new MeanPrepare()) // Add Integer field to tuple to count the points
.groupBy(1) // GroupBy CentroidID
.reduce(new MeanSum()) // Sum every points by centroid
.map(new MeanDivide()); // Divide by the number of points to get the average
DataSet<Point> finalCentroids = centroids.closeWith(newCentroids); // Mark end of the loop
// Final assignment of points to centroids (that's the data we want)
assigned = input
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
// Convert to CSV format
DataSet<Tuple3<Double, Double, Integer>> output = assigned
.map(new MapFunction<Tuple2<Point, Integer>, Tuple3<Double, Double, Integer>>() {
@Override
public Tuple3<Double, Double, Integer> map(Tuple2<Point, Integer> tuple) {
return new Tuple3<Double, Double, Integer>(tuple.f0.x, tuple.f0.y, tuple.f1);
}
});
output.writeAsCsv(params.get("output", "output.csv"));
env.execute("K-Means clustering");
}
public static class Point implements Comparable<Point> {
public Double x;
public Double y;
public Point(Double x, Double y) {
this.x = x;
this.y = y;
}
public int compareTo(Point other) {
int comp = x.compareTo(other.x);
if (comp == 0) {
comp = y.compareTo(other.y);
}
return comp;
}
public Point addTo(Point other) {
// Since input is always re-fetched we can overwrite the values
x += other.x;
y += other.y;
return this;
}
public Point divideBy(Integer factor) {
// Since input is always re-fetched we can overwrite the values
x /= factor;
y /= factor;
return this;
}
public Double distanceTo(Point other) {
return Math.sqrt(Math.pow(other.x - x, 2) + Math.pow(other.y - y, 2));
}
}
public static class FindArea implements ReduceFunction<Tuple4<Double, Double, Double, Double>> {
// minX, minY, maxX, maxY
@Override
public Tuple4<Double, Double, Double, Double> reduce(Tuple4<Double, Double, Double, Double> a, Tuple4<Double, Double, Double, Double> b) {
return new Tuple4<Double, Double, Double, Double>(Math.min(a.f0, b.f0), Math.min(a.f1, b.f1), Math.max(a.f2, b.f2), Math.max(a.f3, b.f3));
}
}
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
// minX, minY, maxX, maxY → Point × k
Integer k;
Random r;
public RandomCentroids(Integer k) {
this.k = k;
this.r = new Random();
}
private Double randomRange(Double min, Double max) {
return min + (r.nextDouble() * (max - min));
}
@Override
public void flatMap(Tuple4<Double, Double, Double, Double> area, Collector<Point> out) {
for (int i = 0; i < k; i++) {
out.collect(new Point(randomRange(area.f0, area.f2), randomRange(area.f1, area.f3)));
}
}
}
public static class AssignCentroid extends RichMapFunction<Point, Tuple2<Point, Integer>> {
// Point → Point, CentroidID
private List<Point> centroids;
@Override
public void open(Configuration parameters) throws Exception {
// Centroids are sorted so they have an identifier common to all the operators
centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids"));
Collections.sort(centroids);
}
@Override
public Tuple2<Point, Integer> map(Point point) {
// Calculate the distance Point-Centroid for all centroids,
// keep the identifier of the closest centroid
Integer c;
Point centroid;
Double distance;
Integer minCentroid = 0;
Double minDistance = Double.POSITIVE_INFINITY;
for (c = 0; c < centroids.size(); c++) {
centroid = centroids.get(c);
distance = point.distanceTo(centroid);
if (distance < minDistance) {
minCentroid = c;
minDistance = distance;
}
}
return new Tuple2<Point, Integer>(point, minCentroid);
}
}
public static class MeanPrepare implements MapFunction<Tuple2<Point, Integer>, Tuple3<Point, Integer, Integer>> {
// Point, CentroidID → Point, CentroidID, Number of points
@Override
public Tuple3<Point, Integer, Integer> map(Tuple2<Point, Integer> info) {
return new Tuple3<Point, Integer, Integer>(info.f0, info.f1, 1);
}
}
public static class MeanSum implements ReduceFunction<Tuple3<Point, Integer, Integer>> {
// Point, CentroidID (irrelevant), Number of points
@Override
public Tuple3<Point, Integer, Integer> reduce(Tuple3<Point, Integer, Integer> a, Tuple3<Point, Integer, Integer> b) {
return new Tuple3<Point, Integer, Integer>(a.f0.addTo(b.f0), a.f1, a.f2 + b.f2);
}
}
public static class MeanDivide implements MapFunction<Tuple3<Point, Integer, Integer>, Point> {
// Point, CentroidID (irrelevant), Number of points → Point
@Override
public Point map(Tuple3<Point, Integer, Integer> info) {
return info.f0.divideBy(info.f2);
}
}
}