You can call the whale.auto_parallel operation to implement distributed model training with ease. This topic describes the syntax and parameters of the operation. This topic also provides the sample code to show you how to call the operation.

Background information

Whale can implement distributed model training by grouping resources and dividing models. To achieve efficient distributed training by using this method, you must have a good understanding of how to balance the cluster and scope settings. To simplify operations, Whale provides the whale.auto_parallel operation to allow you to use a single line of code to implement parallel training.

Operation description

  • Syntax
    auto_parallel(modes)
  • Description
    Implements parallelism based on a single line of code. Examples:
    • whale.auto_parallel(whale.replica): automatically implements data parallelism for the model.
    • whale.auto_parallel(whale.split): automatically implements operator splitting for the model.
    • whale.auto_parallel(whale.pipeline): automatically implements operator splitting and pipeline parallelism for the model.
    • whale.auto_parallel([whale.pipeline, whale.replica]): automatically implements parallelism by combining model parallelism and data parallelism. When you set the modes parameter, you can use multiple scope-related primitives in combination.
    • whale.auto_parallel(): automatically infers a strategy for dividing models to implement parallelism.
    Note Only whale.auto_parallel(whale.replica) is supported.
  • Parameters

    modes: the parallelism mode, such as data parallelism, model parallelism, pipeline parallelism, and combined parallelism. Only automatic data parallelism is supported, which can be achieved by using the modes=whale.replica setting.

  • Return value

    None.

  • Examples
    Add wh.auto_parallel(wh.replica) following import whale as wh to automatically group resources and implement data parallelism. Only the core code is given here. For more information about the complete code, visit auto_data_parallel.py.
    import whale as wh
    
    wh.auto_parallel(wh.replica)
    
    # Construct your model here.
    model_definition()