Xây dựng mạng nơ-ron tích chập đơn giản để nhận diện hình ảnh CIFAR10 bằng PyTorch

Thiết lập:

python 3.11.1

pytorch 2.3.0

Chuẩn bị ban đầu


1. Cấu hình GPU


import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

thiet_bi = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(thiet_bi)

2. Nhập dữ liệu


Sử dụng dataset tải xuống tập dữ liệu CIFAR10 và phân chia thành tập huấn luyện và kiểm tra Sử dụng dataloader để tải dữ liệu, thiết lập kích thước batch cơ bản

tap_huanluyen_ds = torchvision.datasets.CIFAR10('./du_lieu',
                                                train=True,
                                                transform=torchvision.transforms.ToTensor(), # Chuyển đổi dạng dữ liệu sang Tensor
                                                download=True)

tap_kiemtra_ds  = torchvision.datasets.CIFAR10('./du_lieu',
                                                train=False,
                                                transform=torchvision.transforms.ToTensor(), # Chuyển đổi dạng dữ liệu sang Tensor
                                                download=True)

kich_thuoc_batch = 32
tap_huanluyen_dl = torch.utils.data.DataLoader(tap_huanluyen_ds,
                                               batch_size=kich_thuoc_batch,
                                               shuffle=True)
tap_kiemtra_dl  = torch.utils.data.DataLoader(tap_kiemtra_ds,
                                               batch_size=kich_thuoc_batch)

# Lấy một batch để xem định dạng dữ liệu
# Định dạng dữ liệu: [kich_thuoc_batch, kenh, chieu_cao, rong]
# Trong đó kich_thuoc_batch tự đặt, kenh, chieu_cao và rong là số kênh, chiều cao và chiều rộng của hình ảnh.
anh, nhan = next(iter(tap_huanluyen_dl))
print(anh.shape)

3. Hiển thị dữ liệu

Hàm squeeze() loại bỏ các chiều có kích thước bằng 1 từ ma trận.

plt.figure(figsize=(20, 5))
for i, anh in enumerate(anh[:20]):
    np_anh = anh.numpy().transpose((1, 2, 0))
    plt.subplot(2, 10, i+1)
    plt.imshow(np_anh, cmap=plt.cm.binary)
    plt.axis('off')
plt.show()

Có thể hiển thị nhãn cùng với hình ảnh

fig = plt.figure()
danh_muc=['máy bay','ô tô','chim','mèo','nai','chó','ếch','ngựa','tàu thủy','xe tải']
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.tight_layout()
    (_, nhan) = tap_huanluyen_ds[i]
    plt.imshow(tap_huanluyen_dl.dataset.data[i],cmap=plt.cm.binary)
    plt.title("Nhãn: {}".format(danh_muc[nhan]))
    plt.xticks([])
    plt.yticks([])
plt.show()

Xây dựng mạng CNN đơn giản

Mạng CNN thường bao gồm mạng trích xuất đặc trưng và mạng phân loại, trong đó mạng trích xuất đặc trưng dùng để trích xuất đặc trưng của hình ảnh, mạng phân loại dùng để phân loại hình ảnh.

1. Chi tiết về torch.nn.Conv2d()

Nguyên mẫu hàm: torch.nn.Conv2d(kenh_vao, ken_ra, kich_thuoc_bo_loc, buoc_nhảy=1, day_viền=0, do_loãng=1, nhom=1, bias=True, chế_độ_viền='zeros', thiet_bi=None, loai_du_lieu=None)

2. Chi tiết về torch.nn.Linear()

Nguyên mẫu hàm: torch.nn.Linear(so_tinh_nhap, so_tinh_xuat, bias=True, thiet_bi=None, loai_du_lieu=None)

3. Chi tiết về torch.nn.MaxPool2d()

Nguyên mẫu hàm: torch.nn.MaxPool2d(kich_thuoc_bo_loc, buoc_nhảy=None, day_viền=0, do_loãng=1, tra_ve_chi_so=False, che_do_tran=False)

4. Tính toán lớp tích chập và lớp pooling

Lớp tích chập:

Kích thước bản đồ đặc trưng đầu vào (H, W) Kích thước bộ lọc tích chập (kH, kW) Bước di chuyển (sH, sW) Đệm (pH, pW) Kích thước bản đồ đặc trưng đầu ra (H’, W’) Công thức tính: H’ = (H + 2pH - kH) / sH + 1 W’ = (W + 2pW - kW) / sW + 1 Lớp pooling:

Kích thước bản đồ đặc trưng đầu vào (H, W) Kích thước cửa sổ pooling (kH, kW) Bước di chuyển (sH, sW) Đệm (pH, pW) Kích thước bản đồ đặc trưng đầu ra (H’, W’) Công thức tính: H’ = (H - kH) / sH + 1 W’ = (W - kW) / sW + 1

so_loai = 10  # Số lượng loại hình ảnh

class MoHinh(nn.Module):
    def __init__(self):
        super().__init__()
        self.bo_loc1 = nn.Conv2d(3, 64, kernel_size=3)   # Lớp tích chập đầu tiên, kích thước bộ lọc 3x3
        self.giam_kich_thuoc1 = nn.MaxPool2d(kernel_size=2)       # Thiết lập lớp giảm kích thước, kích thước bộ lọc 2x2
        self.bo_loc2 = nn.Conv2d(64, 64, kernel_size=3)  # Lớp tích chập thứ hai, kích thước bộ lọc 3x3
        self.giam_kich_thuoc2 = nn.MaxPool2d(kernel_size=2)
        self.bo_loc3 = nn.Conv2d(64, 128, kernel_size=3) # Lớp tích chập thứ ba, kích thước bộ lọc 3x3
        self.giam_kich_thuoc3 = nn.MaxPool2d(kernel_size=2)

        self.phan_loai1 = nn.Linear(512, 256)
        self.phan_loai2 = nn.Linear(256, so_loai)
        
    def tien_truyen(self, x):
        x = self.giam_kich_thuoc1(F.relu(self.bo_loc1(x)))
        x = self.giam_kich_thuoc2(F.relu(self.bo_loc2(x)))
        x = self.giam_kich_thuoc3(F.relu(self.bo_loc3(x)))

        x = torch.flatten(x, start_dim=1)

        x = F.relu(self.phan_loai1(x))
        x = self.phan_loai2(x)
        return x

Tải và hiển thị mô hình

mo_hinh = MoHinh().to(thiet_bi)
summary(mo_hinh)

Hiển thị thông tin mô hình

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
MoHinh                                    --
├─Conv2d: 1-1                            1,792
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            36,928
├─MaxPool2d: 1-4                         --
├─Conv2d: 1-5                            73,856
├─MaxPool2d: 1-6                         --
├─Linear: 1-7                            131,328
├─Linear: 1-8                            2,570
=================================================================
Total params: 246,474
Trainable params: 246,474
Non-trainable params: 0
=================================================================

Huấn luyện mô hình


1. Thiết lập tham số

ham_mat = nn.CrossEntropyLoss()  # Tạo hàm mất mát
toc_do_hoc = 1e-2  # Tốc độ học
toi_uu = torch.optim.SGD(mo_hinh.parameters(), lr=toc_do_hoc)

2. Viết hàm huấn luyện

def huan_luyen(tap_du_lieu, mo_hinh, ham_mat, toi_uu):
    kich_thuoc = len(tap_du_lieu.dataset)  # Kích thước tập huấn luyện, tổng cộng 60000 hình ảnh
    so_batch = len(tap_du_lieu)   # Số lượng batch, 1875 (60000/32)

    mat, do_chinh_xac = 0, 0  # Khởi tạo tổn thất và độ chính xác

    for anh, nhan in tap_du_lieu:
        anh, nhan = anh.to(thiet_bi), nhan.to(thiet_bi)

        du_bao = mo_hinh(anh)          # Đầu ra của mạng
        loi = ham_mat(du_bao, nhan)  # Tính toán sai lệch giữa đầu ra và giá trị thực tế

        toi_uu.zero_grad()  # Đặt gradient về 0
        loi.backward()      # Lan truyền ngược
        toi_uu.step()       # Cập nhật tham số

        do_chinh_xac += (du_bao.argmax(1) == nhan).type(torch.float).sum().item()
        mat += loi.item()

    do_chinh_xac /= kich_thuoc
    mat /= so_batch

    return do_chinh_xac, mat

3. Viết hàm kiểm thử

def kiem_thu (tap_du_lieu, mo_hinh, ham_mat):
    kich_thuoc        = len(tap_du_lieu.dataset)  # Kích thước tập kiểm thử, tổng cộng 10000 hình ảnh
    so_batch = len(tap_du_lieu)          # Số lượng batch, 313 (10000/32=312.5, làm tròn lên)
    mat, do_chinh_xac = 0, 0

    with torch.no_grad():
        for anh, muc_tieu in tap_du_lieu:
            anh, muc_tieu = anh.to(thiet_bi), muc_tieu.to(thiet_bi)

            du_bao = mo_hinh(anh)
            loi        = ham_mat(du_bao, muc_tieu)

            mat += loi.item()
            do_chinh_xac  += (du_bao.argmax(1) == muc_tieu).type(torch.float).sum().item()

    do_chinh_xac  /= kich_thuoc
    mat /= so_batch

    return do_chinh_xac, mat

4. Tiến hành huấn luyện

if os.path.exists(duong_dan) is not True:
    so_lan_lap     = 10
    tap_luu_mat = []
    tap_luu_do_chinh_xac  = []
    tap_luu_mat_kiem_thu  = []
    tap_luu_do_chinh_xac_kiem_thu   = []
    for lan_lap in range(so_lan_lap):
        mo_hinh.train()
        do_chinh_xac_lan, mat_lan = huan_luyen(tap_huanluyen_dl, mo_hinh, ham_mat, toi_uu)

        mo_hinh.eval()
        do_chinh_xac_kiem_thu_lan, mat_kiem_thu_lan = kiem_thu(tap_kiemtra_dl, mo_hinh, ham_mat)

        tap_luu_do_chinh_xac.append(do_chinh_xac_lan)
        tap_luu_mat.append(mat_lan)
        tap_luu_do_chinh_xac_kiem_thu.append(do_chinh_xac_kiem_thu_lan)
        tap_luu_mat_kiem_thu.append(mat_kiem_thu_lan)

        mau_in = ('Lần lặp:{:2d}, Độ chính xác huấn luyện:{:.1f}%, Tổn thất huấn luyện:{:.3f}, Độ chính xác kiểm thử:{:.1f}%,Tổn thất kiểm thử:{:.3f}')
        print(mau_in.format(lan_lap+1, do_chinh_xac_lan*100, mat_lan, do_chinh_xac_kiem_thu_lan*100, mat_kiem_thu_lan))
    print('Xong')
    torch.save(mo_hinh.state_dict(), duong_dan)

    warnings.filterwarnings("ignore")               
    plt.rcParams['font.sans-serif']    = ['SimHei'] 
    plt.rcParams['axes.unicode_minus'] = False      
    plt.rcParams['figure.dpi']         = 100        

    khoang_lan_lap = range(so_lan_lap)

    plt.figure(figsize=(12, 3))
    plt.subplot(1, 2, 1)

    plt.plot(khoang_lan_lap, tap_luu_do_chinh_xac, label='Độ chính xác huấn luyện')
    plt.plot(khoang_lan_lap, tap_luu_do_chinh_xac_kiem_thu, label='Độ chính xác kiểm thử')
    plt.legend(loc='lower right')
    plt.title('Độ chính xác huấn luyện và kiểm thử')

    plt.subplot(1, 2, 2)
    plt.plot(khoang_lan_lap, tap_luu_mat, label='Tổn thất huấn luyện')
    plt.plot(khoang_lan_lap, tap_luu_mat_kiem_thu, label='Tổn thất kiểm thử')
    plt.legend(loc='upper right')
    plt.title('Tổn thất huấn luyện và kiểm thử')
    plt.show()

Kết quả hiển thị

Epoch: 1, Do chinh xac:14.9%, Toi da mat:2.282, Do chinh xac kiem thu:19.6%,Toi da mat kiem thu:2.163
Epoch: 2, Do chinh xac:25.2%, Toi da mat:2.002, Do chinh xac kiem thu:30.4%,Toi da mat kiem thu:1.877
Epoch: 3, Do chinh xac:35.0%, Toi da mat:1.778, Do chinh xac kiem thu:37.6%,Toi da mat kiem thu:1.730
Epoch: 4, Do chinh xac:40.7%, Toi da mat:1.624, Do chinh xac kiem thu:40.7%,Toi da mat kiem thu:1.635
Epoch: 5, Do chinh xac:44.7%, Toi da mat:1.516, Do chinh xac kiem thu:46.6%,Toi da mat kiem thu:1.457
Epoch: 6, Do chinh xac:48.7%, Toi da mat:1.420, Do chinh xac kiem thu:50.1%,Toi da mat kiem thu:1.371
Epoch: 7, Do chinh xac:52.1%, Toi da mat:1.334, Do chinh xac kiem thu:52.5%,Toi da mat kiem thu:1.325
Epoch: 8, Do chinh xac:55.1%, Toi da mat:1.262, Do chinh xac kiem thu:53.5%,Toi da mat kiem thu:1.314
Epoch: 9, Do chinh xac:57.4%, Toi da mat:1.198, Do chinh xac kiem thu:54.1%,Toi da mat kiem thu:1.276
Epoch:10, Do chinh xac:60.0%, Toi da mat:1.142, Do chinh xac kiem thu:58.9%,Toi da mat kiem thu:1.161
Xong

Kiểm tra kết quả phân loại


# Dự đoán
mo_hinh.eval()

# Định nghĩa hàm tiền xử lý hình ảnh
def xu_ly_truoc_anh(duong_dan_anh, thiet_bi):
    anh = Image.open(duong_dan_anh)
    anh = anh.convert('RGB')
    bien_doi = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                                torchvision.transforms.ToTensor()])
    tensor_anh = bien_doi(anh)
    tensor_anh = torch.reshape(tensor_anh, (1, 3, 32, 32))
    return tensor_anh

duong_dan_anh = ['../kiem_thu/meo.png', '../kiem_thu/cho.png', '../kiem_thu/oto.png', '../kiem_thu/oto1.png', '../kiem_thu/oto2.png']
for i, duong_dan in enumerate(duong_dan_anh):
    anh = xu_ly_truoc_anh(duong_dan, thiet_bi)
    ket_qua = mo_hinh(anh)
    print('ket_qua', "Nhãn: {}".format(danh_muc[ket_qua.argmax(1)]))

    plt.subplot(1, len(duong_dan_anh), i + 1)
    anh = cv2.imread(duong_dan, cv2.IMREAD_COLOR)
    anh = cv2.cvtColor(anh, cv2.COLOR_BGR2RGB)
    plt.imshow(anh)
    plt.axis('off')
    plt.title("Nhãn: {}".format(danh_muc[ket_qua.argmax(1)]))
plt.show()

Thẻ: PyTorch CNN image_classification cifar10

Đăng vào ngày 16 tháng 6 lúc 08:21