Lớp Dataset
Lớp Dataset trong PyTorch đóng vai trò là giao diện cơ bản cho tập dữ liệu huấn luyện của mô hình. Khi sử dụng lớp Dataset, bạn cần triển khai ba phương thức chính:
__init__: Phương thức khởi tạo, được gọi khi tạo đối tượng. Nhận các tham số cần thiết cho lớp.__getitem__: Phương thức truy xuất dữ liệu, được gọi khi sử dụng toán tử[]để lấy một mẫu dữ liệu và nhãn tương ứng.__len__: Phương thức trả về số lượng mẫu trong tập dữ liệu.
Dưới đây là ví dụ triển khai một lớp Dataset tùy chỉnh:
from torch.utils.data import Dataset
import os
from PIL import Image
# Tải tập dữ liệu từ: https://download.pytorch.org/tutorial/hymenoptera_data.zip
class CustomDataset(Dataset):
def __init__(self, duong_dan_goc, thu_muc_nhan):
# Biến toàn lớp
self.duong_dan_goc = duong_dan_goc
self.thu_muc_nhan = thu_muc_nhan
self.duong_dan = os.path.join(self.duong_dan_goc, self.thu_muc_nhan)
self.duong_dan_anh = os.listdir(self.duong_dan) # Lấy danh sách tất cả tệp trong đường dẫn
def __getitem__(self, chi_so):
ten_anh = self.duong_dan_anh[chi_so] # Lấy ảnh theo chỉ số
duong_dan_anh = os.path.join(self.duong_dan_goc, self.thu_muc_nhan, ten_anh)
anh = Image.open(duong_dan_anh)
nhan = self.thu_muc_nhan
return anh, nhan
def __len__(self):
return len(self.duong_dan_anh)
duong_dan_goc = r'D:\hoc-mai\pytorch\hymenoptera_data\train' # Thay đổi thành đường dẫn của bạn
thu_muc_kien = "ants"
thu_muc_ong = "bees"
tap_kien = CustomDataset(duong_dan_goc, thu_muc_kien)
tap_ong = CustomDataset(duong_dan_goc, thu_muc_ong)
tap_du_lieu_huan_luyen = tap_kien + tap_ong
anh, nhan = tap_du_lieu_huan_luyen[0]
anh.show()
so_luong = len(tap_du_lieu_huan_luyen)
Lớp TensorDataset
TensorDataset là lớp tiện ích đặc biệt trong PyTorch, được thiết kế để làm việc với dữ liệu dưới dạng tensor. Khác với Dataset tùy chỉnh, TensorDataset đã triển khai sẵn các phương thức __getitem__ và __len__, giúp việc sử dụng trở nên đơn giản hơn.
TensorDataset đóng gói hai tensor: data_tensor và label_tensor, và cho phép truy cập từng mẫu bằng chỉ số. Lớp này rất hữu ích khi bạn muốn nhanh chóng tạo tập dữ liệu từ các tensor hiện có.
from torch.utils.data import TensorDataset
import torch
# Định nghĩa các tensor dữ liệu
data_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
label_tensor = torch.tensor([0, 1, 0])
# Tạo đối tượng TensorDataset
tap_du_lieu = TensorDataset(data_tensor, label_tensor)
# Truy cập dữ liệu
print(len(tap_du_lieu)) # Kết quả: 3
print(tap_du_lieu[0]) # Kết quả: (tensor([1, 2]), tensor(0))
Sự khác biệt chính
- Dataset yêu cầu người dùng tự triển khai các phương thức
__getitem__và__len__. Thường được sử dụng để tạo tập dữ liệu tùy chỉnh, phù hợp với các thao tác tiền xử lý và tải dữ liệu phức tạp. - TensorDataset đã triển khai sẵn các phương thức cần thiết, chuyên xử lý dữ liệu dạng tensor. Đóng gói hai tensor
data_tensorvàlabel_tensor, cho phép truy cập mẫu bằng chỉ số. Phù hợp để nhanh chóng tạo tập dữ liệu từ các tensor hiện có.