Whale provides GraphKeys to distinguish collections that use different aggregation methods. This topic describes the syntax and features of GraphKey-related operations. This topic also provides the sample code to show you how to call the operations.

Operation Description
add_to_collection Adds a tensor to a collection that uses a specific aggregation method.
get_all_collections Returns all tensors in all non-empty collections.
get_collection Returns all tensors in a collection that uses a specific aggregation method.

Background information

To improve the sample throughput for model training, you may use different distributed parallelism strategies in combination, such as data parallelism, model parallelism, and pipeline parallelism. However, Whale does not aggregate the output tensor values of operators across replicas. To debug or converge algorithms, you must regularly check the changes of metrics such as local or global loss and accuracy. Whale provides several GraphKeys to distinguish collections. You can call the whale.add_to_collection, whale.get_all_collections, or whale.get_collection operation to view or modify the tensors in a specific collection.

add_to_collection

  • Syntax
    add_to_collection(tensor, graph_key)
  • Description

    Adds a tensor to a collection with a specific GraphKey so that the Whale framework can automatically aggregate the tensor values as needed. If this operation is not called, the sess.run method returns the tensor values for a single replica.

  • Parameters
    • tensor: the tensor to be aggregated. The value is the output of an operator and is of the TensorFlow Tensor type.
    • graph_key: the aggregation method of the tensor. The value is of the STRING type. For more information about valid values, see GraphKeys and scenarios.
  • Return value

    None.

  • Examples
    import tensorflow as tf
    import whale as wh
    # Scenario: Check the global mean of loss for all model replicas.
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=tf.cast(logits, tf.float32))
    wh.add_to_collection(loss, wh.GraphKeys.GLOBAL_MEAN_OBJECTS)

get_all_collections

  • Syntax
    get_all_collections()
  • Description

    Returns all tensors in all non-empty collections.

  • Parameters

    None.

  • Return value

    The value is of the LIST type. If all collections are empty, an empty list is returned.

  • Examples
    import whale as wh
    # Scenario: Return all collections.
    print(wh.get_all_collections())

get_collection

  • Syntax
    get_collection(graph_key)
  • Description

    Returns all tensors in a collection with a specific GraphKey.

  • Parameters

    graph_key: the aggregation method of the tensors in the specific collection. The value is of the STRING type. For more information about valid values, see GraphKeys and scenarios.

  • Return value

    The value is of the LIST type. If the specific collection is empty, an empty list is returned.

  • Examples
    import whale as wh
    # Scenario: Check all tensors in a collection whose GraphKey is whale.GraphKeys.GLOBAL_MEAN_OBJECTS.
    print(wh.get_collection(wh.GraphKeys.GLOBAL_MEAN_OBJECTS))

GraphKeys and scenarios

GraphKey Description Scenario
GLOBAL_CONCAT_OBJECTS Concatenates the tensors in the collection for all model replicas along an axis of 0. Check the data consumed by all model replicas in an iteration.
LOCAL_CONCAT_OBJECTS Concatenates the tensors in the collection for the current model replica along an axis of 0. Check the data consumed by the current model replica in an iteration. For example, check the aggregation of microbatches in pipeline parallelism mode.
GLOBAL_MEAN_OBJECTS Calculates the mean of all tensors in the collection for all model replicas. Check the mean of loss for all model replicas in an iteration.
LOCAL_MEAN_OBJECTS Calculates the mean of all tensors in the collection for the current model replica. Check the mean of loss for the current model replica in an iteration. For example, check the mean of loss for multiple microbatches in pipeline parallelism mode.
GLOBAL_SUM_OBJECTS Calculates the sum of all tensors in the collection for all model replicas. Check the sum of loss for all model replicas in an iteration.
LOCAL_SUM_OBJECTS Calculates the sum of all tensors in the collection for the current model replica. Check the sum of loss for the current model replica in an iteration. For example, check the sum of loss for multiple microbatches in pipeline parallelism mode.