Artikel ini menjelaskan cara menggunakan OssCheckpoint untuk membaca dan menulis langsung ke checkpoint di Object Storage Service (OSS). Checkpoint menyimpan status model pada titik tertentu selama pelatihan.
Prasyarat
OSS Connector for AI/ML telah diinstal dan dikonfigurasi. Untuk informasi lebih lanjut, lihat Install OSS Connector for AI/ML dan Configure OSS Connector for AI/ML.
OssCheckpoint
Gunakan OssCheckpoint untuk membaca dan menulis hasil pelatihan selama pelatihan model.
Contoh berikut menunjukkan cara menggunakan OssCheckpoint untuk membaca dari dan menulis ke checkpoint.
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# Buat objek checkpoint.
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# Baca dari checkpoint.
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# Tulis ke checkpoint.
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
Tipe Data
Objek checkpoint yang dibuat oleh OssCheckpoint mengimplementasikan antarmuka I/O umum. Untuk informasi lebih lanjut, lihat Data types in OSS Connector for AI/ML.
Parameter
OssCheckpoint memerlukan parameter berikut.
Parameter | Type | Required | Description |
endpoint | string | Yes | Nama domain akses untuk OSS. Untuk informasi lebih lanjut, lihat Regions and endpoints. |
cred_path | string | Yes | Jalur default file kredensial adalah |
config_path | string | Yes | Jalur default file konfigurasi OSS Connector adalah |
Distributed checkpoint (DCP)
OSS Connector for AI/ML mendukung fitur PyTorch Distributed Checkpoint (DCP) mulai dari versi V1.2.3. Anda dapat menggunakan OssDCPFileSystem untuk menyimpan dan membaca langsung checkpoint terdistribusi di OSS.
Contoh berikut menunjukkan cara menggunakan OssDCPFileSystem untuk menyimpan dan memuat checkpoint terdistribusi.
import torchvision
import torch.distributed.checkpoint as DCP
from osstorchconnector import OssDCPFileSystem
import torch
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/dcp-checkpoint-resnet18"
model = torchvision.models.resnet18()
# Tulis ke OSS.
fs = OssDCPFileSystem(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
oss_storage_writer = fs.writer(OSS_URI)
# Gunakan DCP.save atau DCP.async_save.
checkpoint_future = DCP.async_save(
state_dict=model.state_dict(),
storage_writer=oss_storage_writer,
)
checkpoint_future.result()
# Muat dari OSS.
loaded_state_dict = {
key: torch.zeros_like(value) for key, value in model.state_dict().items()
}
oss_storage_reader = fs.reader(OSS_URI)
DCP.load(
loaded_state_dict,
storage_reader=oss_storage_reader,
)Safetensors
OSS Connector for AI/ML mendukung format safetensors mulai dari versi V1.2.0rc6. Anda dapat menggunakan OssSafetensor untuk menyimpan dan membaca langsung file safetensors di OSS.
Contoh berikut menunjukkan cara menggunakan OssSafetensor untuk menyimpan dan memuat file safetensors.
import torch
from osstorchconnector import OssSafetensor
ENDPOINT = "http://oss-cn-beijing-internal.aliyuncs.com"
CONFIG_PATH = "/etc/oss-connector/config.json"
CRED_PATH = "/root/.alibabacloud/credentials"
OSS_URI = "oss://ossconnectorbucket/safetensors/model.safetensors"
sfts = OssSafetensor(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# Simpan tensor sebagai file safetensors ke OSS.
tensors = {"embedding": torch.rand((512, 1024)), "attention": torch.rand((256, 256))}
metadata = {"a": "a", "b": "b"}
sfts.save_file(tensors, OSS_URI, metadata)
# Muat file safetensor dari OSS.
loaded_tensors = sfts.load_file(OSS_URI, device="cpu")
# Atau muat tensor menggunakan safe_open.
with sfts.safe_open(OSS_URI, device ="cpu") as f:
metadata = f.metadata() # Dapatkan metadata.
for key in f.keys(): # Baca tensor berdasarkan kunci.
tensor = f.get_tensor(key)