Mở Rộng Độ Dài Chuỗi Cho Mô Hình BERT Trong Huấn Luyện

Giới thiệu

Mô hình BERT tiêu chuẩn thường giới hạn độ dài chuỗi đầu vào ở mức 512 token. Tuy nhiên, trong nhiều bài toán thực tế, dữ liệu văn bản có thể dài hơn đáng kể. Để giải quyết vấn đề này, chúng ta cần điều chỉnh tham số max_position_embeddings trong cấu hình mô hình và tùy chỉnh quá trình huấn luyện để phù hợp với độ dài mới.

1. Xây dựng Lớp Dữ Liệu Tùy Chỉnh

Đầu tiên, cần tạo một class kế thừa từ Dataset của PyTorch để quản lý việc nạp dữ liệu từ file CSV. Class này sẽ trả về cặp văn bản và nhãn tương ứng cho mỗi mẫu dữ liệu.

from torch.utils.data import Dataset
from datasets import load_dataset

class LongTextDataset(Dataset):
    def __init__(self, data_split):
        # Nạp dữ liệu từ file CSV cục bộ
        self.data_store = load_dataset(
            path="csv", 
            data_files=f"dataset/{data_split}.csv", 
            split="train"
        )
    
    def __len__(self):
        return len(self.data_store)

    def __getitem__(self, index):
        content = self.data_store[index]["text"]
        target = self.data_store[index]["label"]
        return content, target

if __name__ == "__main__":
    sample_data = LongTextDataset("validation")
    for item in sample_data:
        print(item)

2. Điều Cấu Hình Và Khởi Tạo Mô Hình

Để hỗ trợ chuỗi dài hơn, chúng ta cần tải cấu hình của BERT, sửa đổi giá trị vị trí nhúng tối đa, sau đó khởi tạo lại mô hình. Trong ví dụ này, độ dài được tăng lên 1500 token. Lưu ý rằng lớp encoder sẽ được đóng băng (freeze) để chỉ huấn luyện lớp embeddings và lớp phân loại cuối cùng, giúp tiết kiệm bộ nhớ và tập trung thích nghi vị trí.

from transformers import BertModel, BertConfig
import torch

# Xác định thiết bị tính toán
hardware_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Đường dẫn tới thư mục mô hình预训练
BASE_MODEL_DIR = "path/to/bert-base-chinese"

# Tải cấu hình và điều chỉnh độ dài tối đa
config = BertConfig.from_pretrained(BASE_MODEL_DIR)
MAX_SEQ_LEN = 1500
config.max_position_embeddings = MAX_SEQ_LEN

# Khởi tạo mô hình BERT với cấu hình mới
bert_backbone = BertModel(config).to(hardware_device)

# Định nghĩa mô hình phân loại downstream
class BERTLongTextNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Lớp fully connected để phân loại
        self.classifier = torch.nn.Linear(768, 10)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # Lấy output từ lớp embeddings
        embed_out = bert_backbone.embeddings(input_ids=input_ids)
        
        # Chuẩn hóa attention_mask
        attn_mask = attention_mask.to(torch.float)
        attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
        attn_mask = attn_mask.to(embed_out.dtype)
        
        # Đóng băng phần encoder, không tính gradient
        with torch.no_grad():
            encoder_out = bert_backbone.encoder(embed_out, attention_mask=attn_mask)
        
        # Lấy trạng thái ẩn đầu tiên và đưa qua lớp phân loại
        logits = self.classifier(encoder_out.last_hidden_state[:, 0])
        return logits

print(bert_backbone.embeddings.position_embeddings)

3. Thiết Lập Quy Trình Huấn Luyện

Quá trình huấn luyện bao gồm việc tokenize dữ liệu với độ dài mới, tạo DataLoader và chạy vòng lặp tối ưu hóa. Hàm collate_fn sẽ đảm bảo mỗi batch được padding đúng kích thước MAX_SEQ_LEN.

import torch
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer
from data_module import LongTextDataset  # Import lớp dataset đã viết
from model_module import BERTLongTextNet # Import lớp model đã viết

hardware_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_ROUNDS = 30000

# Khởi tạo tokenizer
tokenizer = BertTokenizer.from_pretrained(BASE_MODEL_DIR)

def batch_prepare(batch_data):
    texts = [item[0] for item in batch_data]
    targets = [item[1] for item in batch_data]
    
    # Tokenize với độ dài tối đa đã điều chỉnh
    encoded = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=texts,
        truncation=True,
        padding="max_length",
        max_length=MAX_SEQ_LEN,
        return_tensors="pt",
        return_length=True
    )
    
    token_ids = encoded["input_ids"]
    attn_mask = encoded["attention_mask"]
    seg_ids = encoded["token_type_ids"]
    label_tensor = torch.LongTensor(targets)
    
    return token_ids, attn_mask, seg_ids, label_tensor

# Khởi tạo dataset và loader
train_set = LongTextDataset("train")
val_set = LongTextDataset("validation")

train_loader = DataLoader(
    dataset=train_set,
    batch_size=20,
    shuffle=True,
    drop_last=True,
    collate_fn=batch_prepare
)

if __name__ == "__main__":
    print(f"Running on: {hardware_device}")
    model = BERTLongTextNet().to(hardware_device)
    optimizer = AdamW(model.parameters(), lr=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(TRAIN_ROUNDS):
        # Vòng lặp huấn luyện
        for idx, (token_ids, attn_mask, seg_ids, labels) in enumerate(train_loader):
            token_ids = token_ids.to(hardware_device)
            attn_mask = attn_mask.to(hardware_device)
            seg_ids = seg_ids.to(hardware_device)
            labels = labels.to(hardware_device)
            
            outputs = model(token_ids, attn_mask, seg_ids)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Ghi log mỗi 5 batch
            if idx % 5 == 0:
                predictions = outputs.argmax(dim=1)
                accuracy = (predictions == labels).sum().item() / len(labels)
                print(f"Epoch: {epoch}, Batch: {idx}, Loss: {loss.item():.4f}, Acc: {accuracy:.4f}")
        
        # Lưu tham số mô hình định kỳ
        torch.save(model.state_dict(), f"checkpoints/epoch_{epoch}_longbert.pth")
        print(f"Epoch {epoch} saved successfully.")

Thẻ: BERT PyTorch NLP Transformers DeepLearning

Đăng vào ngày 26 tháng 5 lúc 02:37