Một trong những bài toán nền tảng trong hệ thống hội thoại thông minh là đồng thời xác định ý định tổng thể của câu (intent detection) và trích xuất các thực thể có vai trò ngữ nghĩa cụ thể (slot filling), ví dụ: "Đặt vé máy bay từ Hà Nội đến Đà Nẵng ngày 15/04" → intent = book_flight, slots = {from_city: "Hà Nội", to_city: "Đà Nẵng", date: "15/04"}. Thay vì huấn luyện hai mô hình riêng biệt, cách tiếp cận hiệu quả hơn là xây dựng một kiến trúc chung để tận dụng sự tương quan giữa hai tác vụ — tương tự như mô hình kết hợp NER và relation classification trong trích xuất tri thức.
Mô hình được trình bày dưới đây mở rộng kiến trúc RNN hai chiều với cơ chế chú ý (BiGRU + Attention), được thiết kế đặc biệt cho việc học đồng thời hai đầu ra: một nhãn phân loại toàn câu (intent), và một chuỗi nhãn tuần tự (slots). Kiến trúc không sử dụng attention theo kiểu encoder-decoder cổ điển mà tích hợp trọng số chú ý vào cả hai nhánh xử lý — vừa để tổng hợp biểu diễn toàn cục cho ý định, vừa để điều chỉnh trạng thái ẩn từng bước khi sinh nhãn slot.
Thiết kế mô hình
- Nhánh ý định: Lấy đầu ra ẩn cuối cùng của BiGRU (ghép trạng thái forward và backward), sau đó áp dụng phép tính trọng số chú ý lên toàn bộ chuỗi ẩn mã hóa — kết quả là một vector ngữ cảnh được tính trung bình có trọng số. Vector này được đưa vào lớp fully-connected để phân loại ý định.
- Nhánh khung: Sử dụng một GRUCell đơn lẻ hoạt động từng bước theo thời gian. Tại mỗi bước
t, đầu vào gồm: (i) biểu diễn ẩn tại vị tríttừ encoder, (ii) embedding của nhãn slot ở bước trước, và (iii) vector ngữ cảnh chú ý được tái tính từ trạng thái ẩn hiện tại và toàn bộ chuỗi mã hóa. Đầu ra của GRUCell được chiếu qua lớp phân lớp để dự đoán nhãn slot tại vị trí đó. - Hàm mất mát tổng hợp: Tổng tuyến tính của hai thành phần:
cross_entropy(intent_pred, intent_true)vàcross_entropy(slot_preds.view(-1, num_slots), slot_labels.view(-1)).
Triển khai dữ liệu và mô hình
Sử dụng thư viện torchtext để tiền xử lý dữ liệu ATIS:
from torchtext import data
def tokenize(text):
return text.split()
SRC = data.Field(
tokenize=tokenize,
lower=True,
init_token='<sos>',
eos_token='<eos>',
pad_token='<pad>',
unk_token='<unk>',
batch_first=True,
fix_length=64,
include_lengths=True
)
TGT = data.Field(
tokenize=tokenize,
lower=True,
init_token='<sos>',
eos_token='<eos>',
pad_token='<pad>',
unk_token='<unk>',
batch_first=True,
fix_length=64,
include_lengths=True
)
INTENT = data.LabelField(dtype=torch.long, sequential=False)
train_data, val_data = data.TabularDataset.splits(
path='./atis/',
train='train.csv',
validation='test.csv',
format='csv',
fields=[('idx', None), ('intent', INTENT), ('utterance', SRC), ('slots', TGT)]
)
SRC.build_vocab(train_data, min_freq=2)
TGT.build_vocab(train_data, min_freq=2)
INTENT.build_vocab(train_data)
Kiến trúc mô hình được viết lại gọn gàng, loại bỏ logic thừa và làm rõ luồng dữ liệu:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContextualAttention(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.att_proj = nn.Linear(hidden_size * 2, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
def forward(self, query, memory):
# query: [B, H*2], memory: [B, T, H*2]
B, T, H2 = memory.shape
query_expanded = query.unsqueeze(1) # [B, 1, H*2]
combined = torch.tanh(self.att_proj(torch.cat([query_expanded, memory], dim=2)))
scores = self.score(combined).squeeze(-1) # [B, T]
weights = F.softmax(scores, dim=-1) # [B, T]
context = torch.bmm(weights.unsqueeze(1), memory).squeeze(1) # [B, H*2]
return context, weights
class JointIntentSlotModel(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
hidden_size,
num_intent,
num_slot,
slot_embed_dim,
dropout=0.3,
pad_idx=1
):
super().__init__()
self.pad_idx = pad_idx
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.encoder = nn.GRU(
embed_dim, hidden_size // 2,
num_layers=1,
bidirectional=True,
batch_first=True,
dropout=dropout if 1 > 1 else 0
)
self.attention = ContextualAttention(hidden_size)
self.slot_embedding = nn.Embedding(num_slot, slot_embed_dim)
self.decoder_cell = nn.GRUCell(slot_embed_dim + hidden_size * 2, hidden_size)
self.intent_head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_intent)
)
self.slot_head = nn.Linear(hidden_size, num_slot)
def forward(self, x, lengths):
# x: [B, T], lengths: [B]
embedded = self.embedding(x) # [B, T, E]
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths, batch_first=True, enforce_sorted=True
)
enc_out, h_n = self.encoder(packed) # enc_out: [B, T, H], h_n: [2, B, H//2]
enc_out, _ = nn.utils.rnn.pad_packed_sequence(
enc_out, batch_first=True, padding_value=self.pad_idx
)
# Lấy hidden state cuối cùng (forward + backward)
final_hidden = torch.cat([h_n[0], h_n[1]], dim=1) # [B, H]
# Tính ngữ cảnh chú ý cho intent
intent_context, _ = self.attention(final_hidden, enc_out) # [B, H]
intent_logits = self.intent_head(intent_context) # [B, C_intent]
# Dự đoán slot từng bước
B, T = x.shape
slot_logits = torch.zeros(B, T, self.slot_head.out_features, device=x.device)
# Khởi tạo input token đầu tiên là <sos> (giả sử index = 2)
prev_slot = torch.full((B,), 2, dtype=torch.long, device=x.device)
# Duyệt từng vị trí
for t in range(T):
# Lấy encoder output tại vị trí t
enc_t = enc_out[:, t, :] # [B, H]
# Embedding nhãn trước đó
slot_emb = self.slot_embedding(prev_slot) # [B, S]
# Kết hợp: encoder_t + slot_emb + intent_context
combined_input = torch.cat([enc_t, slot_emb, intent_context], dim=1) # [B, H + S + H]
# Cập nhật trạng thái decoder
dec_hidden = self.decoder_cell(combined_input, final_hidden) # [B, H]
# Dự đoán slot tại t
slot_logits[:, t, :] = self.slot_head(dec_hidden) # [B, C_slot]
# Cập nhật prev_slot cho bước sau
prev_slot = slot_logits[:, t, :].argmax(dim=1)
return slot_logits, intent_logits
Mô hình được huấn luyện trong 10 epoch với hàm mất mát tổng hợp. Giai đoạn suy luận (inference) được thực hiện bằng cách chạy vòng lặp giải mã từng bước, bắt đầu từ token khởi tạo và cập nhật liên tục dựa trên dự đoán trước đó.