Dijkstra は有向グラフの SSSP (Single Source Shortest Path) を計算する典型的なアルゴリズムです。

重み付き有向グラフ G=(V,E) では、多くの経路がソース頂点 s からシンク頂点 v に向かっています。これらの経路では、辺の重みの合計が最小となる頂点が s から v の最短距離と呼ばれます。

アルゴリズムの基本概念は以下のとおりです。
  • 初期化: ソース頂点 s から s 自身への距離はゼロ (d[s] = 0) で、他の頂点 u から s までの距離は無限大 (d[u]=∞) です。
  • 反復: u から v への辺が存在する場合、s から v への最短距離は d[v] = min(d[v], d[u] + weight(u, v)) として更新されます。 すべての頂点から s までの距離が 変化しなくなると、反復を終了します。

このアルゴリズムの基本概念は、MaxCompute Graph プログラムを使用するソリューションに適用可能なことを示しています。 各頂点はソース頂点までの現在の最短距離を保持します。 値が変わると、新たな値と辺の重みを含むメッセージが、隣接する頂点に対して送信されます。 次の反復では、受け取ったメッセージに基づいて隣接する頂点が現在の最短距離を更新します。 現在のすべての頂点の最短距離が変化しなくなると、反復を終了します。

コード例

SSSP のコードは以下のとおりです。
import java.io.IOException;

import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.data.TableInfo;

public class SSSP {

  public static final String START_VERTEX = "sssp.start.vertex.id";

  public static class SSSPVertex extends
      Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {

    private static long startVertexId = -1;

    public SSSPVertex() {
      this.setValue(new LongWritable(Long.MAX_VALUE));
    

    public boolean isStartVertex(
        ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) {
      if (startVertexId == -1) {
        String s = context.getConfiguration().get(START_VERTEX);
        startVertexId = Long.parseLong(s);
      
      return getId().get() == startVertexId;
    

    @Override
    public void compute(
        ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
        Iterable<LongWritable> messages) throws IOException {
      long minDist = isStartVertex(context) ? 0 : Integer.MAX_VALUE;
      for (LongWritable msg : messages) {
        if (msg.get() < minDist) {
          minDist = msg.get();
        
      

      if (minDist < this.getValue().get()) {
        this.setValue(new LongWritable(minDist));
        if (hasEdges()) {
          for (Edge<LongWritable, LongWritable> e : this.getEdges()) {
            context.sendMessage(e.getDestVertexId(), new LongWritable(minDist
                + e.getValue().get()));
          
        
      } else {
        voteToHalt();
      
    

    @Override
    public void cleanup(
        WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
        throws IOException {
      context.write(getId(), getValue());
    
  

  public static class MinLongCombiner extends
      Combiner<LongWritable, LongWritable> {

    @Override
    public void combine(LongWritable vertexId, LongWritable combinedMessage,
        LongWritable messageToCombine) throws IOException {
      if (combinedMessage.get() > messageToCombine.get()) {
        combinedMessage.set(messageToCombine.get());
      
    

  

  public static class SSSPVertexReader extends
      GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {

    @Override
    public void load(
        LongWritable recordNum,
        WritableRecord record,
        MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
        throws IOException {
      SSSPVertex vertex = new SSSPVertex();
      vertex.setId((LongWritable) record.get(0));
      String[] edges = record.get(1).toString().split(",");
      for (int i = 0; i < edges.length; i++) {
        String[] ss = edges[i].split(":");
        vertex.addEdge(new LongWritable(Long.parseLong(ss[0])),
            new LongWritable(Long.parseLong(ss[1])));
      

      context.addVertexRequest(vertex);
    

  

  public static void main(String[] args) throws IOException {
    if (args.length < 2) {
      System.out.println("Usage: <startnode> <input> <output>");
      System.exit(-1);
    

    GraphJob job = new GraphJob();
    job.setGraphLoaderClass(SSSPVertexReader.class);
    job.setVertexClass(SSSPVertex.class);
    job.setCombinerClass(MinLongCombiner.class);

    job.set(START_VERTEX, args[0]);
    job.addInput(TableInfo.builder().tableName(args[1]).build());
    job.addOutput(TableInfo.builder().tableName(args[2]).build());

    long startTime = System.currentTimeMillis();
    job.run();
    System.out.println("Job Finished in "
        + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
  

SSSP のソースコードの詳細を以下に示します。
  • 19 行目: SSSPVertex を定義します。
    • 頂点の値は、この頂点からソース頂点 startVertexId までの現在の最短距離を示します。
    • compute() メソッドは、反復計算式 d[v] = min(d[v], d[u] + weight(u, v)) を使用して頂点の値を更新します。
    • cleanup() メソッドは、頂点とソース頂点までの最短距離を結果テーブルに書き込みます。
  • 58 行目: 頂点の値が変化しない場合、voteToHalt() が呼び出され、頂点が停止状態に変わったことをフレームワークに通知します。 すべての頂点が 停止状態に変わると、この計算は終了します。
  • 70 行目: MinLongCombiner を定義し、同一頂点に送信されるメッセージを組み合わせてパフォーマンスを最適化し、メモリの使用を減らします。
  • 83 行目: SSSPVertexReader クラスを定義し、グラフをロードし、 テーブルの各レコードを頂点に分解します。 レコードの最初の列は頂点 ID で、2 列目は 2:2、3:1、4:4 のように、頂点を始点とするすべての辺のセットを保存します。
  • 106 行目: メインプログラム (main 関数) を実行し、GraphJob を定義し、 Vertex/GraphLoader/Combiner および入出力テーブルの実装を指定します。