torch.nn cung cấp đầy đủ các thành phần thiết yếu — từ lớp tuyến tính đến hàm kích hoạt — giúp người dùng xây dựng kiến trúc tùy chỉnh một cách linh hoạt. Mỗi lớp trong PyTorch đều kế thừa từ nn.Module, và chính mạng nơ-ron cũng là một nn.Module chứa các lớp con khác, tạo nên cấu trúc phân cấp rõ ràng và dễ mở rộng.
Dưới đây là ví dụ minh họa việc xây dựng một mạng nơ-ron phân loại ảnh từ bộ dữ liệu FashionMNIST:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Chọn thiết bị huấn luyện
Để tăng tốc độ tính toán, ta ưu tiên sử dụng GPU nếu có sẵn; ngược lại sẽ chuyển về CPU:
execution_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Thiết bị được chọn: {execution_device}")
Định nghĩa kiến trúc mạng
Tạo lớp ImageClassifier kế thừa từ nn.Module. Trong phương thức __init__, ta khai báo các lớp con; trong forward, ta xác định luồng dữ liệu qua các lớp đó:
class ImageClassifier(nn.Module):
def __init__(self, input_dim=784, hidden_dim=512, num_classes=10):
super().__init__()
self.reshape = nn.Flatten()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, batch_input):
reshaped = self.reshape(batch_input)
return self.network(reshaped)
Sau khi khởi tạo và chuyển mô hình sang thiết bị mục tiêu:
net = ImageClassifier().to(execution_device)
print(net)
Kết quả in ra sẽ hiển thị cấu trúc chi tiết gồm lớp làm phẳng và chuỗi các lớp tuyến tính – phi tuyến xen kẽ.
Truyền dữ liệu qua mô hình
Không gọi trực tiếp forward(); thay vào đó, truyền tensor đầu vào như một đối số thông thường:
sample_batch = torch.randn(1, 28, 28).to(execution_device)
raw_output = net(sample_batch)
probabilities = torch.nn.functional.softmax(raw_output, dim=1)
predicted_label = probabilities.argmax(dim=1)
print(f"Nhãn dự đoán: {predicted_label.item()}")
Phân tích từng lớp
Giả sử ta có một batch gồm 3 ảnh kích thước 28×28:
batch_input = torch.randn(3, 28, 28)
print("Kích thước đầu vào:", batch_input.shape)
nn.Flatten(): Chuyển mỗi ảnh 2D thành vector 1D độ dài 784, giữ nguyên chiều batch:
flattener = nn.Flatten() flattened = flattener(batch_input) print("Sau Flatten:", flattened.shape) # torch.Size([3, 784])nn.Linear(): Ánh xạ tuyến tính từ không gian đầu vào sang không gian ẩn:
projector = nn.Linear(784, 64) encoded = projector(flattened) print("Sau Linear:", encoded.shape) # torch.Size([3, 64])nn.ReLU(): Giới thiệu phi tuyến bằng cách đặt mọi giá trị âm về 0:
activated = torch.relu(encoded) print("Sau ReLU:", activated.shape)nn.Sequential(): Gộp nhiều lớp thành một khối tuần tự:
pipeline = nn.Sequential( nn.Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) result = pipeline(batch_input)nn.Softmax(): Chuẩn hóa đầu ra thành phân phối xác suất:
final_probs = torch.nn.functional.softmax(result, dim=1) print("Tổng xác suất mỗi mẫu:", final_probs.sum(dim=1))
Truy cập tham số mô hình
Mọi trọng số và hệ số chệch đều được tự động đăng ký khi kế thừa nn.Module. Ta có thể liệt kê chúng như sau:
for name, tensor in net.named_parameters():
print(f"{name:25} | shape: {list(tensor.shape)}")
Kết quả sẽ liệt kê tên từng tham số (ví dụ: network.0.weight) cùng kích thước tương ứng — hữu ích cho việc kiểm tra, khởi tạo lại hoặc tối ưu hóa riêng lẻ.