Xây dựng dữ liệu huấn luyện đơn giản cho mô hình BERT

Mô hình BERT gốc được huấn luyện với hai nhiệm vụ chính:

  • Masked Language Model (MLM): Dự đoán các từ bị che (mask) trong câu.
  • Next Sentence Prediction (NSP): Xác định liệu câu thứ hai có phải là câu tiếp theo của câu đầu tiên hay không.

Dữ liệu đầu vào cho BERT bao gồm ba loại embedding:

  1. Token Embeddings: Biểu diễn từng từ/token.
  2. Segment Embeddings: Phân biệt giữa hai câu trong cặp đầu vào.
  3. Position Embeddings: Mã hóa vị trí của từng token trong chuỗi.

Dưới đây là cách xây dựng dữ liệu huấn luyện BERT một cách đơn giản bằng Python.

import re
import random
import numpy as np

# Văn bản mẫu
raw_text = (
    'Sau đó, bài viết đưa ra 5 đề xuất cho quan hệ Trung-Mỹ.\n'
    'Thứ nhất, Mỹ nên khôi phục các chương trình học bổng như "Peace Corps" tại Trung Quốc.\n'
    'Các chương trình này từng giúp Mỹ hiểu rõ hơn về Trung Quốc nhưng đã bị chính quyền Trump hủy bỏ.\n'
    'Thứ hai, Mỹ nên ngừng gán mác tiêu cực cho Viện Khổng Tử – vốn chỉ là trung tâm văn hóa và giáo dục, tương tự Viện Goethe của Đức.\n'
    'Thứ ba, Mỹ nên cho phép các phóng viên Trung Quốc bị trục xuất trước đây trở lại. Đồng thời, Trung Quốc cũng nên mở cửa cho báo chí Mỹ.\n'
    'Thứ tư, Mỹ cần dỡ bỏ lệnh cấm nhập cảnh đối với đảng viên Đảng Cộng sản Trung Quốc.\n'
    'Thứ năm, Mỹ nên mời Trung Quốc mở lại Lãnh sự quán tại Houston.\n'
    'Đổi lại, Trung Quốc sẽ cho phép Mỹ tái mở Lãnh sự quán tại Thành Đô.\n'
    'Bài viết kết luận rằng dù nhỏ, những hành động này rất quan trọng để xây dựng lòng tin và giải quyết các vấn đề phức tạp hơn.'
)

# Tiền xử lý: loại bỏ dấu câu và chuyển về chữ thường
sentences = [re.sub(r'[.,"""!?\\-]', '', s.strip().lower()) for s in raw_text.split('\n') if s.strip()]
all_chars = list(''.join(sentences))

# Xây dựng từ điển
special_tokens = ['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]']
vocab = special_tokens + sorted(set(all_chars))
word2idx = {token: idx for idx, token in enumerate(vocab)}
idx2word = {idx: token for token, idx in word2idx.items()}
vocab_size = len(vocab)

# Chuyển câu thành dãy ID
tokenized_sentences = [[word2idx[char] for char in sent] for sent in sentences]

Tiếp theo, chuẩn bị dữ liệu theo định dạng yêu cầu của BERT, bao gồm việc tạo cặp câu (cho NSP) và áp dụng chiến lược masking (cho MLM).

max_len = 120      # Độ dài tối đa của chuỗi đầu vào
max_mask = 5       # Số lượng token tối đa được mask trong mỗi mẫu
batch_size = 6

def create_training_instances():
    dataset = []
    pos_count = neg_count = 0
    
    while pos_count < batch_size // 2 or neg_count < batch_size // 2:
        # Chọn ngẫu nhiên hai câu
        i_a = random.randint(0, len(tokenized_sentences) - 1)
        i_b = random.randint(0, len(tokenized_sentences) - 1)
        
        tokens_a = tokenized_sentences[i_a]
        tokens_b = tokenized_sentences[i_b]
        
        # Xác định mối quan hệ giữa hai câu
        is_next = (i_a + 1 == i_b)
        
        # Chỉ chấp nhận mẫu nếu chưa đủ số lượng tương ứng
        if is_next and pos_count >= batch_size // 2:
            continue
        if not is_next and neg_count >= batch_size // 2:
            continue
        
        # Tạo input_ids: [CLS] + A + [SEP] + B + [SEP]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
        
        # Xác định các vị trí có thể bị mask (loại trừ [CLS], [SEP])
        mask_candidates = [
            i for i, tid in enumerate(input_ids)
            if tid not in (word2idx['[CLS]'], word2idx['[SEP]'])
        ]
        random.shuffle(mask_candidates)
        
        num_to_mask = min(max_mask, max(1, int(len(input_ids) * 0.15)))
        masked_pos = []
        masked_labels = []
        
        # Áp dụng chiến lược masking
        for pos in mask_candidates[:num_to_mask]:
            masked_pos.append(pos)
            masked_labels.append(input_ids[pos])
            
            rand = random.random()
            if rand < 0.8:
                input_ids[pos] = word2idx['[MASK]']          # 80% thay bằng [MASK]
            elif rand < 0.9:
                # 10% thay bằng token ngẫu nhiên (không phải token đặc biệt)
                rand_id = random.randint(5, vocab_size - 1)
                input_ids[pos] = rand_id
            # 10% giữ nguyên
        
        # Padding đến max_len
        pad_len = max_len - len(input_ids)
        input_ids += [word2idx['[PAD]']] * pad_len
        segment_ids += [0] * pad_len
        
        # Padding cho masked arrays
        if len(masked_pos) < max_mask:
            pad_needed = max_mask - len(masked_pos)
            masked_pos += [0] * pad_needed
            masked_labels += [0] * pad_needed
        
        # Lưu mẫu
        if is_next:
            dataset.append((input_ids, segment_ids, masked_labels, masked_pos, True))
            pos_count += 1
        else:
            dataset.append((input_ids, segment_ids, masked_labels, masked_pos, False))
            neg_count += 1
            
    return dataset

Sau khi tạo xong dữ liệu, có thể đóng gói thành dataset để sử dụng với DataLoader trong PyTorch:

from torch.utils.data import Dataset, DataLoader

class BertDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]

# Tạo dữ liệu và DataLoader
training_data = create_training_instances()
loader = DataLoader(BertDataset(training_data), batch_size=batch_size, shuffle=True)

Lưu ý: Position embeddings thường được tạo tự động trong mô hình BERT dựa trên độ dài chuỗi đầu vào, ví dụ bằng cách sử dụng torch.arange để sinh vị trí cho từng token trong batch.

Thẻ: BERT masked-language-model next-sentence-prediction PyTorch NLP

Đăng vào ngày 22 tháng 5 lúc 20:53