This topic describes how to use Object Storage Service (OSS) SDK for Python, OSS IO for TensorFlow, and OSS API for PyTorch to read data from and write data to OSS.
Background information
OSS is a secure, cost-effective, and highly reliable cloud storage service provided by Alibaba Cloud. It enables you to store a large amount of data in the cloud. By default, Data Science Workshop (DSW) instances are attached with Network Attached Storage (NAS) file systems. You can also use DSW instances with OSS if you require larger storage.
OSS SDK for Python
In most cases, you can use OSS SDK for Python to read data from and write data to OSS. For more information, see . DSW has preinstalled OSS2 for Python packages. The following code block describes how to read data from or write data to OSS:
OSS IO for TensorFlow
DSW provides the tensorflow_io module. TensorFlow users can use the module to read data from and write data to OSS. Therefore, users do not need to frequently copy data or model files when training models in TensorFlow.
OSS API for Python
For PyTorch users, DSW provides OSS API for Python to read data from and write data to OSS.
- Load training data
You can store training data in an OSS bucket. The path and labels of the data must be stored in an index file in the same OSS bucket. You can customize DataSet and call the
DataLoader
API in PyTorch to read data through multiple threads in parallel. The following code block shows an example:import io import oss2 import PIL import torch class OSSDataset(torch.utils.data.dataset.Dataset): def __init__(self, endpoint, bucket, auth, index_file): self._bucket = oss2.Bucket(auth, endpoint, bucket) self._indices = self._bucket.get_object(index_file).read().split(',') def __len__(self): return len(self._indices) def __getitem__(self, index): img_path, label = self._indices(index).strip().split(':') img_str = self._bucket.get_object(img_path) img_buf = io.BytesIO() img_buf.write(img_str.read()) img_buf.seek(0) img = Image.open(img_buf).convert('RGB') img_buf.close() return img, label dataset = OSSDataset(endpoint, bucket, index_file) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_loaders, pin_memory=True)
endpoint
indicates the endpoint of OSS.bucket
indicates the name of the OSS bucket.auth
indicates the objects that are authenticated.index_file
indicates the path of the index file. Set the parameters based on your requirements.Note In this topic, samples in the index file are separated with commas (,). The sample path and labels are separated with colons (:). - Write logs to OSS
You can compile a StreamHandler to write log data to OSS.Note You cannot write a log through multiple threads in parallel.
import oss2 import logging class OSSLoggingHandler(logging.StreamHandler): def __init__(self, endpoint, bucket, auth, log_file): OSSLoggingHandler.__init__(self) self._bucket = oss2.Bucket(auth, endpoint, bucket) self._log_file = log_file self._pos = self._bucket.append_object(self._log_file, 0, '') def emit(self, record): msg = self.format(record) self._pos = self._bucket.append_object(self._log_file, self._pos.next_position, msg) oss_handler = OSSLoggingHandler(endpoint, bucket, log_file) logging.basicConfig( stream=oss_handler, format='[%(asctime)s] [%(levelname)s] [%(process)d#%(threadName)s] ' + '[%(filename)s:%(lineno)d] %(message)s', level=logging.INFO)
endpoint
indicates the endpoint of OSS.bucket
indicates the name of the OSS bucket.auth
indicates the objects that are authenticated.log_file
indicates the path where you want to store the log file. Set the parameters based on your requirements. - Save or load models
You can use OSS2 API for Python to save or load PyTorch models. For more information about how to save or load models by using PyTorch, see PyTorch.
- Save a model
from io import BytesIO import torch import oss2 # bucket_name bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO() torch.save(model.state_dict(), buffer) bucket.put_object("<your_model_path>", buffer.getvalue())
endpoint
indicates the endpoint of OSS.bucket
indicates the name of the OSS bucket. It cannot start with oss://.auth
indicates the objects that are authenticated.<your_model_path>
indicates the path where you want to store the model. Set the parameters based on your requirements. - Load a model
from io import BytesIO import torch import oss2 bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO(bucket.get_object("<your_model_path>").read()) model.load_state_dict(torch.load(buffer))
endpoint
indicates the endpoint of OSS.bucket
indicates the name of the OSS bucket. It cannot start with oss://.auth
indicates the objects that are authenticated.log_file
indicates the path where the model is stored. Set the parameters based on your requirements.
- Save a model