本文為您介紹如何使用OSS Python SDK及OSS Python API讀寫OSS資料。
使用建議
如果需要頻繁訪問和處理大規模資料,推薦您將OSS註冊為資料集並掛載;如果只需要臨時訪問OSS資料,或者根據商務邏輯來決定是否訪問OSS,可採用本文的SDK和API方式靈活讀取。
OSS Python SDK
DSW已預裝OSS2 Python包,您可以參見如下方法讀寫OSS資料。
鑒權及初始化。
import oss2 auth = oss2.Auth('<your_AccessKey_ID>', '<your_AccessKey_Secret>') bucket = oss2.Bucket(auth, '<your_oss_endpoint>', '<your_bucket_name>')需要根據實際需要修改以下參數。
參數
描述
<your_AccessKey_ID>、<your_AccessKey_Secret>
阿里雲的AccessKey ID、AccessKey Secret,擷取方式請參見建立AccessKey。
<your_oss_endpoint>
OSS網域名稱。需要根據執行個體的地區選擇對應的OSS網域名稱:
華北2(北京)後付費執行個體:
oss-cn-beijing.aliyuncs.com華北2(北京)預付費執行個體:
oss-cn-beijing-internal.aliyuncs.com華東2(上海)GPU P100執行個體或CPU執行個體:
oss-cn-shanghai.aliyuncs.com華東2(上海)GPU M40執行個體:
oss-cn-shanghai-internal.aliyuncs.com
其他請參見OSS地區和訪問網域名稱。
<your_bucket_name>
Bucket名稱,且開頭不帶oss://。
讀寫OSS資料。
#讀取一個完整檔案。 result = bucket.get_object('<your_file_path/your_file>') print(result.read()) #按Range讀取資料。 result = bucket.get_object('<your_file_path/your_file>', byte_range=(0, 99)) #寫資料至OSS。 bucket.put_object('<your_file_path/your_file>', '<your_object_content>') #對檔案進行Append。 result = bucket.append_object('<your_file_path/your_file>', 0, '<your_object_content>') result = bucket.append_object('<your_file_path/your_file>', result.next_position, '<your_object_content>')其中
<your_file_path/your_file>表示待讀寫的檔案路徑,<your_object_content>表示待Append的內容,需要根據實際情況修改。
OSS Python API
對於PyTorch使用者,DSW提供OSS Python API,用於直接讀寫OSS資料。
您可以在OSS儲存訓練資料或模型:
載入訓練資料
您可以將資料存放在一個OSS Bucket中,並將資料路徑和對應的Label儲存在同一個OSS Bucket的索引檔案中。通過自訂DataSet,在PyTorch中使用
DataLoaderAPI多進程並行讀取資料,樣本如下。import io import oss2 import PIL import torch class OSSDataset(torch.utils.data.dataset.Dataset): def __init__(self, endpoint, bucket, auth, index_file): self._bucket = oss2.Bucket(auth, endpoint, bucket) self._indices = self._bucket.get_object(index_file).read().split(',') def __len__(self): return len(self._indices) def __getitem__(self, index): img_path, label = self._indices(index).strip().split(':') img_str = self._bucket.get_object(img_path) img_buf = io.BytesIO() img_buf.write(img_str.read()) img_buf.seek(0) img = Image.open(img_buf).convert('RGB') img_buf.close() return img, label dataset = OSSDataset(endpoint, bucket, auth, index_file) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_loaders, pin_memory=True)其中
endpoint為OSS網域名稱,bucket為Bucket名稱,auth為鑒權對象,index_file為索引檔案的路徑,都需要根據實際情況修改。說明樣本中,索引檔案格式為每條樣本使用英文逗號(,)分隔,樣本路徑與Label之間使用英文冒號(:)分隔。
Save或Load模型
您可以使用OSS2 Python API Save或Load PyTorch模型(關於PyTorch如何Save或Load模型,詳情請參見PyTorch),樣本如下:
Save模型
from io import BytesIO import torch import oss2 # bucket_name bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO() torch.save(model.state_dict(), buffer) bucket.put_object("<your_model_path>", buffer.getvalue())其中
endpoint為OSS網域名稱,<your_bucket_name>為OSS Bucket名稱,且開頭不帶oss://,auth為鑒權對象,<your_model_path>為模型路徑,都需要根據實際情況修改。Load模型
from io import BytesIO import torch import oss2 bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO(bucket.get_object("<your_model_path>").read()) model.load_state_dict(torch.load(buffer))其中
endpoint為OSS網域名稱,<your_bucket_name>為OSS Bucket名稱,且開頭不帶oss://,auth為鑒權對象,<your_model_path>為模型路徑,都需要根據實際情況修改。