Tối ưu hóa Tham Số Sử Dụng K-Fold Cross Validation trong Mô Hình Hóa Máy Học với Matlab

Trong quá trình xây dựng mô hình máy học, việc điều chỉnh tham số là một bước không thể bỏ qua. Bài viết này sẽ hướng dẫn sử dụng công cụ箱 mạng nơ-ron của Matlab để thực hiện tối ưu hóa tham số bằng kỹ thuật K-Fold Cross Validation, đặc biệt là chọn số lượng nút ẩn phù hợp.

Bắt đầu bằng cách chuẩn bị dữ liệu thử nghiệm - sử dụng bộ dữ liệu chuẩn đoán ung thư sẵn có trong Matlab:

load breastcancer
dau_vao = breastInputs; 
dich_phai = ind2vec(breastTargets');  % Chuyển đổi thành dạng one-hot

Lưu ý rằng nhãn được chuyển đổi thành dạng vector để phù hợp với hàm mất mát cross-entropy sau này. Đồng thời, chuẩn hóa dữ liệu:

dau_vao = mapminmax(dau_vao, 0, 1);  % Chuẩn hóa về khoảng [0,1]

Tiếp theo, chia dữ liệu thành 10 phần bằng hàm cvpartition:

k = 10;
chia_du_lieu = cvpartition(size(dau_vao,2), 'KFold',k);

Lưu ý rằng cvpartition chia dựa trên chỉ số mẫu, và mỗi cột của dữ liệu đầu vào đại diện cho một mẫu, do đó sử dụng size(dau_vao,2) để lấy tổng số mẫu.

Bây giờ đến phần chính - tìm kiếm lưới (grid search) để xác định số lượng nút ẩn tốt nhất. Hãy thử từ 8 đến 15 nút:

so_nut_anh = 8:15;
mse_trung_binh = zeros(size(so_nut_anh));

for j = 1:length(so_nut_anh)
    loi_theo_fold = zeros(k,1);
    
    for fold = 1:k
        chi_so_huan_luyen = training(chia_du_lieu,fold);
        chi_so_kiem_thu = test(chia_du_lieu,fold);
        
        mang = patternnet(so_nut_anh(j));
        mang.divideFcn = 'dividetrain';  % Cài đặt quan trọng! Tắt tự động phân chia
        mang = train(mang, dau_vao(:,chi_so_huan_luyen), dich_phai(:,chi_so_huan_luyen));
        
        du_doan = mang(dau_vao(:,chi_so_kiem_thu));
        loi_theo_fold(fold) = crossentropy(du_doan, dich_phai(:,chi_so_kiem_thu));
    end
    
    mse_trung_binh(j) = mean(loi_theo_fold);
end

Một điểm dễ gây lỗi ở đoạn mã trên là patternnet tự động chia tập kiểm định, cần thiết lập divideFcn='dividetrain' để tắt chức năng này, nếu không thì quá trình cross-validation sẽ không chính xác. Quá trình huấn luyện sẽ hiển thị thanh tiến trình epoch trên màn hình, đó là quá trình cập nhật trọng số thông qua backpropagation.

Sau khi hoàn thành vòng lặp, vẽ đồ thị đường cong để đánh giá kết quả:

plot(so_nut_anh, mse_trung_binh, 'ro-')
xlabel('Số lượng nút ẩn')
ylabel('Hàm mất mát cross-entropy')
title('Kết quả K-Fold Cross Validation')
grid on

Nếu thấy hàm mất mát ngừng giảm rõ ràng khi số lượng nút tăng lên 12 và bắt đầu dao động, hãy chọn 12 nút - đây là điểm cân bằng giữa độ phức tạp của mô hình và hiệu suất.

Cuối cùng, huấn luyện lại toàn bộ dữ liệu với tham số tối ưu:

mang_tot_nhat = patternnet(12);
mang_tot_nhat = train(mang_tot_nhat, dau_vao, dich_phai);

Khi đó, đối tượng mang_tot_nhat đã được huấn luyện có thể được lưu lại và triển khai trực tiếp vào môi trường sản xuất để dự đoán. Tuy nhiên, cần lưu ý rằng trong dự án thực tế có thể cần áp dụng chiến lược dừng sớm hoặc thêm regularization L2 để ngăn ngừa hiện tượng overfitting.

Toàn bộ quy trình tiêu tốn thời gian nhất chính là 10 x 8 = 80 lần huấn luyện mạng. Nếu kích thước dữ liệu lớn, có thể sử dụng parfor để song song hóa hoặc chuyển sang GPU để huấn luyện - nhưng đó là câu chuyện khác.

Thẻ: MATLAB neural-network k-fold-cross-validation parameter-tuning

Đăng vào ngày 26 tháng 6 lúc 05:01