Để xây dựng một mô hình học sâu cho bài toán phân loại ảnh, bạn cần thực hiện các bước cơ bản: tải dữ liệu, định nghĩa kiến trúc mạng, chọn hàm mất mát và thuật toán tối ưu, sau đó huấn luyện và đánh giá mô hình. Trong hướng dẫn này, chúng ta sẽ sử dụng bộ dữ liệu CIFAR-10 để minh họa quy trình hoàn chỉnh.
1. Tải và chuẩn hóa dữ liệu CIFAR-10
Gói torchvision cung cấp tiện ích để tải các bộ dữ liệu thị giác phổ biến như CIFAR-10, bao gồm 10 lớp: máy bay, ô tô, chim, mèo, hươu, chó, ếch, ngựa, tàu thủy và xe tải. Mỗi ảnh có kích thước 32×32 pixel với 3 kênh màu.
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Để trực quan hóa dữ liệu, ta có thể hiển thị một batch ảnh:
import matplotlib.pyplot as plt
import numpy as np
def show_image(img):
img = img / 2 + 0.5 # chuyển từ [-1,1] về [0,1]
np_img = img.numpy()
plt.imshow(np.transpose(np_img, (1, 2, 0)))
plt.show()
data_iter = iter(trainloader)
images, labels = next(data_iter)
show_image(torchvision.utils.make_grid(images))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
2. Xây dựng mạng nơ-ron tích chập
Mô hình được thiết kế để xử lý ảnh đầu vào 3 kênh:
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
model = SimpleCNN()
3. Thiết lập hàm mất mát và bộ tối ưu
import torch.optim as optim
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
4. Huấn luyện mô hình
for epoch in range(2):
total_loss = 0.0
for i, (inputs, targets) in enumerate(trainloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (i + 1) % 2000 == 0:
print(f'[{epoch + 1}, {i + 1}] loss: {total_loss / 2000:.3f}')
total_loss = 0.0
torch.save(model.state_dict(), './cifar_model.pth')
5. Đánh giá trên tập kiểm thử
Sau khi huấn luyện, ta kiểm tra độ chính xác tổng thể và theo từng lớp:
# Tải lại mô hình (minh họa cách khôi phục)
model = SimpleCNN()
model.load_state_dict(torch.load('./cifar_model.pth'))
# Đánh giá toàn bộ tập test
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total:.0f} %')
# Đánh giá theo từng lớp
class_correct = {cls: 0 for cls in classes}
class_total = {cls: 0 for cls in classes}
with torch.no_grad():
for inputs, labels in testloader:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for label, pred in zip(labels, preds):
cls_name = classes[label]
class_total[cls_name] += 1
if label == pred:
class_correct[cls_name] += 1
for cls in classes:
acc = 100 * class_correct[cls] / class_total[cls]
print(f'Accuracy for {cls:5s}: {acc:.1f} %')
Chạy trên GPU
Để tăng tốc độ huấn luyện, ta có thể di chuyển mô hình và dữ liệu lên GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Trong vòng lặp huấn luyện:
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
# ... phần còn lại giữ nguyên
Lưu ý: Với mô hình nhỏ như trên, lợi ích từ GPU có thể không rõ rệt. Để thấy hiệu quả rõ hơn, hãy mở rộng số lượng kênh trong các lớp tích chập.