One time calculations
This commit is contained in:
parent
de6a881428
commit
72a187112c
21
plotClassification.py
Executable file
21
plotClassification.py
Executable file
|
@ -0,0 +1,21 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
FILENAME = sys.argv[1] # CSV file
|
||||||
|
|
||||||
|
data = np.loadtxt(FILENAME, delimiter=',')
|
||||||
|
|
||||||
|
D = data[0].size - 1 # Number of dimensions
|
||||||
|
|
||||||
|
assert D <= 2
|
||||||
|
assert D > 0
|
||||||
|
|
||||||
|
X = data[:, 0]
|
||||||
|
Y = data[:, 1] if D > 1 else np.zeros(len(data))
|
||||||
|
C = data[:, -1]
|
||||||
|
|
||||||
|
plt.scatter(X, Y, c=C)
|
||||||
|
plt.show()
|
|
@ -1,11 +1,26 @@
|
||||||
package it.polimi.middleware.projects.flink;
|
package it.polimi.middleware.projects.flink;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import org.apache.flink.api.common.functions.MapFunction;
|
import org.apache.flink.api.common.functions.MapFunction;
|
||||||
import org.apache.flink.api.common.functions.ReduceFunction;
|
import org.apache.flink.api.common.functions.ReduceFunction;
|
||||||
|
import org.apache.flink.api.common.functions.RichMapFunction;
|
||||||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||||
import org.apache.flink.api.java.tuple.Tuple1;
|
import org.apache.flink.api.java.tuple.Tuple1;
|
||||||
import org.apache.flink.api.java.tuple.Tuple2;
|
import org.apache.flink.api.java.tuple.Tuple2;
|
||||||
|
import org.apache.flink.api.java.tuple.Tuple3;
|
||||||
|
import org.apache.flink.configuration.Configuration;
|
||||||
|
|
||||||
import org.apache.flink.api.java.DataSet;
|
import org.apache.flink.api.java.DataSet;
|
||||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||||
|
@ -18,6 +33,8 @@ public class KMeans {
|
||||||
final ParameterTool params = ParameterTool.fromArgs(args);
|
final ParameterTool params = ParameterTool.fromArgs(args);
|
||||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||||
|
|
||||||
|
final Integer k = params.getInt("k", 3);
|
||||||
|
|
||||||
// Read CSV input
|
// Read CSV input
|
||||||
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
DataSet<Tuple1<Double>> csvInput = env.readCsvFile(params.get("input")).types(Double.class);
|
||||||
|
|
||||||
|
@ -25,26 +42,127 @@ public class KMeans {
|
||||||
DataSet<Double> input = csvInput
|
DataSet<Double> input = csvInput
|
||||||
.map(point -> point.f0);
|
.map(point -> point.f0);
|
||||||
|
|
||||||
// DEBUG Means all the points
|
// Generate random centroids
|
||||||
DataSet<Tuple1<Double>> mean = input
|
final RandomCentroids r = new RandomCentroids(k);
|
||||||
.map(new MapFunction<Double, Tuple2<Double, Integer>>() {
|
DataSet<Double> centroids = env.fromCollection(r, Double.class);
|
||||||
public Tuple2<Double, Integer> map(Double value) {
|
|
||||||
return new Tuple2<Double, Integer>(value, 1);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.reduce(new ReduceFunction<Tuple2<Double, Integer>>() {
|
|
||||||
public Tuple2<Double, Integer> reduce(Tuple2<Double, Integer> a, Tuple2<Double, Integer> b) {
|
|
||||||
return new Tuple2<Double, Integer>(a.f0 + b.f0, a.f1 + b.f1);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.map(new MapFunction<Tuple2<Double, Integer>, Tuple1<Double>>() {
|
|
||||||
public Tuple1<Double> map(Tuple2<Double, Integer> value) {
|
|
||||||
return new Tuple1<Double>(value.f0 / value.f1);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
mean.writeAsCsv(params.get("output", "output.csv"));
|
centroids.print();
|
||||||
|
|
||||||
|
// Assign points to centroids
|
||||||
|
DataSet<Tuple2<Double, Integer>> assigned = input
|
||||||
|
.map(new AssignCentroid()).withBroadcastSet(centroids, "centroids");
|
||||||
|
|
||||||
|
// Calculate means
|
||||||
|
DataSet<Double> newCentroids = assigned
|
||||||
|
.map(new MeanPrepare())
|
||||||
|
.groupBy(1) // GroupBy CentroidID
|
||||||
|
.reduce(new MeanSum())
|
||||||
|
.map(new MeanDivide());
|
||||||
|
|
||||||
|
// Re-assign points to centroids
|
||||||
|
assigned = input
|
||||||
|
.map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids");
|
||||||
|
|
||||||
|
newCentroids.print();
|
||||||
|
assigned.writeAsCsv(params.get("output", "output.csv"));
|
||||||
|
|
||||||
env.execute("K-Means clustering");
|
env.execute("K-Means clustering");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static class RandomCentroids implements Iterator<Double>, Serializable {
|
||||||
|
|
||||||
|
Integer k;
|
||||||
|
Integer i;
|
||||||
|
Random r;
|
||||||
|
|
||||||
|
public RandomCentroids(Integer k) {
|
||||||
|
this.k = k;
|
||||||
|
this.i = 0;
|
||||||
|
this.r = new Random(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return i < k;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double next() {
|
||||||
|
i += 1;
|
||||||
|
return r.nextDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
|
||||||
|
inputStream.defaultReadObject();
|
||||||
|
}
|
||||||
|
private void writeObject(ObjectOutputStream outputStream) throws IOException {
|
||||||
|
outputStream.defaultWriteObject();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class AssignCentroid extends RichMapFunction<Double, Tuple2<Double, Integer>> {
|
||||||
|
// Point → Point, CentroidID
|
||||||
|
private List<Double> centroids;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void open(Configuration parameters) throws Exception {
|
||||||
|
centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids"));
|
||||||
|
Collections.sort(centroids);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tuple2<Double, Integer> map(Double point) {
|
||||||
|
Integer c;
|
||||||
|
Double centroid;
|
||||||
|
Double distance;
|
||||||
|
Integer minCentroid = 4;
|
||||||
|
Double minDistance = Double.POSITIVE_INFINITY;
|
||||||
|
|
||||||
|
for (c = 0; c < centroids.size(); c++) {
|
||||||
|
centroid = centroids.get(c);
|
||||||
|
distance = distancePointCentroid(point, centroid);
|
||||||
|
if (distance < minDistance) {
|
||||||
|
minCentroid = c;
|
||||||
|
minDistance = distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Tuple2<Double, Integer>(point, minCentroid);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Double distancePointCentroid(Double point, Double centroid) {
|
||||||
|
return Math.abs(point - centroid);
|
||||||
|
// return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MeanPrepare implements MapFunction<Tuple2<Double, Integer>, Tuple3<Double, Integer, Integer>> {
|
||||||
|
// Point, CentroidID → Point, CentroidID, Number of points
|
||||||
|
@Override
|
||||||
|
public Tuple3<Double, Integer, Integer> map(Tuple2<Double, Integer> point) {
|
||||||
|
return new Tuple3<Double, Integer, Integer>(point.f0, point.f1, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MeanSum implements ReduceFunction<Tuple3<Double, Integer, Integer>> {
|
||||||
|
// Point, CentroidID (irrelevant), Number of points
|
||||||
|
@Override
|
||||||
|
public Tuple3<Double, Integer, Integer> reduce(Tuple3<Double, Integer, Integer> a, Tuple3<Double, Integer, Integer> b) {
|
||||||
|
return new Tuple3<Double, Integer, Integer>(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Double sumPoints(Double a, Double b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MeanDivide implements MapFunction<Tuple3<Double, Integer, Integer>, Double> {
|
||||||
|
// Point, CentroidID (irrelevant), Number of points → Point
|
||||||
|
@Override
|
||||||
|
public Double map(Tuple3<Double, Integer, Integer> point) {
|
||||||
|
return point.f0 / point.f2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue