Thành phần và Khối trong PyTorch
Môi trường: PyCharm + Python 3.8
1. Khối và Lớp
Khối (block) có thể biểu diễn:
- Một lớp đơn lẻ
- Nhóm nhiều lớp kết hợp
- Toàn bộ mô hình
1.1. Tạo khối tùy chỉnh
Ví dụ về mạng nơ-ron với 2 lớp tuyến tính:
import torch
from torch import nn
model = nn.Sequential(
nn.Linear(20, 256),
nn.ReLU(),
nn.Linear(256, 10))
X = torch.rand(2, 20)
print(f"Đầu vào ngẫu nhiên:\n{X}")
print(f"Kết quả đầu ra:\n{model(X)}")
1.2. Triển khai khối tuần tự
Viết lại Sequential với cấu trúc khác:
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
self._modules[str(idx)] = module
def forward(self, X):
for block in self._modules.values():
X = block(X)
return X
2. Quản lý Tham số
Ví dụ về mạng nơ-ron ẩn:
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
print(f"Đầu vào:\n{X}")
print(f"Kết quả:\n{net(X)}")
2.1. Truy cập Tham số
print(f"Tham số lớp thứ 2:\n{net[2].state_dict()}")
2.2. Khởi tạo Tham số
Khởi tạo trọng số với phân phối chuẩn:
def init_normal(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.zeros_(m.bias)
net.apply(init_normal)
3. Tích Chập và CNN
3.1. Nguyên lý Tích Chập
Triển khai hàm tích chập 2D:
def cross_correlation(X, K):
h, w = K.shape
Y = torch.zeros(X.shape[0]-h+1, X.shape[1]-w+1)
for i in range(Y.shape[0]):
for j in range(Y.shape[1]):
Y[i,j] = (X[i:i+h, j:j+w] * K).sum()
return Y
3.2. Lớp Tích Chập
class Conv2D(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.weight = nn.Parameter(torch.rand(kernel_size))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
return cross_correlation(x, self.weight) + self.bias
4. LeNet
Cài đặt kiến trúc LeNet:
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(400, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10))
5. Sử dụng GPU
def try_gpu(i=0):
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
return torch.device('cpu')
device = try_gpu()
net.to(device)