本記事では、OssCheckpoint を使用して Object Storage Service (OSS) 上のチェックポイントを直接読み書きする方法について説明します。チェックポイントは、モデルトレーニング中の特定時点におけるモデルの状態を保存するものです。
前提条件
OSS Connector for AI/ML がインストールおよび構成済みである必要があります。詳細については、「OSS Connector for AI/ML のインストール」および「OSS Connector for AI/ML の構成」をご参照ください。
OssCheckpoint
OssCheckpoint を使用して、モデルトレーニング中にトレーニング結果を読み書きします。
この例では、OssCheckpoint を使用したチェックポイントの読み書き方法を示します。
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# チェックポイントオブジェクトを作成します。
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# チェックポイントから読み取ります。
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# チェックポイントに書き込みます。
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
データ型
OssCheckpoint で作成されたチェックポイントオブジェクトは、一般的な I/O インターフェイスを実装しています。詳細については、「OSS Connector for AI/ML のデータ型」をご参照ください。
パラメーター
OssCheckpoint には、以下のパラメーターが必要です。
パラメーター | 型 | 必須 | 説明 |
endpoint | string | はい | OSS へのアクセスに使用するドメイン名です。詳細については、「リージョンとエンドポイント」をご参照ください。 |
cred_path | string | はい | 認証情報ファイルのデフォルトパスは |
config_path | string | はい | OSS Connector 構成ファイルのデフォルトパスは |
分散チェックポイント(DCP)
OSS Connector for AI/ML は、バージョン 1.2.3 以降で PyTorch Distributed Checkpoint(DCP)機能をサポートしています。OssDCPFileSystem を使用して、OSS 上に分散チェックポイントを直接保存・読み取りできます。
この例では、OssDCPFileSystem を使用した分散チェックポイントの保存および読み込み方法を示します。
import torchvision
import torch.distributed.checkpoint as DCP
from osstorchconnector import OssDCPFileSystem
import torch
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/dcp-checkpoint-resnet18"
model = torchvision.models.resnet18()
# OSS へ書き込みます。
fs = OssDCPFileSystem(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
oss_storage_writer = fs.writer(OSS_URI)
# DCP.save または DCP.async_save を使用します。
checkpoint_future = DCP.async_save(
state_dict=model.state_dict(),
storage_writer=oss_storage_writer,
)
checkpoint_future.result()
# OSS から読み込みます。
loaded_state_dict = {
key: torch.zeros_like(value) for key, value in model.state_dict().items()
}
oss_storage_reader = fs.reader(OSS_URI)
DCP.load(
loaded_state_dict,
storage_reader=oss_storage_reader,
)Safetensors
OSS Connector for AI/ML は、バージョン 1.2.0rc6 以降で safetensors フォーマットをサポートしています。OssSafetensor を使用して、OSS 上の safetensors ファイルを直接保存・読み取りできます。
この例では、OssSafetensor を使用した safetensors ファイルの保存および読み込み方法を示します。
import torch
from osstorchconnector import OssSafetensor
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/safetensors/model.safetensors"
sfts = OssSafetensor(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# テンソルを safetensors ファイルとして OSS に保存します。
tensors = {"embedding": torch.rand((512, 1024)), "attention": torch.rand((256, 256))}
metadata = {"a": "a", "b": "b"}
sfts.save_file(tensors, OSS_URI, metadata)
# OSS から safetensors ファイルを読み込みます。
loaded_tensors = sfts.load_file(OSS_URI, device="cpu")
# または safe_open を使用してテンソルを読み込みます。
with sfts.safe_open(OSS_URI, device ="cpu") as f:
metadata = f.metadata() # メタデータを取得します。
for key in f.keys(): # キーごとにテンソルを読み込みます。
tensor = f.get_tensor(key)