Tổng quan về Dataset và TensorDataset trong PyTorch

Lớp Dataset

Lớp Dataset trong PyTorch đóng vai trò là giao diện cơ bản cho tập dữ liệu huấn luyện của mô hình. Khi sử dụng lớp Dataset, bạn cần triển khai ba phương thức chính:

  • __init__: Phương thức khởi tạo, được gọi khi tạo đối tượng. Nhận các tham số cần thiết cho lớp.
  • __getitem__: Phương thức truy xuất dữ liệu, được gọi khi sử dụng toán tử [] để lấy một mẫu dữ liệu và nhãn tương ứng.
  • __len__: Phương thức trả về số lượng mẫu trong tập dữ liệu.

Dưới đây là ví dụ triển khai một lớp Dataset tùy chỉnh:

from torch.utils.data import Dataset
import os
from PIL import Image

# Tải tập dữ liệu từ: https://download.pytorch.org/tutorial/hymenoptera_data.zip

class CustomDataset(Dataset):
    def __init__(self, duong_dan_goc, thu_muc_nhan):
        # Biến toàn lớp
        self.duong_dan_goc = duong_dan_goc
        self.thu_muc_nhan = thu_muc_nhan
        self.duong_dan = os.path.join(self.duong_dan_goc, self.thu_muc_nhan)
        self.duong_dan_anh = os.listdir(self.duong_dan)  # Lấy danh sách tất cả tệp trong đường dẫn
    def __getitem__(self, chi_so):
        ten_anh = self.duong_dan_anh[chi_so]  # Lấy ảnh theo chỉ số
        duong_dan_anh = os.path.join(self.duong_dan_goc, self.thu_muc_nhan, ten_anh)
        anh = Image.open(duong_dan_anh)
        nhan = self.thu_muc_nhan
        return anh, nhan
    def __len__(self):
        return len(self.duong_dan_anh)

duong_dan_goc = r'D:\hoc-mai\pytorch\hymenoptera_data\train'  # Thay đổi thành đường dẫn của bạn
thu_muc_kien = "ants"
thu_muc_ong = "bees"
tap_kien = CustomDataset(duong_dan_goc, thu_muc_kien)
tap_ong = CustomDataset(duong_dan_goc, thu_muc_ong)

tap_du_lieu_huan_luyen = tap_kien + tap_ong

anh, nhan = tap_du_lieu_huan_luyen[0]
anh.show()
so_luong = len(tap_du_lieu_huan_luyen)

Lớp TensorDataset

TensorDataset là lớp tiện ích đặc biệt trong PyTorch, được thiết kế để làm việc với dữ liệu dưới dạng tensor. Khác với Dataset tùy chỉnh, TensorDataset đã triển khai sẵn các phương thức __getitem____len__, giúp việc sử dụng trở nên đơn giản hơn.

TensorDataset đóng gói hai tensor: data_tensorlabel_tensor, và cho phép truy cập từng mẫu bằng chỉ số. Lớp này rất hữu ích khi bạn muốn nhanh chóng tạo tập dữ liệu từ các tensor hiện có.

from torch.utils.data import TensorDataset
import torch

# Định nghĩa các tensor dữ liệu
data_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
label_tensor = torch.tensor([0, 1, 0])

# Tạo đối tượng TensorDataset
tap_du_lieu = TensorDataset(data_tensor, label_tensor)

# Truy cập dữ liệu
print(len(tap_du_lieu))  # Kết quả: 3
print(tap_du_lieu[0])    # Kết quả: (tensor([1, 2]), tensor(0))

Sự khác biệt chính

  • Dataset yêu cầu người dùng tự triển khai các phương thức __getitem____len__. Thường được sử dụng để tạo tập dữ liệu tùy chỉnh, phù hợp với các thao tác tiền xử lý và tải dữ liệu phức tạp.
  • TensorDataset đã triển khai sẵn các phương thức cần thiết, chuyên xử lý dữ liệu dạng tensor. Đóng gói hai tensor data_tensorlabel_tensor, cho phép truy cập mẫu bằng chỉ số. Phù hợp để nhanh chóng tạo tập dữ liệu từ các tensor hiện có.

Thẻ: PyTorch Dataset TensorDataset deep learning

Đăng vào ngày 2 tháng 6 lúc 00:35