Giới thiệu
Mô hình BERT tiêu chuẩn thường giới hạn độ dài chuỗi đầu vào ở mức 512 token. Tuy nhiên, trong nhiều bài toán thực tế, dữ liệu văn bản có thể dài hơn đáng kể. Để giải quyết vấn đề này, chúng ta cần điều chỉnh tham số max_position_embeddings trong cấu hình mô hình và tùy chỉnh quá trình huấn luyện để phù hợp với độ dài mới.
1. Xây dựng Lớp Dữ Liệu Tùy Chỉnh
Đầu tiên, cần tạo một class kế thừa từ Dataset của PyTorch để quản lý việc nạp dữ liệu từ file CSV. Class này sẽ trả về cặp văn bản và nhãn tương ứng cho mỗi mẫu dữ liệu.
from torch.utils.data import Dataset
from datasets import load_dataset
class LongTextDataset(Dataset):
def __init__(self, data_split):
# Nạp dữ liệu từ file CSV cục bộ
self.data_store = load_dataset(
path="csv",
data_files=f"dataset/{data_split}.csv",
split="train"
)
def __len__(self):
return len(self.data_store)
def __getitem__(self, index):
content = self.data_store[index]["text"]
target = self.data_store[index]["label"]
return content, target
if __name__ == "__main__":
sample_data = LongTextDataset("validation")
for item in sample_data:
print(item)
2. Điều Cấu Hình Và Khởi Tạo Mô Hình
Để hỗ trợ chuỗi dài hơn, chúng ta cần tải cấu hình của BERT, sửa đổi giá trị vị trí nhúng tối đa, sau đó khởi tạo lại mô hình. Trong ví dụ này, độ dài được tăng lên 1500 token. Lưu ý rằng lớp encoder sẽ được đóng băng (freeze) để chỉ huấn luyện lớp embeddings và lớp phân loại cuối cùng, giúp tiết kiệm bộ nhớ và tập trung thích nghi vị trí.
from transformers import BertModel, BertConfig
import torch
# Xác định thiết bị tính toán
hardware_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Đường dẫn tới thư mục mô hình预训练
BASE_MODEL_DIR = "path/to/bert-base-chinese"
# Tải cấu hình và điều chỉnh độ dài tối đa
config = BertConfig.from_pretrained(BASE_MODEL_DIR)
MAX_SEQ_LEN = 1500
config.max_position_embeddings = MAX_SEQ_LEN
# Khởi tạo mô hình BERT với cấu hình mới
bert_backbone = BertModel(config).to(hardware_device)
# Định nghĩa mô hình phân loại downstream
class BERTLongTextNet(torch.nn.Module):
def __init__(self):
super().__init__()
# Lớp fully connected để phân loại
self.classifier = torch.nn.Linear(768, 10)
def forward(self, input_ids, attention_mask, token_type_ids):
# Lấy output từ lớp embeddings
embed_out = bert_backbone.embeddings(input_ids=input_ids)
# Chuẩn hóa attention_mask
attn_mask = attention_mask.to(torch.float)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.to(embed_out.dtype)
# Đóng băng phần encoder, không tính gradient
with torch.no_grad():
encoder_out = bert_backbone.encoder(embed_out, attention_mask=attn_mask)
# Lấy trạng thái ẩn đầu tiên và đưa qua lớp phân loại
logits = self.classifier(encoder_out.last_hidden_state[:, 0])
return logits
print(bert_backbone.embeddings.position_embeddings)
3. Thiết Lập Quy Trình Huấn Luyện
Quá trình huấn luyện bao gồm việc tokenize dữ liệu với độ dài mới, tạo DataLoader và chạy vòng lặp tối ưu hóa. Hàm collate_fn sẽ đảm bảo mỗi batch được padding đúng kích thước MAX_SEQ_LEN.
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer
from data_module import LongTextDataset # Import lớp dataset đã viết
from model_module import BERTLongTextNet # Import lớp model đã viết
hardware_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_ROUNDS = 30000
# Khởi tạo tokenizer
tokenizer = BertTokenizer.from_pretrained(BASE_MODEL_DIR)
def batch_prepare(batch_data):
texts = [item[0] for item in batch_data]
targets = [item[1] for item in batch_data]
# Tokenize với độ dài tối đa đã điều chỉnh
encoded = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=texts,
truncation=True,
padding="max_length",
max_length=MAX_SEQ_LEN,
return_tensors="pt",
return_length=True
)
token_ids = encoded["input_ids"]
attn_mask = encoded["attention_mask"]
seg_ids = encoded["token_type_ids"]
label_tensor = torch.LongTensor(targets)
return token_ids, attn_mask, seg_ids, label_tensor
# Khởi tạo dataset và loader
train_set = LongTextDataset("train")
val_set = LongTextDataset("validation")
train_loader = DataLoader(
dataset=train_set,
batch_size=20,
shuffle=True,
drop_last=True,
collate_fn=batch_prepare
)
if __name__ == "__main__":
print(f"Running on: {hardware_device}")
model = BERTLongTextNet().to(hardware_device)
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(TRAIN_ROUNDS):
# Vòng lặp huấn luyện
for idx, (token_ids, attn_mask, seg_ids, labels) in enumerate(train_loader):
token_ids = token_ids.to(hardware_device)
attn_mask = attn_mask.to(hardware_device)
seg_ids = seg_ids.to(hardware_device)
labels = labels.to(hardware_device)
outputs = model(token_ids, attn_mask, seg_ids)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Ghi log mỗi 5 batch
if idx % 5 == 0:
predictions = outputs.argmax(dim=1)
accuracy = (predictions == labels).sum().item() / len(labels)
print(f"Epoch: {epoch}, Batch: {idx}, Loss: {loss.item():.4f}, Acc: {accuracy:.4f}")
# Lưu tham số mô hình định kỳ
torch.save(model.state_dict(), f"checkpoints/epoch_{epoch}_longbert.pth")
print(f"Epoch {epoch} saved successfully.")