All Products
Search
Document Center

MaxCompute:K-means clustering

Last Updated:Mar 25, 2026

K-means clustering partitions a dataset into k groups by iteratively refining cluster centroids until the centroids stabilize or the iteration limit is reached. This page shows how to implement k-means clustering using the MaxCompute Graph API (Java).

How it works

Each iteration assigns every data point to its nearest centroid, then recomputes the centroid as the mean of all points in that cluster. The algorithm stops when all centroids shift by less than the convergence threshold or the maximum iteration count is reached.

Algorithm steps

  1. Select initial centroids for k clusters.

  2. For each data point, compute its squared Euclidean distance to all k centroids and assign it to the nearest cluster.

  3. Recompute each centroid as the arithmetic mean of all points assigned to that cluster.

  4. If all centroids shift by less than the convergence threshold, stop. Otherwise, repeat from step 2.

Key parameters

ParameterDefaultDescription
Convergence threshold0.05Maximum Euclidean distance a centroid can shift between iterations before the algorithm considers it converged. Computed using the square root of the summed squared coordinate differences.
Maximum iterations30Upper bound on the number of supersteps. The job stops even if centroids have not fully converged.
Distance metricSquared EuclideanSquared distance is used for cluster assignment comparisons. The square root is computed only when checking convergence in terminate().

Implementation

The following example implements k-means clustering with four classes built on the MaxCompute Graph API.

Component overview

Before reading the code, review each component's responsibility:

KmeansVertex represents a single data point as a Tuple (feature vector). Its compute() method passes the vertex value to the aggregator by calling context.aggregate().

KmeansVertexReader loads the input table and creates one vertex per record. The record number becomes the vertex ID; all columns in the record become the vertex value as a Tuple.

KmeansAggrValue holds the shared aggregation state across workers, with three Tuple fields:

  • centers: current centroid coordinates

  • sums: accumulated coordinate sums per cluster

  • counts: data point counts per cluster

KmeansAggregator encapsulates the main algorithm logic across four methods:

MethodResponsibility
createInitialValue()On superstep 0, reads initial centroids from the cache file named "centers". On subsequent supersteps, retrieves the centroids computed in the previous iteration via getLastAggregatedValue(0).
aggregate()For each vertex, finds the nearest centroid using squared Euclidean distance and updates that cluster's sum and count.
merge()Combines partial sums and counts collected from workers running in parallel.
terminate()Computes new centroids from sums and counts. If all centroids shift by less than 0.05 or the maximum iteration count is reached, writes the final centroids to the output table and returns true (stop). Otherwise returns false (continue).

Sample code

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

import org.apache.log4j.Logger;

import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;

public class Kmeans {
  private final static Logger LOG = Logger.getLogger(Kmeans.class);

  public static class KmeansVertex extends
      Vertex<Text, Tuple, NullWritable, NullWritable> {

    @Override
    public void compute(
        ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
        Iterable<NullWritable> messages) throws IOException {
      context.aggregate(getValue());
    }

  }

  public static class KmeansVertexReader extends
      GraphLoader<Text, Tuple, NullWritable, NullWritable> {
    @Override
    public void load(LongWritable recordNum, WritableRecord record,
        MutationContext<Text, Tuple, NullWritable, NullWritable> context)
        throws IOException {
      KmeansVertex vertex = new KmeansVertex();
      vertex.setId(new Text(String.valueOf(recordNum.get())));
      vertex.setValue(new Tuple(record.getAll()));
      context.addVertexRequest(vertex);
    }

  }

  public static class KmeansAggrValue implements Writable {

    Tuple centers = new Tuple();
    Tuple sums = new Tuple();
    Tuple counts = new Tuple();

    @Override
    public void write(DataOutput out) throws IOException {
      centers.write(out);
      sums.write(out);
      counts.write(out);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
      centers = new Tuple();
      centers.readFields(in);
      sums = new Tuple();
      sums.readFields(in);
      counts = new Tuple();
      counts.readFields(in);
    }

    @Override
    public String toString() {
      return "centers " + centers.toString() + ", sums " + sums.toString()
          + ", counts " + counts.toString();
    }

  }

  public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {

    @SuppressWarnings("rawtypes")
    @Override
    public KmeansAggrValue createInitialValue(WorkerContext context)
        throws IOException {
      KmeansAggrValue aggrVal = null;
      if (context.getSuperstep() == 0) {
        aggrVal = new KmeansAggrValue();
        aggrVal.centers = new Tuple();
        aggrVal.sums = new Tuple();
        aggrVal.counts = new Tuple();

        byte[] centers = context.readCacheFile("centers");
        String lines[] = new String(centers).split("\n");

        for (int i = 0; i < lines.length; i++) {
          String[] ss = lines[i].split(",");
          Tuple center = new Tuple();
          Tuple sum = new Tuple();
          for (int j = 0; j < ss.length; ++j) {
            center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
            sum.append(new DoubleWritable(0.0));
          }
          LongWritable count = new LongWritable(0);
          aggrVal.sums.append(sum);
          aggrVal.counts.append(count);
          aggrVal.centers.append(center);
        }
      } else {
        aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
      }

      return aggrVal;
    }

    @Override
    public void aggregate(KmeansAggrValue value, Object item) {
      int min = 0;
      double mindist = Double.MAX_VALUE;
      Tuple point = (Tuple) item;

      for (int i = 0; i < value.centers.size(); i++) {
        Tuple center = (Tuple) value.centers.get(i);
        // use Euclidean Distance, no need to calculate sqrt
        double dist = 0.0d;
        for (int j = 0; j < center.size(); j++) {
          double v = ((DoubleWritable) point.get(j)).get()
              - ((DoubleWritable) center.get(j)).get();
          dist += v * v;
        }
        if (dist < mindist) {
          mindist = dist;
          min = i;
        }
      }

      // update sum and count
      Tuple sum = (Tuple) value.sums.get(min);
      for (int i = 0; i < point.size(); i++) {
        DoubleWritable s = (DoubleWritable) sum.get(i);
        s.set(s.get() + ((DoubleWritable) point.get(i)).get());
      }
      LongWritable count = (LongWritable) value.counts.get(min);
      count.set(count.get() + 1);
    }

    @Override
    public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
      for (int i = 0; i < value.sums.size(); i++) {
        Tuple sum = (Tuple) value.sums.get(i);
        Tuple that = (Tuple) partial.sums.get(i);
        for (int j = 0; j < sum.size(); j++) {
          DoubleWritable s = (DoubleWritable) sum.get(j);
          s.set(s.get() + ((DoubleWritable) that.get(j)).get());
        }
      }

      for (int i = 0; i < value.counts.size(); i++) {
        LongWritable count = (LongWritable) value.counts.get(i);
        count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
      }
    }

    @SuppressWarnings("rawtypes")
    @Override
    public boolean terminate(WorkerContext context, KmeansAggrValue value)
        throws IOException {

      // compute new centers
      Tuple newCenters = new Tuple(value.sums.size());
      for (int i = 0; i < value.sums.size(); i++) {
        Tuple sum = (Tuple) value.sums.get(i);
        Tuple newCenter = new Tuple(sum.size());
        LongWritable c = (LongWritable) value.counts.get(i);
        for (int j = 0; j < sum.size(); j++) {

          DoubleWritable s = (DoubleWritable) sum.get(j);
          double val = s.get() / c.get();
          newCenter.set(j, new DoubleWritable(val));

          // reset sum for next iteration
          s.set(0.0d);
        }
        // reset count for next iteration
        c.set(0);
        newCenters.set(i, newCenter);
      }

      // update centers
      Tuple oldCenters = value.centers;
      value.centers = newCenters;

      LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);

      // compare new/old centers
      boolean converged = true;
      for (int i = 0; i < value.centers.size() && converged; i++) {
        Tuple oldCenter = (Tuple) oldCenters.get(i);
        Tuple newCenter = (Tuple) newCenters.get(i);
        double sum = 0.0d;
        for (int j = 0; j < newCenter.size(); j++) {
          double v = ((DoubleWritable) newCenter.get(j)).get()
              - ((DoubleWritable) oldCenter.get(j)).get();
          sum += v * v;
        }
        double dist = Math.sqrt(sum);
        LOG.info("old center: " + oldCenter + ", new center: " + newCenter
            + ", dist: " + dist);
        // converge threshold for each center: 0.05
        converged = dist < 0.05d;
      }

      if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
        // converged or reach max iteration, output centers
        for (int i = 0; i < value.centers.size(); i++) {
          context.write(((Tuple) value.centers.get(i)).toArray());
        }
        // true means to terminate iteration
        return true;
      }

      // false means to continue iteration
      return false;
    }
  }

  private static void printUsage() {
    System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
    System.exit(-1);
  }

  public static void main(String[] args) throws IOException {
    if (args.length < 2)
      printUsage();

    GraphJob job = new GraphJob();

    job.setGraphLoaderClass(KmeansVertexReader.class);
    job.setRuntimePartitioning(false);
    job.setVertexClass(KmeansVertex.class);
    job.setAggregatorClass(KmeansAggregator.class);
    job.addInput(TableInfo.builder().tableName(args[0]).build());
    job.addOutput(TableInfo.builder().tableName(args[1]).build());

    // default max iteration is 30
    job.setMaxIteration(30);
    if (args.length >= 3)
      job.setMaxIteration(Integer.parseInt(args[2]));

    long start = System.currentTimeMillis();
    job.run();
    System.out.println("Job Finished in "
        + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  }
}

Job configuration

The main method configures the GraphJob with the following settings:

SettingValueDescription
setGraphLoaderClassKmeansVertexReader.classLoads input table records as vertices
setVertexClassKmeansVertex.classDefines per-vertex compute logic
setAggregatorClassKmeansAggregator.classDefines centroid update and convergence logic
setRuntimePartitioningfalseDisables runtime graph partitioning. K-means vertices do not need redistribution during loading, so disabling this improves graph loading performance.
setMaxIteration30 (default)Sets the iteration limit. Pass a third argument to the job to override.
addInput / addOutputargs[0] / args[1]Input and output table names, passed as command-line arguments