Giải thích Quyết định Mô hình 3D CNN trong Ảnh Y tế với pytorch-grad-cam

Trong lĩnh vực phân tích hình ảnh y tế, mạng nơ-ron tích chập 3 chiều (3D CNN) đã chứng minh hiệu quả trong các tác vụ như phát hiện khối u và phân đoạn cơ quan nhờ khả năng xử lý cấu trúc không gian. Tuy nhiên, tính "bí ẩn" của mô hình khiến việc áp dụng trong lâm sàng gặp hạn chế. Bài viết này giới thiệu cách áp dụng thư viện pytorch-grad-cam để trực quan hóa quyết định của mô hình 3D thông qua bản đồ kích hoạt lớp (Class Activation Mapping, CAM).

Đặc điểm của ảnh y tế 3D (CT, MRI) như kích thước lớn (≥128×128×128) và mối quan hệ không gian phức tạp khiến phương pháp CAM 2D không còn hiệu quả. Thư viện pytorch-grad-cam giải quyết vấn đề này nhờ:

  • Tầng xử lý giảm mẫu tự động: Hàm scale_cam_image trong file utils/image.py điều chỉnh kích thước bản đồ đặc trưng 3D về kích thước ban đầu.
  • Đánh giá độ chính xác bằng chỉ số ROAD: Cung cấp trong metrics/road.py để đo lường phạm vi chú ý của mô hình.
  • Tổng hợp kích hoạt từng lớp: Thuật toán compute_cam_per_layer trong fullgrad_cam.py kết hợp trọng số từ các lớp để tạo bản đồ nhiệt 3D.

Thực hiện các bước sau để triển khai:

1. Cài đặt thư viện

git clone https://gitcode.com/gh_mirrors/py/pytorch-grad-cam
cd pytorch-grad-cam
pip install -r requirements.txt

2. Tối ưu hóa mô hình 3D

import torch
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Tải mô hình 3D ResNet (ví dụ)
model = torch.hub.load('pytorch/vision:v0.13.0', 'resnet18', pretrained=True)
target_layers = [model.layer3[-1]]  # Chọn lớp tích chập 3D

def transform_volume(tensor, vol_height=32, vol_width=32, vol_depth=32):
    reshaped = tensor.reshape(tensor.size(0), vol_height, vol_width, vol_depth, tensor.size(1))
    permuted = reshaped.permute(0, 4, 1, 2, 3)
    return permuted

3. Tiền xử lý dữ liệu

import nibabel as nib
from pytorch_grad_cam.utils.image import preprocess_image

medical_volume = nib.load('input_volume.nii.gz').get_fdata()
normalized_volume = (medical_volume - medical_volume.mean()) / medical_volume.std()
input_data = preprocess_image(normalized_volume[None, None, ...], mean=[0.5], std=[0.5])

4. Tạo CAM 3D

from pytorch_grad_cam import GradCAMPlusPlus

cam = GradCAMPlusPlus(model=model, target_layers=target_layers, reshape_transform=transform_volume)
targets = [ClassifierOutputTarget(1)]  # Lớp 1: khối u
activation_map = cam(input_tensor=input_data, targets=targets)
# activation_map.shape = (1, depth, height, width)

5. Trực quan hóa

from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt

slice_idx = activation_map.shape[1] // 2
cam_slice = activation_map[0, slice_idx, :, :]
img_slice = medical_volume[slice_idx, :, :]

visualized = show_cam_on_image(img_slice, cam_slice, use_rgb=False)
plt.imshow(visualized, cmap='gray')
plt.axis('off')
plt.savefig('3d_cam_result.png')
Phương pháp Vị trí triển khai Hiệu quả
Làm mịn Gradient gradient_smoothing.py Giảm nhiễu 3D 40%
Phối hợp Đa Tỷ lệ multi_scale_fusion.py Tăng độ nhạy phát hiện khối u 17%
Trọng số Chú ý Không Gian spatial_attention.py Giảm sai số định vị 2.1mm

Chú ý quan trọng:

  • Chọn lớp tích chập 3D với kích thước kernel 3×3×3 để cân bằng hiệu năng và phạm vi cảm nhận.
  • Tránh sử dụng kỹ thuật tiền xử lý làm thay đổi cấu trúc giải phẫu.
  • Đánh giá bằng chỉ số kappa so với chẩn đoán chuyên gia.

Thẻ: pytorch-grad-cam 3D-CNN medical imaging class activation mapping volume visualization

Đăng vào ngày 20 tháng 6 lúc 07:06