diff --git a/plotClassification.py b/plotClassification.py new file mode 100755 index 0000000..3417f41 --- /dev/null +++ b/plotClassification.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np +import matplotlib.pyplot as plt + +FILENAME = sys.argv[1] # CSV file + +data = np.loadtxt(FILENAME, delimiter=',') + +D = data[0].size - 1 # Number of dimensions + +assert D <= 2 +assert D > 0 + +X = data[:, 0] +Y = data[:, 1] if D > 1 else np.zeros(len(data)) +C = data[:, -1] + +plt.scatter(X, Y, c=C) +plt.show() diff --git a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java index 3e70451..5e0b010 100644 --- a/src/main/java/it/polimi/middleware/projects/flink/KMeans.java +++ b/src/main/java/it/polimi/middleware/projects/flink/KMeans.java @@ -1,11 +1,26 @@ package it.polimi.middleware.projects.flink; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; @@ -18,6 +33,8 @@ public class KMeans { final ParameterTool params = ParameterTool.fromArgs(args); final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + final Integer k = params.getInt("k", 3); + // Read CSV input DataSet> csvInput = env.readCsvFile(params.get("input")).types(Double.class); @@ -25,26 +42,127 @@ public class KMeans { DataSet input = csvInput .map(point -> point.f0); - // DEBUG Means all the points - DataSet> mean = input - .map(new MapFunction>() { - public Tuple2 map(Double value) { - return new Tuple2(value, 1); - } - }) - .reduce(new ReduceFunction>() { - public Tuple2 reduce(Tuple2 a, Tuple2 b) { - return new Tuple2(a.f0 + b.f0, a.f1 + b.f1); - } - }) - .map(new MapFunction, Tuple1>() { - public Tuple1 map(Tuple2 value) { - return new Tuple1(value.f0 / value.f1); - } - }); + // Generate random centroids + final RandomCentroids r = new RandomCentroids(k); + DataSet centroids = env.fromCollection(r, Double.class); - mean.writeAsCsv(params.get("output", "output.csv")); + centroids.print(); + + // Assign points to centroids + DataSet> assigned = input + .map(new AssignCentroid()).withBroadcastSet(centroids, "centroids"); + + // Calculate means + DataSet newCentroids = assigned + .map(new MeanPrepare()) + .groupBy(1) // GroupBy CentroidID + .reduce(new MeanSum()) + .map(new MeanDivide()); + + // Re-assign points to centroids + assigned = input + .map(new AssignCentroid()).withBroadcastSet(newCentroids, "centroids"); + + newCentroids.print(); + assigned.writeAsCsv(params.get("output", "output.csv")); env.execute("K-Means clustering"); } + + + public static class RandomCentroids implements Iterator, Serializable { + + Integer k; + Integer i; + Random r; + + public RandomCentroids(Integer k) { + this.k = k; + this.i = 0; + this.r = new Random(0); + } + + @Override + public boolean hasNext() { + return i < k; + } + + @Override + public Double next() { + i += 1; + return r.nextDouble(); + } + + private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException { + inputStream.defaultReadObject(); + } + private void writeObject(ObjectOutputStream outputStream) throws IOException { + outputStream.defaultWriteObject(); + } + } + + public static class AssignCentroid extends RichMapFunction> { + // Point → Point, CentroidID + private List centroids; + + @Override + public void open(Configuration parameters) throws Exception { + centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids")); + Collections.sort(centroids); + } + + @Override + public Tuple2 map(Double point) { + Integer c; + Double centroid; + Double distance; + Integer minCentroid = 4; + Double minDistance = Double.POSITIVE_INFINITY; + + for (c = 0; c < centroids.size(); c++) { + centroid = centroids.get(c); + distance = distancePointCentroid(point, centroid); + if (distance < minDistance) { + minCentroid = c; + minDistance = distance; + } + } + + return new Tuple2(point, minCentroid); + } + + private Double distancePointCentroid(Double point, Double centroid) { + return Math.abs(point - centroid); + // return Math.sqrt(Math.pow(point, 2) + Math.pow(centroid, 2)); + } + } + + public static class MeanPrepare implements MapFunction, Tuple3> { + // Point, CentroidID → Point, CentroidID, Number of points + @Override + public Tuple3 map(Tuple2 point) { + return new Tuple3(point.f0, point.f1, 1); + } + } + + public static class MeanSum implements ReduceFunction> { + // Point, CentroidID (irrelevant), Number of points + @Override + public Tuple3 reduce(Tuple3 a, Tuple3 b) { + return new Tuple3(sumPoints(a.f0, b.f0), a.f1, a.f2 + b.f2); + } + + private Double sumPoints(Double a, Double b) { + return a + b; + } + } + + public static class MeanDivide implements MapFunction, Double> { + // Point, CentroidID (irrelevant), Number of points → Point + @Override + public Double map(Tuple3 point) { + return point.f0 / point.f2; + } + } + }