Mô hình học sâu kết hợp nhận diện ý định và điền khung dựa trên cơ chế chú ý

Một trong những bài toán nền tảng trong hệ thống hội thoại thông minh là đồng thời xác định ý định tổng thể của câu (intent detection) và trích xuất các thực thể có vai trò ngữ nghĩa cụ thể (slot filling), ví dụ: "Đặt vé máy bay từ Hà Nội đến Đà Nẵng ngày 15/04" → intent = book_flight, slots = {from_city: "Hà Nội", to_city: "Đà Nẵng", date: "15/04"}. Thay vì huấn luyện hai mô hình riêng biệt, cách tiếp cận hiệu quả hơn là xây dựng một kiến trúc chung để tận dụng sự tương quan giữa hai tác vụ — tương tự như mô hình kết hợp NER và relation classification trong trích xuất tri thức.

Mô hình được trình bày dưới đây mở rộng kiến trúc RNN hai chiều với cơ chế chú ý (BiGRU + Attention), được thiết kế đặc biệt cho việc học đồng thời hai đầu ra: một nhãn phân loại toàn câu (intent), và một chuỗi nhãn tuần tự (slots). Kiến trúc không sử dụng attention theo kiểu encoder-decoder cổ điển mà tích hợp trọng số chú ý vào cả hai nhánh xử lý — vừa để tổng hợp biểu diễn toàn cục cho ý định, vừa để điều chỉnh trạng thái ẩn từng bước khi sinh nhãn slot.

Thiết kế mô hình

  • Nhánh ý định: Lấy đầu ra ẩn cuối cùng của BiGRU (ghép trạng thái forward và backward), sau đó áp dụng phép tính trọng số chú ý lên toàn bộ chuỗi ẩn mã hóa — kết quả là một vector ngữ cảnh được tính trung bình có trọng số. Vector này được đưa vào lớp fully-connected để phân loại ý định.
  • Nhánh khung: Sử dụng một GRUCell đơn lẻ hoạt động từng bước theo thời gian. Tại mỗi bước t, đầu vào gồm: (i) biểu diễn ẩn tại vị trí t từ encoder, (ii) embedding của nhãn slot ở bước trước, và (iii) vector ngữ cảnh chú ý được tái tính từ trạng thái ẩn hiện tại và toàn bộ chuỗi mã hóa. Đầu ra của GRUCell được chiếu qua lớp phân lớp để dự đoán nhãn slot tại vị trí đó.
  • Hàm mất mát tổng hợp: Tổng tuyến tính của hai thành phần: cross_entropy(intent_pred, intent_true)cross_entropy(slot_preds.view(-1, num_slots), slot_labels.view(-1)).

Triển khai dữ liệu và mô hình

Sử dụng thư viện torchtext để tiền xử lý dữ liệu ATIS:

from torchtext import data

def tokenize(text):
    return text.split()

SRC = data.Field(
    tokenize=tokenize,
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    pad_token='<pad>',
    unk_token='<unk>',
    batch_first=True,
    fix_length=64,
    include_lengths=True
)

TGT = data.Field(
    tokenize=tokenize,
    lower=True,
    init_token='<sos>',
    eos_token='<eos>',
    pad_token='<pad>',
    unk_token='<unk>',
    batch_first=True,
    fix_length=64,
    include_lengths=True
)

INTENT = data.LabelField(dtype=torch.long, sequential=False)

train_data, val_data = data.TabularDataset.splits(
    path='./atis/',
    train='train.csv',
    validation='test.csv',
    format='csv',
    fields=[('idx', None), ('intent', INTENT), ('utterance', SRC), ('slots', TGT)]
)

SRC.build_vocab(train_data, min_freq=2)
TGT.build_vocab(train_data, min_freq=2)
INTENT.build_vocab(train_data)

Kiến trúc mô hình được viết lại gọn gàng, loại bỏ logic thừa và làm rõ luồng dữ liệu:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContextualAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.att_proj = nn.Linear(hidden_size * 2, hidden_size)
        self.score = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, query, memory):
        # query: [B, H*2], memory: [B, T, H*2]
        B, T, H2 = memory.shape
        query_expanded = query.unsqueeze(1)  # [B, 1, H*2]
        combined = torch.tanh(self.att_proj(torch.cat([query_expanded, memory], dim=2)))
        scores = self.score(combined).squeeze(-1)  # [B, T]
        weights = F.softmax(scores, dim=-1)  # [B, T]
        context = torch.bmm(weights.unsqueeze(1), memory).squeeze(1)  # [B, H*2]
        return context, weights

class JointIntentSlotModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        hidden_size,
        num_intent,
        num_slot,
        slot_embed_dim,
        dropout=0.3,
        pad_idx=1
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.encoder = nn.GRU(
            embed_dim, hidden_size // 2,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if 1 > 1 else 0
        )
        self.attention = ContextualAttention(hidden_size)
        self.slot_embedding = nn.Embedding(num_slot, slot_embed_dim)
        self.decoder_cell = nn.GRUCell(slot_embed_dim + hidden_size * 2, hidden_size)
        self.intent_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_intent)
        )
        self.slot_head = nn.Linear(hidden_size, num_slot)

    def forward(self, x, lengths):
        # x: [B, T], lengths: [B]
        embedded = self.embedding(x)  # [B, T, E]
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=True
        )
        enc_out, h_n = self.encoder(packed)  # enc_out: [B, T, H], h_n: [2, B, H//2]
        enc_out, _ = nn.utils.rnn.pad_packed_sequence(
            enc_out, batch_first=True, padding_value=self.pad_idx
        )

        # Lấy hidden state cuối cùng (forward + backward)
        final_hidden = torch.cat([h_n[0], h_n[1]], dim=1)  # [B, H]

        # Tính ngữ cảnh chú ý cho intent
        intent_context, _ = self.attention(final_hidden, enc_out)  # [B, H]
        intent_logits = self.intent_head(intent_context)  # [B, C_intent]

        # Dự đoán slot từng bước
        B, T = x.shape
        slot_logits = torch.zeros(B, T, self.slot_head.out_features, device=x.device)
        
        # Khởi tạo input token đầu tiên là <sos> (giả sử index = 2)
        prev_slot = torch.full((B,), 2, dtype=torch.long, device=x.device)
        
        # Duyệt từng vị trí
        for t in range(T):
            # Lấy encoder output tại vị trí t
            enc_t = enc_out[:, t, :]  # [B, H]
            # Embedding nhãn trước đó
            slot_emb = self.slot_embedding(prev_slot)  # [B, S]
            # Kết hợp: encoder_t + slot_emb + intent_context
            combined_input = torch.cat([enc_t, slot_emb, intent_context], dim=1)  # [B, H + S + H]
            # Cập nhật trạng thái decoder
            dec_hidden = self.decoder_cell(combined_input, final_hidden)  # [B, H]
            # Dự đoán slot tại t
            slot_logits[:, t, :] = self.slot_head(dec_hidden)  # [B, C_slot]
            # Cập nhật prev_slot cho bước sau
            prev_slot = slot_logits[:, t, :].argmax(dim=1)

        return slot_logits, intent_logits

Mô hình được huấn luyện trong 10 epoch với hàm mất mát tổng hợp. Giai đoạn suy luận (inference) được thực hiện bằng cách chạy vòng lặp giải mã từng bước, bắt đầu từ token khởi tạo và cập nhật liên tục dựa trên dự đoán trước đó.

Thẻ: PyTorch NLP attention-mechanism sequence-labeling intent-detection

Đăng vào ngày 24 tháng 6 lúc 10:18