Documentation
This commit is contained in:
parent
a9a0e518e8
commit
24e85deb17
3 changed files with 54 additions and 30 deletions
|
|
@ -34,9 +34,9 @@ import org.apache.flink.api.java.utils.ParameterTool;
|
|||
public class KMeans {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
final ParameterTool params = ParameterTool.fromArgs(args);
|
||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
|
||||
final ParameterTool params = ParameterTool.fromArgs(args);
|
||||
final Integer k = params.getInt("k", 3);
|
||||
final Integer maxIterations = params.getInt("maxIterations", 25);
|
||||
|
||||
|
|
@ -49,28 +49,18 @@ public class KMeans {
|
|||
|
||||
// Find min and max of the coordinates to determine where the initial centroids should be
|
||||
DataSet<Tuple4<Double, Double, Double, Double>> area = input
|
||||
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() {
|
||||
.map(new MapFunction<Point, Tuple4<Double, Double, Double, Double>>() { // Format points so
|
||||
// they can be passed as reduce parameters
|
||||
@Override
|
||||
public Tuple4<Double, Double, Double, Double> map(Point point) {
|
||||
return new Tuple4<Double, Double, Double, Double>(point.x, point.y, point.x, point.y);
|
||||
}
|
||||
}).reduce(new FindArea());
|
||||
|
||||
area.print();
|
||||
|
||||
DataSet<Tuple2<Double, Double>> testCentroids = area
|
||||
.flatMap(new RandomCentroids(k))
|
||||
.map(new MapFunction<Point, Tuple2<Double, Double>>() {
|
||||
@Override
|
||||
public Tuple2<Double, Double> map(Point point) {
|
||||
return new Tuple2<Double, Double>(point.x, point.y);
|
||||
}});
|
||||
testCentroids.print();
|
||||
}).reduce(new FindArea()); // Gives the minX, minY, maxX, maxY of all the point
|
||||
|
||||
// Generate random centroids
|
||||
IterativeDataSet<Point> centroids = area
|
||||
.flatMap(new RandomCentroids(k))
|
||||
.iterate(maxIterations);
|
||||
.flatMap(new RandomCentroids(k)) // Create centroids randomly in the area of the points
|
||||
.iterate(maxIterations); // Mark beginning of the loop
|
||||
|
||||
// Assign points to centroids
|
||||
DataSet<Tuple2<Point, Integer>> assigned = input
|
||||
|
|
@ -78,18 +68,18 @@ public class KMeans {
|
|||
|
||||
// Calculate means
|
||||
DataSet<Point> newCentroids = assigned
|
||||
.map(new MeanPrepare())
|
||||
.map(new MeanPrepare()) // Add Integer field to tuple to count the points
|
||||
.groupBy(1) // GroupBy CentroidID
|
||||
.reduce(new MeanSum())
|
||||
.map(new MeanDivide());
|
||||
.reduce(new MeanSum()) // Sum every points by centroid
|
||||
.map(new MeanDivide()); // Divide by the number of points to get the average
|
||||
|
||||
DataSet<Point> finalCentroids = centroids.closeWith(newCentroids);
|
||||
DataSet<Point> finalCentroids = centroids.closeWith(newCentroids); // Mark end of the loop
|
||||
|
||||
// Final assignment of points to centroids
|
||||
// Final assignment of points to centroids (that's the data we want)
|
||||
assigned = input
|
||||
.map(new AssignCentroid()).withBroadcastSet(finalCentroids, "centroids");
|
||||
|
||||
// Convert to external format
|
||||
// Convert to CSV format
|
||||
DataSet<Tuple3<Double, Double, Integer>> output = assigned
|
||||
.map(new MapFunction<Tuple2<Point, Integer>, Tuple3<Double, Double, Integer>>() {
|
||||
@Override
|
||||
|
|
@ -128,6 +118,7 @@ public class KMeans {
|
|||
}
|
||||
|
||||
public Point divideBy(Integer factor) {
|
||||
// Since input is always re-fetched we can overwrite the values
|
||||
x /= factor;
|
||||
y /= factor;
|
||||
return this;
|
||||
|
|
@ -148,12 +139,13 @@ public class KMeans {
|
|||
}
|
||||
|
||||
public static class RandomCentroids implements FlatMapFunction<Tuple4<Double, Double, Double, Double>, Point> {
|
||||
// minX, minY, maxX, maxY → Point × k
|
||||
Integer k;
|
||||
Random r;
|
||||
|
||||
public RandomCentroids(Integer k) {
|
||||
this.k = k;
|
||||
this.r = new Random(0);
|
||||
this.r = new Random();
|
||||
}
|
||||
|
||||
private Double randomRange(Double min, Double max) {
|
||||
|
|
@ -174,16 +166,19 @@ public class KMeans {
|
|||
|
||||
@Override
|
||||
public void open(Configuration parameters) throws Exception {
|
||||
// Centroids are sorted so they have an identifier common to all the operators
|
||||
centroids = new ArrayList(getRuntimeContext().getBroadcastVariable("centroids"));
|
||||
Collections.sort(centroids);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple2<Point, Integer> map(Point point) {
|
||||
// Calculate the distance Point-Centroid for all centroids,
|
||||
// keep the identifier of the closest centroid
|
||||
Integer c;
|
||||
Point centroid;
|
||||
Double distance;
|
||||
Integer minCentroid = 4;
|
||||
Integer minCentroid = 0;
|
||||
Double minDistance = Double.POSITIVE_INFINITY;
|
||||
|
||||
for (c = 0; c < centroids.size(); c++) {
|
||||
|
|
|
|||
Reference in a new issue