K-means clustering partitions a dataset into k groups by iteratively refining cluster centroids until the centroids stabilize or the iteration limit is reached. This page shows how to implement k-means clustering using the MaxCompute Graph API (Java).
How it works
Each iteration assigns every data point to its nearest centroid, then recomputes the centroid as the mean of all points in that cluster. The algorithm stops when all centroids shift by less than the convergence threshold or the maximum iteration count is reached.
Algorithm steps
Select initial centroids for k clusters.
For each data point, compute its squared Euclidean distance to all k centroids and assign it to the nearest cluster.
Recompute each centroid as the arithmetic mean of all points assigned to that cluster.
If all centroids shift by less than the convergence threshold, stop. Otherwise, repeat from step 2.
Key parameters
| Parameter | Default | Description |
|---|---|---|
| Convergence threshold | 0.05 | Maximum Euclidean distance a centroid can shift between iterations before the algorithm considers it converged. Computed using the square root of the summed squared coordinate differences. |
| Maximum iterations | 30 | Upper bound on the number of supersteps. The job stops even if centroids have not fully converged. |
| Distance metric | Squared Euclidean | Squared distance is used for cluster assignment comparisons. The square root is computed only when checking convergence in terminate(). |
Implementation
The following example implements k-means clustering with four classes built on the MaxCompute Graph API.
Component overview
Before reading the code, review each component's responsibility:
KmeansVertex represents a single data point as a Tuple (feature vector). Its compute() method passes the vertex value to the aggregator by calling context.aggregate().
KmeansVertexReader loads the input table and creates one vertex per record. The record number becomes the vertex ID; all columns in the record become the vertex value as a Tuple.
KmeansAggrValue holds the shared aggregation state across workers, with three Tuple fields:
centers: current centroid coordinatessums: accumulated coordinate sums per clustercounts: data point counts per cluster
KmeansAggregator encapsulates the main algorithm logic across four methods:
| Method | Responsibility |
|---|---|
createInitialValue() | On superstep 0, reads initial centroids from the cache file named "centers". On subsequent supersteps, retrieves the centroids computed in the previous iteration via getLastAggregatedValue(0). |
aggregate() | For each vertex, finds the nearest centroid using squared Euclidean distance and updates that cluster's sum and count. |
merge() | Combines partial sums and counts collected from workers running in parallel. |
terminate() | Computes new centroids from sums and counts. If all centroids shift by less than 0.05 or the maximum iteration count is reached, writes the final centroids to the output table and returns true (stop). Otherwise returns false (continue). |
Sample code
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.log4j.Logger;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
public class Kmeans {
private final static Logger LOG = Logger.getLogger(Kmeans.class);
public static class KmeansVertex extends
Vertex<Text, Tuple, NullWritable, NullWritable> {
@Override
public void compute(
ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
public static class KmeansVertexReader extends
GraphLoader<Text, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<Text, Tuple, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex vertex = new KmeansVertex();
vertex.setId(new Text(String.valueOf(recordNum.get())));
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
}
}
public static class KmeansAggrValue implements Writable {
Tuple centers = new Tuple();
Tuple sums = new Tuple();
Tuple counts = new Tuple();
@Override
public void write(DataOutput out) throws IOException {
centers.write(out);
sums.write(out);
counts.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
centers = new Tuple();
centers.readFields(in);
sums = new Tuple();
sums.readFields(in);
counts = new Tuple();
counts.readFields(in);
}
@Override
public String toString() {
return "centers " + centers.toString() + ", sums " + sums.toString()
+ ", counts " + counts.toString();
}
}
public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
@SuppressWarnings("rawtypes")
@Override
public KmeansAggrValue createInitialValue(WorkerContext context)
throws IOException {
KmeansAggrValue aggrVal = null;
if (context.getSuperstep() == 0) {
aggrVal = new KmeansAggrValue();
aggrVal.centers = new Tuple();
aggrVal.sums = new Tuple();
aggrVal.counts = new Tuple();
byte[] centers = context.readCacheFile("centers");
String lines[] = new String(centers).split("\n");
for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
Tuple center = new Tuple();
Tuple sum = new Tuple();
for (int j = 0; j < ss.length; ++j) {
center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
sum.append(new DoubleWritable(0.0));
}
LongWritable count = new LongWritable(0);
aggrVal.sums.append(sum);
aggrVal.counts.append(count);
aggrVal.centers.append(center);
}
} else {
aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
}
return aggrVal;
}
@Override
public void aggregate(KmeansAggrValue value, Object item) {
int min = 0;
double mindist = Double.MAX_VALUE;
Tuple point = (Tuple) item;
for (int i = 0; i < value.centers.size(); i++) {
Tuple center = (Tuple) value.centers.get(i);
// use Euclidean Distance, no need to calculate sqrt
double dist = 0.0d;
for (int j = 0; j < center.size(); j++) {
double v = ((DoubleWritable) point.get(j)).get()
- ((DoubleWritable) center.get(j)).get();
dist += v * v;
}
if (dist < mindist) {
mindist = dist;
min = i;
}
}
// update sum and count
Tuple sum = (Tuple) value.sums.get(min);
for (int i = 0; i < point.size(); i++) {
DoubleWritable s = (DoubleWritable) sum.get(i);
s.set(s.get() + ((DoubleWritable) point.get(i)).get());
}
LongWritable count = (LongWritable) value.counts.get(min);
count.set(count.get() + 1);
}
@Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple that = (Tuple) partial.sums.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
s.set(s.get() + ((DoubleWritable) that.get(j)).get());
}
}
for (int i = 0; i < value.counts.size(); i++) {
LongWritable count = (LongWritable) value.counts.get(i);
count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
}
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// compute new centers
Tuple newCenters = new Tuple(value.sums.size());
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple newCenter = new Tuple(sum.size());
LongWritable c = (LongWritable) value.counts.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
double val = s.get() / c.get();
newCenter.set(j, new DoubleWritable(val));
// reset sum for next iteration
s.set(0.0d);
}
// reset count for next iteration
c.set(0);
newCenters.set(i, newCenter);
}
// update centers
Tuple oldCenters = value.centers;
value.centers = newCenters;
LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
// compare new/old centers
boolean converged = true;
for (int i = 0; i < value.centers.size() && converged; i++) {
Tuple oldCenter = (Tuple) oldCenters.get(i);
Tuple newCenter = (Tuple) newCenters.get(i);
double sum = 0.0d;
for (int j = 0; j < newCenter.size(); j++) {
double v = ((DoubleWritable) newCenter.get(j)).get()
- ((DoubleWritable) oldCenter.get(j)).get();
sum += v * v;
}
double dist = Math.sqrt(sum);
LOG.info("old center: " + oldCenter + ", new center: " + newCenter
+ ", dist: " + dist);
// converge threshold for each center: 0.05
converged = dist < 0.05d;
}
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centers
for (int i = 0; i < value.centers.size(); i++) {
context.write(((Tuple) value.centers.get(i)).toArray());
}
// true means to terminate iteration
return true;
}
// false means to continue iteration
return false;
}
}
private static void printUsage() {
System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
System.exit(-1);
}
public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(KmeansVertexReader.class);
job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}Job configuration
The main method configures the GraphJob with the following settings:
| Setting | Value | Description |
|---|---|---|
setGraphLoaderClass | KmeansVertexReader.class | Loads input table records as vertices |
setVertexClass | KmeansVertex.class | Defines per-vertex compute logic |
setAggregatorClass | KmeansAggregator.class | Defines centroid update and convergence logic |
setRuntimePartitioning | false | Disables runtime graph partitioning. K-means vertices do not need redistribution during loading, so disabling this improves graph loading performance. |
setMaxIteration | 30 (default) | Sets the iteration limit. Pass a third argument to the job to override. |
addInput / addOutput | args[0] / args[1] | Input and output table names, passed as command-line arguments |