This repository has been archived on 2019-08-08. You can view files and clone it, but cannot push or open issues or pull requests.
s9-mtds-prj-flink/src/main/java/it/polimi/middleware/projects/flink/KMeans.java

169 lines
5.6 KiB
Java

package it.polimi.middleware.projects.flink;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.utils.ParameterTool;
public class KMeans {
public static void main(String[] args) throws Exception {
final ParameterTool params = ParameterTool.fromArgs(args);
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
final Integer k = params.getInt("k", 3);
// Read CSV input
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
// Convert CSV to internal format
DataSet<Double> input = csvInput
.map(point -> point.f0);
// Generate random centroids
final RandomCentroids r = new RandomCentroids(k);
DataSet<Double> centroids = env.fromCollection(r, Double.class);
centroids.print();
// Assign points to centroids
DataSet<Tuple2<Double, Integer>> assigned = input
.map(new AssignCentroid()).withBroadcastSet(centroids, "centroids");
// Calculate means
DataSet<Double> newCentroids = assigned
.map(new MeanPrepare())
.groupBy(1) // GroupBy CentroidID
.reduce(new MeanSum())
.map(new MeanDivide());
// Re-assign points to centroids
assigned = input
.map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids");
newCentroids.print();
assigned.writeAsCsv(params.get("output", "output.csv"));
env.execute("K-Means clustering");
}
public static class RandomCentroids implements Iterator<Double>, Serializable {
Integer k;
Integer i;
Random r;
public RandomCentroids(Integer k) {
this.k = k;
this.i = 0;
this.r = new Random(0);
}
@Override
public boolean hasNext() {
return i < k;
}
@Override
public Double next() {
i += 1;
return r.nextDouble();
}
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
inputStream.defaultReadObject();
}
private void writeObject(ObjectOutputStream outputStream) throws IOException {
outputStream.defaultWriteObject();
}
}
public static class AssignCentroid extends RichMapFunction<Double, Tuple2<Double, Integer>> {
// Point → Point, CentroidID
private List<Double> centroids;
@Override
public void open(Configuration parameters) throws Exception {
centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids"));
Collections.sort(centroids);
}
@Override
public Tuple2<Double, Integer> map(Double point) {
Integer c;
Double centroid;
Double distance;
Integer minCentroid = 4;
Double minDistance = Double.POSITIVE_INFINITY;
for (c = 0; c < centroids.size(); c++) {
centroid = centroids.get(c);
distance = distancePointCentroid(point, centroid);
if (distance < minDistance) {
minCentroid = c;
minDistance = distance;
}
}
return new Tuple2<Double, Integer>(point, minCentroid);
}
private Double distancePointCentroid(Double point, Double centroid) {
return Math.abs(point - centroid);
// return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2));
}
}
public static class MeanPrepare implements MapFunction<Tuple2<Double, Integer>, Tuple3<Double, Integer, Integer>> {
// Point, CentroidID → Point, CentroidID, Number of points
@Override
public Tuple3<Double, Integer, Integer> map(Tuple2<Double, Integer> point) {
return new Tuple3<Double, Integer, Integer>(point.f0, point.f1, 1);
}
}
public static class MeanSum implements ReduceFunction<Tuple3<Double, Integer, Integer>> {
// Point, CentroidID (irrelevant), Number of points
@Override
public Tuple3<Double, Integer, Integer> reduce(Tuple3<Double, Integer, Integer> a, Tuple3<Double, Integer, Integer> b) {
return new Tuple3<Double, Integer, Integer>(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2);
}
private Double sumPoints(Double a, Double b) {
return a + b;
}
}
public static class MeanDivide implements MapFunction<Tuple3<Double, Integer, Integer>, Double> {
// Point, CentroidID (irrelevant), Number of points → Point
@Override
public Double map(Tuple3<Double, Integer, Integer> point) {
return point.f0 / point.f2;
}
}
}