ここでは Aggregator の実装と関連 API について説明し、KmeansClustering を例に Aggregator の使用方法を説明します。
MaxCompute Graph では、Aggregator はグローバル情報の収集と処理を手助けします。 MaxCompute Graph では、Aggregator はグローバル情報の要約と処理に使用されます。
Aggregator の実装
- 1 つは、分散モードのすべての Worker 上で実行されます。
- もう 1 つは、単一頂点モードで AggregatorOwner が配置された Worker 上で実行されます。
Aggregator の API
- createStartupValue(context)
この API はすべての Worker で一度だけ実行されます。 すべてのスーパーステップが始まる前に呼び出され、通常 AggregatorValue の初期化に使用されます。 最初のスーパーステップの反復 (スーパーステップは 0) では、API で初期化された AggregatorValue オブジェクトは WorkerContext.getLastAggregatedValue() または ComputeContext.getLastAggregatedValue() の呼び出しにより取得できます。
- createInitialValue(context)
この API は各スーパーステップが開始されるとすべての Worker で呼び出されます。 現在の反復の AggregatorValue を初期化するために使用されます。通常、前回の反復の結果は WorkerContext.getLastAggregatedValue() によって取得され、部分的な初期化が実行されます。
- aggregate(value, item)
この API はすべての Worker で実行されます。 前述の 2 つの API はフレームワークによって自動的に呼び出されるのに対し、この API はComputeContext#aggregate(item) の明示的な呼び出しによってトリガーされます。 この API は部分集計を実行するために使用されます。 最初のパラメーター値は、Worker が現在のスーパーステップで集計した結果を示します。 初期値は createInitialValue によって返されるオブジェクトです。 2 番目のパラメーターは、ユーザーコードが ComputeContext#aggregate(item) を呼び出すときに送信されます。 この API では、item は通常、集計値を更新するために使用されます。 すべての集計演算が実行された後、取得された値は、Worker の部分集計結果になります。 その後、結果はフレームワークによって AggregatorOwner が配置されている Worker に送信されます。
- merge(value, partial)
この API は、AggregatorOwner が配置された Worker によって実行されます。 Worker の部分集計結果をマージしてグローバル集計オブジェクトを取得するために使用されます。集計と同様に、value は集計結果を示し、partial は集計対象のオブジェクトを示します。partial は値を更新するために使用されます。
たとえば 3 つの Worker w0、w1、w2 が存在し、集計結果を p0、p1、p2 とします。p1、p0、p2 の順番で AggregatorOwner が配置された Worker へ送信された場合、マージ順序は次のようになります。
- merge(p1, p0) が最初に実行され、p1 と p0 が p1’ として集計されます。
- merge(p1’, p2) が実行され、p1’ と p2 は集計されて、このスーパーステップでのグローバル集計結果である p1’’ となります。
前の例では、1 つの Worker しか存在しない場合、merge() 操作の実行は不要であることを示しています。 つまり、merge() は呼び出されません。
- terminate(context, value)
AggregatorOwner が配置された Worker が merge() を実行した後、フレームワークは terminate(context, value) を呼び出し、最後の処理を実行します。 2 番目のパラメーターの値は merge() により取得されたグローバル集計結果を示します。 グローバル集計結果はこのメソッド内でさらに変更可能です。 terminate() が実行されると、フレームワークは次のスーパーステップのためにグローバル集計オブジェクトをすべての Worker に送信します。 terminate() の特別な機能は、true が返されると、ジョブ全体の反復が終了するというものです。 それ以外の場合、反復は継続されます。 機械学習シナリオでは、通常、収束後に true が返されてジョブの終了が決定されます。
KmeansClustering 例
- GraphLoader セクション
GraphLoader: GraphLoader の部分は、入力テーブルをロードし、グラフの頂点または辺に変換するために使用されます。 入力テーブルのデータの各行はサンプルであり、サンプルは頂点を構成し、VertexValue はサンプルを格納するために使用されます。
最初に、書き込み可能なクラス KmeansValue が VertexValue タイプとして定義されます。public static class KmeansValue implements Writable { DenseVector sample; public KmeansValue() { public KmeansValue(DenseVector v) { this.sample = v; @Override public void write(DataOutput out) throws IOException { wirteForDenseVector(out, sample); @Override public void readFields(DataInput in) throws IOException { sample = readFieldsForDenseVector(in);
KmeansValue: DenseVector オブジェクトは、サンプルを格納するために KmeansValue にカプセル化されています。 DenseVector 型は matrix-toolkits-java. wirteForDenseVector() を継承し、 readFieldsForDenseVector() はシリアル化と逆シリアル化に使用されます。 詳細については、付属ファイル Kmeans 内の完全なコードをご参照ください。
カスタム KmeansReader コードは次のとおりです。public static class KmeansReader extends GraphLoader<LongWritable, KmeansValue, NullWritable, NullWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, KmeansValue, NullWritable, NullWritable> context) throws IOException { KmeansVertex v = new KmeansVertex(); v.setId(recordNum); int n = record.size(); DenseVector dv = new DenseVector(n); for (int i = 0; i < n; i++) { dv.set(i, ((DoubleWritable)record.get(i)).get()); v.setValue(new KmeansValue(dv)); context.addVertexRequest(v);
KmeansReader では、データの各行 (レコード) が読み込まれるときに頂点が作成されます。recordNum は頂点 ID として使用され、レコードの内容は DenseVector オブジェクトに変換され、VertexValue にカプセル化されます。
- Vertexカスタム KmeansVertex コード: ロジックについては、各反復で管理されるサンプルに対して部分集計が実行されます。 ロジックについての詳細は、次のセクションの「Aggregator」の実装をご参照ください。
public static class KmeansVertex extends Vertex<LongWritable, KmeansValue, NullWritable, NullWritable> { @Override public void compute( ComputeContext<LongWritable, KmeansValue, NullWritable, NullWritable> context, Iterable<NullWritable> messages) throws IOException { context.aggregate(getValue());
- AggregatorKmeans 全体の主なロジックは Aggregator に集計されます。カスタム KmeansAggrValue は、集計と送信の対象となるコンテンツの管理に使用されます。
public static class KmeansAggrValue implements Writable { DenseMatrix centroids; DenseMatrix sums; // used to recalculate new centroids DenseVector counts; // used to recalculate new centroids @Override public void write(DataOutput out) throws IOException { wirteForDenseDenseMatrix(out, centroids); wirteForDenseDenseMatrix(out, sums); wirteForDenseVector(out, counts); @Override public void readFields(DataInput in) throws IOException { centroids = readFieldsForDenseMatrix(in); sums = readFieldsForDenseMatrix(in); counts = readFieldsForDenseVector(in);
KmeansAggrValue では 3 つのオブジェクトが管理されています。centroids は既存の K 中心点を示します。 サンプルが m ディメンションの場合、centroids は K x m のマトリクスです。sums は centroids と同じサイズのマトリクスで、各要素は特定の中心点に最も近いサンプルの特定のディメンションの合計を記録します。 たとえば、sums(i, j) は、中心点 i に最も近いサンプルのディメンション j の合計を示します。
counts は各中心点に最も近いサンプル数を記録する K ディメンションベクトルです。sums と counts は主な集計内容である新しい中心点を計算するために一緒に使用されます。
次に KmeansAggregator がカスタム Aggregator 実装に使用されます。 前述の API の順に実装について説明します。- createStartupValue() を実行します。
public static class KmeansAggregator extends Aggregator<KmeansAggrValue> { public KmeansAggrValue createStartupValue(WorkerContext context) throws IOException { KmeansAggrValue av = new KmeansAggrValue(); byte[] centers = context.readCacheFile("centers"); String lines[] = new String(centers).split("\n"); int rows = lines.length; int cols = lines[0].split(",").length; // assumption rows >= 1 av.centroids = new DenseMatrix(rows, cols); av.sums = new DenseMatrix(rows, cols); av.sums.zero(); av.counts = new DenseVector(rows); av.counts.zero(); for (int i = 0; i < lines.length; i++) { String[] ss = lines[i].split(","); for (int j = 0; j < ss.length; j++) { av.centroids.set(i, j, Double.valueOf(ss[j])); return av;
このメソッドでは、KmeansAggrValue オブジェクトが初期化され、初期中心点がリソースファイルセンターから読み込まれ、値が重心に与えられます。sums と counts の初期値は 0 です。
- createInitialValue() の実行します。
@Override public void aggregate(KmeansAggrValue value, Object item) throws IOException { DenseVector sample = ((KmeansValue)item).sample; // find the nearest centroid int min = findNearestCentroid(value.centroids, sample); // update sum and count for (int i = 0; i < sample.size(); i ++) { value.sums.add(min, i, sample.get(i)); value.counts.add(min, 1.0d);
この createInitialValue() メソッドでは、サンプルアイテムと最短のユークリッド距離にある中心点のインデックスを見つけるために findNearestCentroid() が呼び出されます。 次に、各ディメンションが合計に加算され、カウントの値がプラス 1 になります。 findNearestCentroid() の実装方法の詳細は、付属ファイル Kmeans をご参照ください。
- createStartupValue() を実行します。
- merge() を実行します。
@Override public void merge(KmeansAggrValue value, KmeansAggrValue partial) throws IOException { value.sums.add(partial.sums); value.counts.add(partial.counts);
マージ処理の実装ロジックは各 Workder で集計された合計とカウントの値を追加します。
- terminate() を実行します。
@Override public boolean terminate(WorkerContext context, KmeansAggrValue value) throws IOException { // Calculate the new means to be the centroids (original sums) DenseMatrix newCentriods = calculateNewCentroids(value.sums, value.counts, value.centroids); // print old centroids and new centroids for debugging System.out.println("\nsuperstep: " + context.getSuperstep() + "\nold centriod:\n" + value.centroids + " new centriod:\n" + newCentriods); boolean converged = isConverged(newCentriods, value.centroids, 0.05d); System.out.println("superstep: " + context.getSuperstep() + "/" + (context.getMaxIteration() - 1) + " converged: " + converged); if (converged || context.getSuperstep() == context.getMaxIteration() - 1) { // converged or reach max iteration, output centriods for (int i = 0; i < newCentriods.numRows(); i++) { Writable[] centriod = new Writable[newCentriods.numColumns()]; for (int j = 0; j < newCentriods.numColumns(); j++) { centriod[j] = new DoubleWritable(newCentriods.get(i, j)); context.write(centriod); // true means to terminate iteration return true; // update centriods value.centroids.set(newCentriods); // false means to continue iteration return false;
terminate() では、calculateNewCentroids() が合計とカウントに基づいて呼び出され、平均値を計算して新しい中心点を取得します。 次に、新しい中心点と古い中心点との間のユークリッド距離に基づいて isConverged() が呼び出され、中心点が収束しているかどうかが判定されます。 収束または反復の回数が上限のしきい値に達すると、新しい中心点が出力され、true が返されて反復が終了します。 それ以外の場合は、中心点が更新され、反復を続けるために false が返されます calculateNewCentroids() および isConverged() の実装方法の詳細については、付属ファイルをご参照ください。
- main() メソッドmain() メソッドは、Graph ジョブの作成、関連する設定の実行、およびジョブの送信に使用されます。 コードは以下のとおりです。
public static void main(final String [] args)throws IOException{ if (args.length < 2) printUsage(); GraphJob job = new GraphJob(); job.setGraphLoaderClass(KmeansReader.class); job.setRuntimePartitioning(false); job.setVertexClass(KmeansVertex.class); job.setAggregatorClass(KmeansAggregator.class); job.addInput(TableInfo.builder().tableName(args[0]).build()); job.addOutput(TableInfo.builder().tableName(args[1]).build()); // default max iteration is 30 job.setMaxIteration(30); if (args.length >= 3) job.setMaxIteration(Integer.parseInt(args[2])); long start = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
注 job.setRuntimePartitioning(false) が false に設定されている場合、各 Worker によってロードされたデータは Partitioner に基づいてパーティション化されません。 つまり、データをロードする人がそれを維持します。
まとめ
ここでは、MaxCompute Graph の Aggregator の機能、API の意味、および KmeansClustering の例を紹介しています。 まとめると、Aggregator は以下のように実装できます。
- 各 Worker は AggregatorValue 生成のための起動中に createStartupValue を実行します。
- 各 Worker は、各反復が現在のラウンドで AggregatorValue を初期化する前に createInitialValue を実行します。
- 反復では、各頂点は context.aggregate() を使用してaggregate() を実行し、Worker で部分的な反復を実装します。
- 各 Worker は、AggregatorOwner が配置されている Worker に部分的な反復結果を送信します。
- AggregatorOwner が配置されている Worker は、グローバル集計を実装するために複数回マージを実行します。
- AggregatorOwner が配置されている Worker は、グローバル集計結果処理を終了し、反復を終了するかどうかを決定します。