Mạng Structure-Aware LSTM cho Phát hiện Điểm Giải phẫu 3D

Được công bố trên IEEE Transactions on Medical Imaging (tháng 7 năm 2022), bài báo "Mạng Long Short-Term Memory Nhận thức Cấu trúc cho Phát hiện Điểm Giải phạm 3D" của Chen Runnan và các cộng tác viên đã đề xuất một mạng LSTM nhận thức cấu trúc nhằm nâng cao độ chính xác và khả năng chịu lỗi của việc phát hiện điểm giải phẫu trong chụp ảnh sọ 3D.

1. Tổng quan nghiên cứu

Khía cạnhNội dung chínhVấn đề nghiên cứuTrong phân tích chụp ảnh sọ 3D, việc phát hiện chính xác các điểm giải phạm là yếu tố then chốt, nhưng các phương pháp học sâu truyền thống khó có thể mô hình hóa rõ ràng các mối quan hệ phức tạp giữa đặc điểm thị giác và điểm giải phẫu.Đột phá cốt lõiĐề xuất LSTM nhận thức cấu trúc, mã hóa ràng buộc cấu trúc toàn cầu của các điểm giải phẫu vào biểu diễn khái niệm thị giác một cách tự thích nghi.Công nghệ chủ chốtSử dụng chiến lược hai giai đoạn từ thô đến tinh, và tận dụng mạng LSTM để nắm bắt các mối quan hệ phụ thuộc dài hạn, tăng cường hiểu biết về cấu trúc không gian.Đóng góp chínhĐã được xác minh trên các bộ dữ liệu chuẩn về chụp ảnh sọ 2D và 3D cũng như dữ liệu bệnh nhân thực tế, chứng minh tính ưu việt, khả năng tổng quát và ổn định của phương pháp.

2. Phương pháp chi tiết

Bối cảnh và thách thức: Trong chẩn đoán và lập kế hoạch phẫu thuật chỉnh nha và ngoại sọ mặt, việc xác định chính xác các điểm giải phẫu từ hình ảnh 3D như CBCT là bước quan trọng. Các phương pháp truyền thống chủ yếu dựa vào việc gán nhãn thủ công bởi chuyên gia, tồn tại tính chủ quan cao, tốn thời gian và dễ có sự khác biệt giữa các người thực hiện. Mặc dù học sâu đã được áp dụng trong lĩnh vực này, nhiều phương pháp không thể tận dụng hiệu quả cấu trúc không gian toàn cầu và mối quan hệ hình học cố hữu giữa các điểm giải phẫu, dẫn đến sai lệch đáng kể trong các tình huống phức tạp.

Phương pháp cốt lõi: Trọng tâm của bài báo là mạng LSTM Nhận thức Cấu trúc. Điểm mấu chốt của nó là có thể học rõ ràng các mối liên hệ phức tạp giữa đặc điểm thị giác và điểm giải phẫu, từ đó tích hợp thông tin cấu trúc toàn cầu của các điểm giải phẫu vào biểu diễn đặc điểm. Mạng sử dụng thiết kế hai giai đoạn từ thô đến tinh: giai đoạn đầu tiên thực hiện định vị điểm giải phẩm sơ bộ, thô; giai đoạn thứ hai dựa trên thông tin chi tiết hơn (có thể bao gồm đặc điểm hình ảnh cục bộ và kết quả định vị sơ bộ) để tối ưu hóa, cuối cùng đạt được định vị chính xác cấp dưới milimet.

Kết quả thực nghiệm: Nghiên cứu đã được xác minh trên nhiều bộ dữ liệu chuẩn công cộng và một bộ dữ liệu chứa 150 volume CBCT sọ bệnh nhân thực tế. Kết quả thực nghiệm cho thấy phương pháp vượt trội đáng kể so với các phương pháp tiên tiến khác về độ chính xác và khả năng chịu lỗi. Bài báo đặc biệt nhấn mạnh rằng khung framework này có khả năng tổng quát hóa mạnh mẽ, có thể xử lý các nhiệm vụ phát hiện điểm giải phẩm trong cả kịch bản 2D và 3D trong một mô hình thống nhất.

3. Thông tin bổ sung

Nghiên cứu liên quan: Cùng một đội ngũ nghiên cứu cũng đã công bố một bài báo vào năm 2022 với tiêu đề "Phát hiện điểm giải phám bán giám sát thông qua tự huấn luyện có điều tiết hình dạng" trên tạp chí Neurocomputing, khám phá phát hiện điểm giải phám bán giám sát thông qua tự huấn luyện có điều tiết hình dạng, cho thấy sự khám phá sâu sắc và liên tục của đội ngũ này trong lĩnh vực này.

Ý nghĩa thực tế: Nghiên cứu này thúc đẩy sự phát triển của tự động hóa phân tích chụp ảnh sọ 3D. Kết quả của nó giúp nâng cao đáng kể hiệu quả và độ chính xác của chẩn đoán lâm sàng và lập kế hoạch phẫu thuật, đồng thời giảm bớt gánh nặng gán nhãn thủ công cho các bác sĩ, có ý nghĩa tích cực đối với việc phát triển các hệ thống chẩn đoán hỗ trợ y tế thông minh.

4. Đọc mã nguồn

Mã nguồn: runnanchen/SA-LSTM-3D-Landmark-Detection

Mô hình trong MyModel.py

Mạng phát hiện thô

class CoarseDetectionNet(nn.Module):
    def __init__(self, configuration):
        # landmarkCount, use_gpu, image_scale
        super(CoarseDetectionNet, self).__init__()
        self.landmarkCount = configuration.landmarkCount
        self.gpuEnabled = configuration.use_gpu
        self.imageScale = configuration.image_scale
        self.unet = MNL.UNet3D(1, 64)
        self.conv3d = nn.Sequential(
            nn.Conv3d(64, configuration.landmarkCount, 1, 1, 0),
            nn.Sigmoid(),
        )

    def forward(self, input_data):
        global_features = self.unet(input_data)
        output = self.conv3d(global_features) + 1e-9
        heatmap_sum = torch.sum(output.view(self.landmarkCount, -1), dim=1)
        global_heatmap = [output[0, i, :, :, :].squeeze() / heatmap_sum[i] for i in range(self.landmarkCount)]

        return global_heatmap, global_features

Mạng phát hiện tinh với LSTM

class FineDetectionLSTM(nn.Module):
    def __init__(self, configuration):
        super(FineDetectionLSTM, self).__init__()

        # landmarkCount, use_gpu, iterations, cropSize

        self.landmarkCount = configuration.landmarkCount
        self.gpuEnabled = configuration.use_gpu
        self.encoder = MNL.UNet3DEncoder(1, 64)
        self.iterations = configuration.iteration
        self.cropSize = configuration.crop_size
        self.originalImageSize = configuration.origin_image_size
        self.config = configuration

        width, height, length = self.originalImageSize
        # (576, 768, 768)

        self.sizeTensor = torch.tensor([1 / (length - 1), 1 / (height - 1), 1 / (width - 1)]).cuda(self.gpuEnabled)

        self.offsetDecodersX = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)
        self.offsetDecodersY = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)
        self.offsetDecodersZ = nn.Conv1d(self.landmarkCount, self.landmarkCount, 512 + 64, 1, 0, groups=self.landmarkCount)

        self.sharedAttentionGate = nn.Sequential(
            nn.Linear(512 + 64, 256),
            nn.Tanh(),
        )
        self.attentionGateHead = nn.Conv1d(self.landmarkCount, self.landmarkCount, 256, 1, 0, groups=self.landmarkCount)
        self.graphAttention = MNL.GraphAttention(64, self.gpuEnabled)

    def forward(self, initialLandmarks, groundTruth, originalInput, coarseFeature, phase, sizeTensorInv):

        hiddenState = 0
        predictions = []
        cellState = 0
        prediction = initialLandmarks.detach()

        for i in range(0, self.iterations):
            regionsOfInterest = 0
            if phase == 'train':
                if i == 0:
                    regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=32.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
                elif i == 1:
                    regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=16.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
                else:
                    regionsOfInterest = groundTruth + torch.from_numpy(np.random.normal(loc=0.0, scale=8.0 / self.originalImageSize[2] / 3, size = groundTruth.size())).cuda(self.gpuEnabled).float()
            else:
                regionsOfInterest = prediction

            regionsOfInterest = MyUtils.adjustment(regionsOfInterest, groundTruth)

            croppedItems = MyUtils.getCroppedInputsRelated(regionsOfInterest.detach().cpu().numpy(), groundTruth, originalInput, -1, i, self.config)
            croppedItems = torch.cat([croppedItems[i].cuda(self.gpuEnabled) for i in range(len(croppedItems))], dim=0)
            features = self.encoder(croppedItems).squeeze().unsqueeze(0)

            globalFeature = MyUtils.getGlobalFeature(regionsOfInterest.detach().cpu().numpy(), coarseFeature, self.landmarkCount)
            globalFeature = self.graphAttention(regionsOfInterest, globalFeature)
            features = torch.cat((features, globalFeature), dim=2)
            
            if i == 0:
                hiddenState = features
                cellState = regionsOfInterest
            else:
                forgetGate = self.attentionGateHead(self.sharedAttentionGate(hiddenState.squeeze()).unsqueeze(0))
                attentionGate = self.attentionGateHead(self.sharedAttentionGate(features.squeeze()).unsqueeze(0))
                combinedGate = torch.softmax(torch.cat([forgetGate, attentionGate], dim=2), dim=2)

                hiddenState = hiddenState * combinedGate[0, :, 0].view(1, -1, 1) + features * combinedGate[0, :, 1].view(1, -1, 1)
                cellState = cellState * combinedGate[0, :, 0].view(1, -1, 1) + regionsOfInterest * combinedGate[0, :, 1].view(1, -1, 1)

            x, y, z = self.offsetDecodersX(hiddenState), self.offsetDecodersY(hiddenState), self.offsetDecodersZ(hiddenState)
            prediction = torch.cat([x, y, z], dim=2) * sizeTensorInv + cellState
            predictions.append(prediction.float())

        predictions = torch.cat(predictions, dim=0)

        return predictions

Thẻ: deep-learning medical-imaging lstm 3d-reconstruction anatomical-landmarks

Đăng vào ngày 17 tháng 05 lúc 01:19