本文为您介绍如何使用OssCheckpoint直接从OSS中读写检查点(模型训练过程中保存的特定时间点的模型状态)。
前提条件
已安装并配置OSS Connector for AI/ML。具体操作,请参见安装OSS Connector for AI/ML和配置OSS Connector for AI/ML。
OssCheckpoint
OssCheckpoint适用于数据训练过程中对训练结果进行读写需求的场景。
以下示例展示了如何使用OssCheckpoint来进行Checkpoint的读取和写入。
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# 使用OssCheckpoint创建checkpoint
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# 读 checkpoint
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# 写 checkpoint
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
数据类型
通过OssCheckpoint创建的checkpoint对象实现了常用的IO接口。更多信息,请参见OSS Connector for AI/ML中的数据类型。
参数配置
使用OssCheckpoint时需要进行相应配置,具体配置项说明请参见下表。
参数名 | 参数类型 | 是否必选 | 说明 |
endpoint | string | 是 | OSS对外服务的访问域名。更多信息,请参见地域和Endpoint。 |
cred_path | string | 是 | 鉴权文件默认路径为 |
config_path | string | 是 | OSS Connector配置文件默认路径为 |
分布式检查点(DCP)
OSS Connector for AI/ML从V1.2.3起支持PyTorch分布式检查点(DCP)功能。使用OssDCPFileSystem可以直接在OSS上存储和读取分布式检查点。
以下示例展示了如何使用OssDCPFileSystem来存储和读取分布式检查点。
import torchvision
import torch.distributed.checkpoint as DCP
from osstorchconnector import OssDCPFileSystem
import torch
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/dcp-checkpoint-resnet18"
model = torchvision.models.resnet18()
# write to OSS
fs = OssDCPFileSystem(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
oss_storage_writer = fs.writer(OSS_URI)
# DCP.save or DCP.async_save
checkpoint_future = DCP.async_save(
state_dict=model.state_dict(),
storage_writer=oss_storage_writer,
)
checkpoint_future.result()
# load from OSS
loaded_state_dict = {
key: torch.zeros_like(value) for key, value in model.state_dict().items()
}
oss_storage_reader = fs.reader(OSS_URI)
DCP.load(
loaded_state_dict,
storage_reader=oss_storage_reader,
)Safetensors
OSS Connector for AI/ML从V1.2.0rc6起支持safetensors格式。使用OssSafetensor可以直接在OSS上存储和读取safetensors文件。
以下示例展示了如何使用OssSafetensor来存储和读取safetensors文件。
import torch
from osstorchconnector import OssSafetensor
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/safetensors/model.safetensors"
sfts = OssSafetensor(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# save tensors to safetensor file on OSS
tensors = {"embedding": torch.rand((512, 1024)), "attention": torch.rand((256, 256))}
metadata = {"a": "a", "b": "b"}
sfts.save_file(tensors, OSS_URI, metadata)
# load safetensor file from OSS
loaded_tensors = sfts.load_file(OSS_URI, device="cpu")
# or load tensors by safe_open
with sfts.safe_open(OSS_URI, device ="cpu") as f:
metadata = f.metadata() # get metadata
for key in f.keys(): # read tensors by keys
tensor = f.get_tensor(key)