Được công bố trên IEEE Transactions on Medical Imaging (tháng 7 năm 2022), bài báo "Mạng Long Short-Term Memory Nhận thức Cấu trúc cho Phát hiện Điểm Giải phạm 3D" của Chen Runnan và các cộng tác viên đã đề xuất một mạng LSTM nhận thức cấu trúc nhằm nâng cao độ chính xác và khả năng chịu lỗi của việc phát hiện điểm giải phẫu trong chụp ảnh sọ 3D.
1. Tổng quan nghiên cứu
Khía cạnhNội dung chínhVấn đề nghiên cứuTrong phân tích chụp ảnh sọ 3D, việc phát hiện chính xác các điểm giải phạm là yếu tố then chốt, nhưng các phương pháp học sâu truyền thống khó có thể mô hình hóa rõ ràng các mối quan hệ phức tạp giữa đặc điểm thị giác và điểm giải phẫu.Đột phá cốt lõiĐề xuất LSTM nhận thức cấu trúc, mã hóa ràng buộc cấu trúc toàn cầu của các điểm giải phẫu vào biểu diễn khái niệm thị giác một cách tự thích nghi.Công nghệ chủ chốtSử dụng chiến lược hai giai đoạn từ thô đến tinh, và tận dụng mạng LSTM để nắm bắt các mối quan hệ phụ thuộc dài hạn, tăng cường hiểu biết về cấu trúc không gian.Đóng góp chínhĐã được xác minh trên các bộ dữ liệu chuẩn về chụp ảnh sọ 2D và 3D cũng như dữ liệu bệnh nhân thực tế, chứng minh tính ưu việt, khả năng tổng quát và ổn định của phương pháp.
2. Phương pháp chi tiết
Bối cảnh và thách thức: Trong chẩn đoán và lập kế hoạch phẫu thuật chỉnh nha và ngoại sọ mặt, việc xác định chính xác các điểm giải phẫu từ hình ảnh 3D như CBCT là bước quan trọng. Các phương pháp truyền thống chủ yếu dựa vào việc gán nhãn thủ công bởi chuyên gia, tồn tại tính chủ quan cao, tốn thời gian và dễ có sự khác biệt giữa các người thực hiện. Mặc dù học sâu đã được áp dụng trong lĩnh vực này, nhiều phương pháp không thể tận dụng hiệu quả cấu trúc không gian toàn cầu và mối quan hệ hình học cố hữu giữa các điểm giải phẫu, dẫn đến sai lệch đáng kể trong các tình huống phức tạp.
Phương pháp cốt lõi: Trọng tâm của bài báo là mạng LSTM Nhận thức Cấu trúc. Điểm mấu chốt của nó là có thể học rõ ràng các mối liên hệ phức tạp giữa đặc điểm thị giác và điểm giải phẫu, từ đó tích hợp thông tin cấu trúc toàn cầu của các điểm giải phẫu vào biểu diễn đặc điểm. Mạng sử dụng thiết kế hai giai đoạn từ thô đến tinh: giai đoạn đầu tiên thực hiện định vị điểm giải phẩm sơ bộ, thô; giai đoạn thứ hai dựa trên thông tin chi tiết hơn (có thể bao gồm đặc điểm hình ảnh cục bộ và kết quả định vị sơ bộ) để tối ưu hóa, cuối cùng đạt được định vị chính xác cấp dưới milimet.
Kết quả thực nghiệm: Nghiên cứu đã được xác minh trên nhiều bộ dữ liệu chuẩn công cộng và một bộ dữ liệu chứa 150 volume CBCT sọ bệnh nhân thực tế. Kết quả thực nghiệm cho thấy phương pháp vượt trội đáng kể so với các phương pháp tiên tiến khác về độ chính xác và khả năng chịu lỗi. Bài báo đặc biệt nhấn mạnh rằng khung framework này có khả năng tổng quát hóa mạnh mẽ, có thể xử lý các nhiệm vụ phát hiện điểm giải phẩm trong cả kịch bản 2D và 3D trong một mô hình thống nhất.
3. Thông tin bổ sung
Nghiên cứu liên quan: Cùng một đội ngũ nghiên cứu cũng đã công bố một bài báo vào năm 2022 với tiêu đề "Phát hiện điểm giải phám bán giám sát thông qua tự huấn luyện có điều tiết hình dạng" trên tạp chí Neurocomputing, khám phá phát hiện điểm giải phám bán giám sát thông qua tự huấn luyện có điều tiết hình dạng, cho thấy sự khám phá sâu sắc và liên tục của đội ngũ này trong lĩnh vực này.
Ý nghĩa thực tế: Nghiên cứu này thúc đẩy sự phát triển của tự động hóa phân tích chụp ảnh sọ 3D. Kết quả của nó giúp nâng cao đáng kể hiệu quả và độ chính xác của chẩn đoán lâm sàng và lập kế hoạch phẫu thuật, đồng thời giảm bớt gánh nặng gán nhãn thủ công cho các bác sĩ, có ý nghĩa tích cực đối với việc phát triển các hệ thống chẩn đoán hỗ trợ y tế thông minh.
4. Đọc mã nguồn
Mã nguồn: runnanchen/SA-LSTM-3D-Landmark-Detection
Mô hình trong MyModel.py
Mạng phát hiện thô
class CoarseDetectionNet(nn.Module):
def __init__(self, configuration):
# landmarkCount, use_gpu, image_scale
super(CoarseDetectionNet, self).__init__()
self.landmarkCount = configuration.landmarkCount
self.gpuEnabled = configuration.use_gpu
self.imageScale = configuration.image_scale
self.unet = MNL.UNet3D(1, 64)
self.conv3d = nn.Sequential(
nn.Conv3d(64, configuration.landmarkCount, 1, 1, 0),
nn.Sigmoid(),
)
def forward(self, input_data):
global_features = self.unet(input_data)
output = self.conv3d(global_features) + 1e-9
heatmap_sum = torch.sum(output.view(self.landmarkCount, -1), dim=1)
global_heatmap = [output[0, i, :, :, :].squeeze() / heatmap_sum[i] for i in range(self.landmarkCount)]
return global_heatmap, global_features
Mạng phát hiện tinh với LSTM
class FineDetectionLSTM(nn.Module):
def __init__(self, configuration):
super(FineDetectionLSTM, self).__init__()
# landmarkCount, use_gpu, iterations, cropSize
self.landmarkCount = configuration.landmarkCount
self.gpuEnabled = configuration.use_gpu
self.encoder = MNL.UNet3DEncoder(1, 64)
self.iterations = configuration.iteration
self.cropSize = configuration.crop_size
self.originalImageSize = configuration.origin_image_size
self.config = configuration
width, height, length = self.originalImageSize
# (576, 768, 768)
self.sizeTensor = torch.tensor([1 / (length - 1), 1 / (height - 1), 1 / (width - 1)]).cuda(self.gpuEnabled)
self.offsetDecodersX = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)
self.offsetDecodersY = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)
self.offsetDecodersZ = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)
self.sharedAttentionGate = nn.Sequential(
nn.Linear(512 + 64, 256),
nn.Tanh(),
)
self.attentionGateHead = nn.Conv1d(self.landmarkCount, self.landmarkCount, 256, 1, 0, groups=self.landmarkCount)
self.graphAttention = MNL.GraphAttention(64, self.gpuEnabled)
def forward(self, initialLandmarks, groundTruth, originalInput, coarseFeature, phase, sizeTensorInv):
hiddenState = 0
predictions = []
cellState = 0
prediction = initialLandmarks.detach()
for i in range(0, self.iterations):
regionsOfInterest = 0
if phase == 'train':
if i == 0:
regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=32.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
elif i == 1:
regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=16.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
else:
regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=8.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
else:
regionsOfInterest = prediction
regionsOfInterest = MyUtils.adjustment(regionsOfInterest, groundTruth)
croppedItems = MyUtils.getCroppedInputsRelated(regionsOfInterest.detach().cpu().numpy(), groundTruth, originalInput, -1, i, self.config)
croppedItems = torch.cat([croppedItems[i].cuda(self.gpuEnabled) for i in range(len(croppedItems))], dim=0)
features = self.encoder(croppedItems).squeeze().unsqueeze(0)
globalFeature = MyUtils.getGlobalFeature(regionsOfInterest.detach().cpu().numpy(), coarseFeature, self.landmarkCount)
globalFeature = self.graphAttention(regionsOfInterest, globalFeature)
features = torch.cat((features, globalFeature), dim=2)
if i == 0:
hiddenState = features
cellState = regionsOfInterest
else:
forgetGate = self.attentionGateHead(self.sharedAttentionGate(hiddenState.squeeze()).unsqueeze(0))
attentionGate = self.attentionGateHead(self.sharedAttentionGate(features.squeeze()).unsqueeze(0))
combinedGate = torch.softmax(torch.cat([forgetGate, attentionGate], dim=2), dim=2)
hiddenState = hiddenState * combinedGate[0, :, 0].view(1, -1, 1) + features * combinedGate[0, :, 1].view(1, -1, 1)
cellState = cellState * combinedGate[0, :, 0].view(1, -1, 1) + regionsOfInterest * combinedGate[0, :, 1].view(1, -1, 1)
x, y, z = self.offsetDecodersX(hiddenState), self.offsetDecodersY(hiddenState), self.offsetDecodersZ(hiddenState)
prediction = torch.cat([x, y, z], dim=2) * sizeTensorInv + cellState
predictions.append(prediction.float())
predictions = torch.cat(predictions, dim=0)
return predictions