全部產品
Search
文件中心

MaxCompute:單源最短距離

更新時間:Feb 28, 2024

單源最短距離是指給定圖中一個源點,計算源點到其它所有節點的最短距離。Dijkstra演算法是求解有向圖中單源最短距離SSSP(Single Source Shortest Path)的經典演算法。

演算法原理

Dijkstra演算法是通過去更新最短距離值,每個維護到源點的當前最短距離值,當這個值發生變化時,將新值加上權值,發送訊息通知其鄰接點。下一輪迭代時,鄰接點根據收到的訊息,更新其當前最短距離值,當所有的當前最短距離值不再變化時,迭代結束。

  • 初始化:源點s到s自身的距離為0(d[s]=0),其他點u到s的距離為無窮(d[u]=∞)。
  • 迭代:如果存在一條從u到v的邊,則從s到v的最短距離更新為d[v]=min(d[v], d[u]+weight(u, v)),直到所有的點到s的距離不再發生變化時,迭代結束。
說明 對一個有權重的有向圖G=(V,E),從一個源點s到匯點v有很多重路徑,其中邊權和最小的路徑,稱為從s到v的最短距離。

由演算法基本原理可以看出,此演算法非常適合用MaxCompute Graph程式進行求解。

使用情境

圖類型通常分為有向圖和無向圖兩種,MaxCompute均支援。基於來源資料的分布,構造有向圖和無向圖時的路徑通路會存在差異,求解SSSP時會產生不同的結果。MaxCompute Graph以有向圖為基礎資料模型,架構內會以有向圖的模型參與計算。

程式碼範例

以下代碼基於不同的情境,提供不同的程式碼範例。

  • 有向圖
    • 定義類BaseLoadingVertexResolver,此異常類會在SSSP類中被引用。
      import com.aliyun.odps.graph.Edge;
      import com.aliyun.odps.graph.LoadingVertexResolver;
      import com.aliyun.odps.graph.Vertex;
      import com.aliyun.odps.graph.VertexChanges;
      import com.aliyun.odps.io.Writable;
      import com.aliyun.odps.io.WritableComparable;
      
      import java.io.IOException;
      import java.util.HashSet;
      import java.util.Iterator;
      import java.util.List;
      import java.util.Set;
      
      @SuppressWarnings("rawtypes")
      public class BaseLoadingVertexResolver<I extends WritableComparable, V extends Writable, E extends Writable, M extends Writable>
              extends LoadingVertexResolver<I, V, E, M> {
        @Override
        public Vertex<I, V, E, M> resolve(I vertexId, VertexChanges<I, V, E, M> vertexChanges) throws IOException {
      
          Vertex<I, V, E, M> vertex = addVertexIfDesired(vertexId, vertexChanges);
      
          if (vertex != null) {
            addEdges(vertex, vertexChanges);
          } else {
            System.err.println("Ignore all addEdgeRequests for vertex#" + vertexId);
          }
          return vertex;
        }
      
        protected Vertex<I, V, E, M> addVertexIfDesired(
                I vertexId,
                VertexChanges<I, V, E, M> vertexChanges) {
          Vertex<I, V, E, M> vertex = null;
          if (hasVertexAdditions(vertexChanges)) {
            vertex = vertexChanges.getAddedVertexList().get(0);
          }
      
          return vertex;
        }
      
        protected void addEdges(Vertex<I, V, E, M> vertex,
                                VertexChanges<I, V, E, M> vertexChanges) throws IOException {
          Set<I> destVertexId = new HashSet<I>();
          if (vertex.hasEdges()) {
            List<Edge<I, E>> edgeList = vertex.getEdges();
            for (Iterator<Edge<I, E>> edges = edgeList.iterator(); edges.hasNext(); ) {
              Edge<I, E> edge = edges.next();
              if (destVertexId.contains(edge.getDestVertexId())) {
                edges.remove();
              } else {
                destVertexId.add(edge.getDestVertexId());
              }
            }
          }
      
          for (Vertex<I, V, E, M> vertex1 : vertexChanges.getAddedVertexList()) {
            if (vertex1.hasEdges()) {
              List<Edge<I, E>> edgeList = vertex1.getEdges();
              for (Edge<I, E> edge : edgeList) {
                if (destVertexId.contains(edge.getDestVertexId())) continue;
                destVertexId.add(edge.getDestVertexId());
                vertex.addEdge(edge.getDestVertexId(), edge.getValue());
              }
            }
          }
        }
      
        protected boolean hasVertexAdditions(VertexChanges<I, V, E, M> changes) {
          return changes != null && changes.getAddedVertexList() != null
                  && !changes.getAddedVertexList().isEmpty();
        }
      }
      代碼說明:
      • 第15行:定義BaseLoadingVertexResolver。用於處理有向圖資料在載入過程中所遇到的衝突。
      • 第18行:resolve為處理衝突的具體方法。例如當前的某一頂點進行了兩次添加的過程(即進行了兩次addVertexRequest操作),這種行為便稱為衝突載入,需要人為處理衝突之後再參與計算。
    • 定義類SSSP
      import java.io.IOException;
      
      import com.aliyun.odps.graph.Combiner;
      import com.aliyun.odps.graph.ComputeContext;
      import com.aliyun.odps.graph.Edge;
      import com.aliyun.odps.graph.GraphJob;
      import com.aliyun.odps.graph.GraphLoader;
      import com.aliyun.odps.graph.MutationContext;
      import com.aliyun.odps.graph.Vertex;
      import com.aliyun.odps.graph.WorkerContext;
      import com.aliyun.odps.io.WritableRecord;
      import com.aliyun.odps.io.LongWritable;
      import com.aliyun.odps.data.TableInfo;
      
      public class SSSP {
        public static final String START_VERTEX = "sssp.start.vertex.id";
      
        public static class SSSPVertex extends
                Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
          private static long startVertexId = -1;
      
          public SSSPVertex() {
            this.setValue(new LongWritable(Long.MAX_VALUE));
          }
      
          public boolean isStartVertex(
                  ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) {
            if (startVertexId == -1) {
              String s = context.getConfiguration().get(START_VERTEX);
              startVertexId = Long.parseLong(s);
            }
            return getId().get() == startVertexId;
          }
      
          @Override
          public void compute(
                  ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
                  Iterable<LongWritable> messages) throws IOException {
            long minDist = isStartVertex(context) ? 0 : Long.MAX_VALUE;
            for (LongWritable msg : messages) {
              if (msg.get() < minDist) {
                minDist = msg.get();
              }
            }
            if (minDist < this.getValue().get()) {
              this.setValue(new LongWritable(minDist));
              if (hasEdges()) {
                for (Edge<LongWritable, LongWritable> e : this.getEdges()) {
                  context.sendMessage(e.getDestVertexId(), new LongWritable(minDist + e.getValue().get()));
                }
              }
            } else {
              voteToHalt();
            }
          }
      
          @Override
          public void cleanup(
                  WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
                  throws IOException {
            context.write(getId(), getValue());
          }
      
          @Override
          public String toString() {
            return "Vertex(id=" + this.getId() + ",value=" + this.getValue() + ",#edges=" + this.getEdges() + ")";
          }
        }
      
        public static class SSSPGraphLoader extends
                GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
          @Override
          public void load(
                  LongWritable recordNum,
                  WritableRecord record,
                  MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
                  throws IOException {
            SSSPVertex vertex = new SSSPVertex();
            vertex.setId((LongWritable) record.get(0));
            String[] edges = record.get(1).toString().split(",");
            for (String edge : edges) {
              String[] ss = edge.split(":");
              vertex.addEdge(new LongWritable(Long.parseLong(ss[0])), new LongWritable(Long.parseLong(ss[1])));
            }
            context.addVertexRequest(vertex);
          }
        }
      
        public static class MinLongCombiner extends
                Combiner<LongWritable, LongWritable> {
          @Override
          public void combine(LongWritable vertexId, LongWritable combinedMessage,
                              LongWritable messageToCombine) throws IOException {
            if (combinedMessage.get() > messageToCombine.get()) {
              combinedMessage.set(messageToCombine.get());
            }
          }
        }
      
        public static void main(String[] args) throws IOException {
          if (args.length < 3) {
            System.out.println("Usage: <startnode> <input> <output>");
            System.exit(-1);
          }
          GraphJob job = new GraphJob();
          job.setGraphLoaderClass(SSSPGraphLoader.class);
          job.setVertexClass(SSSPVertex.class);
          job.setCombinerClass(MinLongCombiner.class);
          job.setLoadingVertexResolver(BaseLoadingVertexResolver.class);
          job.set(START_VERTEX, args[0]);
          job.addInput(TableInfo.builder().tableName(args[1]).build());
          job.addOutput(TableInfo.builder().tableName(args[2]).build());
          long startTime = System.currentTimeMillis();
          job.run();
          System.out.println("Job Finished in "
                  + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
        }
      }
      
                                  
      代碼說明:
      • 第19行:定義SSSPVertex。其中:
        • 點值表示該頂點到源點startVertexId的最短距離。
        • compute()方法使用迭代公式d[v]=min(d[v], d[u]+weight(u, v))計算最短距離值並更新至當前點值。
        • cleanup()方法將當前頂點到源點的最短距離寫入結果表中。
      • 第54行:當前頂點的Value值(即該頂點到源點的最短路徑)未發生變化時,即調用voteToHalt()通過架構使該頂點進入halt狀態。當所有頂點都進入halt狀態時,計算結束。
      • 第71行:定義GraphLoader圖資料以有向圖的方式載入圖資料。通過將表記憶體儲的記錄解析為圖的頂點或邊資訊載入至架構內。如上範例程式碼中,使用者可通過addVertexRequest方式將圖的頂點資訊載入至圖計算的上下文內。
      • 第90行:定義MinLongCombiner。對發送給同一個點的訊息進行合并,最佳化效能,減少記憶體佔用。
      • 第101行:主程式main函數中定義GraphJob。指定VertexGraphLoaderBaseLoadingVertexResolverCombiner等的實現,指定輸入輸出表。
      • 第110行:添加處理衝突的類BaseLoadingVertexResolver
  • 無向圖
    import com.aliyun.odps.data.TableInfo;
    import com.aliyun.odps.graph.*;
    import com.aliyun.odps.io.DoubleWritable;
    import com.aliyun.odps.io.LongWritable;
    import com.aliyun.odps.io.WritableRecord;
    
    import java.io.IOException;
    import java.util.HashSet;
    import java.util.Set;
    
    public class SSSPBenchmark4 {
        public static final String START_VERTEX = "sssp.start.vertex.id";
    
        public static class SSSPVertex extends
                Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
            private static long startVertexId = -1;
            public SSSPVertex() {
                this.setValue(new DoubleWritable(Double.MAX_VALUE));
            }
            public boolean isStartVertex(
                    ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) {
                if (startVertexId == -1) {
                    String s = context.getConfiguration().get(START_VERTEX);
                    startVertexId = Long.parseLong(s);
                }
                return getId().get() == startVertexId;
            }
    
            @Override
            public void compute(
                    ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context,
                    Iterable<DoubleWritable> messages) throws IOException {
                double minDist = isStartVertex(context) ? 0 : Double.MAX_VALUE;
                for (DoubleWritable msg : messages) {
                    if (msg.get() < minDist) {
                        minDist = msg.get();
                    }
                }
                if (minDist < this.getValue().get()) {
                    this.setValue(new DoubleWritable(minDist));
                    if (hasEdges()) {
                        for (Edge<LongWritable, DoubleWritable> e : this.getEdges()) {
                            context.sendMessage(e.getDestVertexId(), new DoubleWritable(minDist
                                    + e.getValue().get()));
                        }
                    }
                } else {
                    voteToHalt();
                }
            }
    
            @Override
            public void cleanup(
                    WorkerContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context)
                    throws IOException {
                context.write(getId(), getValue());
            }
        }
    
        public static class MinLongCombiner extends
                Combiner<LongWritable, DoubleWritable> {
            @Override
            public void combine(LongWritable vertexId, DoubleWritable combinedMessage,
                                DoubleWritable messageToCombine) {
                if (combinedMessage.get() > messageToCombine.get()) {
                    combinedMessage.set(messageToCombine.get());
                }
            }
        }
    
        public static class SSSPGraphLoader extends
                GraphLoader<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
            @Override
            public void load(
                    LongWritable recordNum,
                    WritableRecord record,
                    MutationContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context)
                    throws IOException {
                LongWritable sourceVertexID = (LongWritable) record.get(0);
                LongWritable destinationVertexID = (LongWritable) record.get(1);
                DoubleWritable edgeValue = (DoubleWritable) record.get(2);
                Edge<LongWritable, DoubleWritable> edge = new Edge<LongWritable, DoubleWritable>(destinationVertexID, edgeValue);
                context.addEdgeRequest(sourceVertexID, edge);
                Edge<LongWritable, DoubleWritable> edge2 = new
                        Edge<LongWritable, DoubleWritable>(sourceVertexID, edgeValue);
                context.addEdgeRequest(destinationVertexID, edge2);
            }
        }
    
        public static class SSSPLoadingVertexResolver extends
                LoadingVertexResolver<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
    
            @Override
            public Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> resolve(
                    LongWritable vertexId,
                    VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> vertexChanges) throws IOException {
    
                SSSPVertex computeVertex = new SSSPVertex();
                computeVertex.setId(vertexId);
                Set<LongWritable> destinationVertexIDSet = new HashSet<>();
    
                if (hasEdgeAdditions(vertexChanges)) {
                    for (Edge<LongWritable, DoubleWritable> edge : vertexChanges.getAddedEdgeList()) {
                        if (!destinationVertexIDSet.contains(edge.getDestVertexId())) {
                            destinationVertexIDSet.add(edge.getDestVertexId());
                            computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
                        }
    
                    }
                }
    
                return computeVertex;
            }
    
            protected boolean hasEdgeAdditions(VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> changes) {
                return changes != null && changes.getAddedEdgeList() != null
                        && !changes.getAddedEdgeList().isEmpty();
            }
        }
    
        public static void main(String[] args) throws IOException {
            if (args.length < 2) {
                System.out.println("Usage: <startnode> <input> <output>");
                System.exit(-1);
            }
            GraphJob job = new GraphJob();
            job.setGraphLoaderClass(SSSPGraphLoader.class);
            job.setLoadingVertexResolver(SSSPLoadingVertexResolver.class);
            job.setVertexClass(SSSPVertex.class);
            job.setCombinerClass(MinLongCombiner.class);
            job.set(START_VERTEX, args[0]);
            job.addInput(TableInfo.builder().tableName(args[1]).build());
            job.addOutput(TableInfo.builder().tableName(args[2]).build());
            long startTime = System.currentTimeMillis();
            job.run();
            System.out.println("Job Finished in "
                    + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
        }
    }
                        
    代碼說明:
    • 第15行:定義SSSPVertex。其中:
      • 點值表示該頂點到源點startVertexId的最短距離。
      • compute()方法使用迭代公式d[v]=min(d[v], d[u]+weight(u, v))計算最短距離值並更新至當前點值。
      • cleanup()方法將當前頂點到源點的最短距離寫入結果表中。
    • 第54行:當前頂點的Value值(即該頂點到源點的最短路徑)未發生變化時,即調用voteToHalt()通過架構使該頂點進入halt狀態。當所有頂點都進入halt狀態時,計算結束。
    • 第61行:定義MinLongCombiner。對發送給同一個點的訊息進行合并,最佳化效能,減少記憶體佔用。
    • 第72行:定義GraphLoader圖資料以無向圖的方式載入圖資料。通過addEdgeRequest以兩點之間的邊作為無向邊載入邊資訊,這樣便可保證當前表記憶體儲的圖資訊載入成無向圖。
      • 第80行:第一列表示初始點的ID。
      • 第81行:第二列表示終點的ID。
      • 第82行:第三列表示邊的權重。
      • 第83行:建立邊,由終點ID和邊的權重組成。
      • 第84行:請求給初始點添加邊。
      • 第85行 - 第87行:每條Record表示雙向邊,重複第83行與第84行。
    • 定義SSSPLoadingVertexResolver。用於處理無向圖資料在載入過程中所遇到的衝突。例如當前的某一邊進行了兩次添加的過程(即進行了兩次addEdgeRequest操作),這種行為便稱為衝突載入,需要人為處理重複添加的邊才可保證計算正確性。
    • 第101行:主程式main函數中定義GraphJob。指定VertexGraphLoaderSSSPLoadingVertexResolverCombiner等的實現,指定輸入輸出表。

運行結果

以下是基於有向圖的程式碼範例的運行結果。操作詳情,請參見編寫Graph
vertex    value
1        0
2        2
3        1
4        3
5        2
  • vertex:代表當前頂點。
  • value:代表當前vertex到達源點(1)的最短距離。
說明 無向圖資料,使用者可以參考無向圖程式碼範例中的初始點ID,終點ID,邊的權值自行構造。