Tối ưu hóa mô hình tiền huấn luyện bằng Transformers và PyTorch

Tối ưu hóa mô hình tiền huấn luyện

Để sử dụng thư viện Transformers, cần nắm vững kiến thức cơ bản về PyTorch. Hướng dẫn chi tiết: Chương 6 của khóa học Transformers - Kiến thức PyTorch cần thiết

  1. Chuẩn bị dữ liệu

Ta sử dụng nhiệm vụ xác định tính đồng nghĩa (mỗi lần đầu vào hai câu, xác định xem chúng có phải là đồng nghĩa không). Dữ liệu sẽ lấy từ tập AFQMC do Ant Financial cung cấp, bao gồm 34334/4316/3861 cặp câu trong tập huấn luyện, kiểm tra và test:

{"sentence1": "Trả hết nợ rồi sao vẫn hiển thị phải thanh toán trên tài khoản Huabei", "sentence2": "Sao trả toàn bộ Huabei mà vẫn hiển thị chưa thanh toán", "label": "1"}

Dataset

PyTorch sử dụng lớp DatasetDataLoader để xử lý dữ liệu. Ta sẽ xây dựng lớp dữ liệu tùy chỉnh kế thừa từ Dataset. Tập AFQMC được lưu dưới dạng JSON, do đó ta sử dụng thư viện json để đọc từng dòng:

from torch.utils.data import Dataset
import json

class DatasetCustom(Dataset):
    def __init__(self, path_file):
        super().__init__()
        self.data = self.load_data(path_file)

    def load_data(self, path_file):
        data_dict = {}
        with open(path_file, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                sample = json.loads(line.strip())
                data_dict[idx] = sample
        return data_dict

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

train_dataset = DatasetCustom("./data/afqmc_public/train.json")
valid_dataset = DatasetCustom("./data/afqmc_public/dev.json")

print(train_dataset[0])

Kết quả in ra:

{'sentence1': 'Trả nợ vay mượn bằng phương thức đều đặn có thể thay đổi thành trả lãi trước không', 'sentence2': 'Vay mượn có thể trả lãi trước không', 'label': '0'}

Nếu dữ liệu quá lớn, ta có thể sử dụng lớp IterableDataset để xử lý theo từng phần:

from torch.utils.data import IterableDataset
import json

class IterableDatasetCustom(IterableDataset):
    def __init__(self, path_file):
        self.path_file = path_file

    def __iter__(self):
        with open(self.path_file, 'r', encoding='utf-8') as f:
            for line in f:
                sample = json.loads(line.strip())
                yield sample

train_dataset = IterableDatasetCustom("./data/afqmc_public/train.json")
print(next(iter(train_dataset)))

Kết quả in ra:

{'sentence1': 'Trả nợ vay mượn bằng phương thức đều đặn có thể thay đổi thành trả lãi trước không', 'sentence2': 'Vay mượn có thể trả lãi trước không', 'label': '0'}

DataLoader

Sau đó sử dụng DataLoader để tải dữ liệu theo batch và chuyển đổi thành định dạng đầu vào phù hợp với mô hình. Đối với các nhiệm vụ NLP, bước này bao gồm mã hóa văn bản (padding, truncation...):

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

model_name = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def batch_processing(samples):
    sentences1, sentences2 = [], []
    labels = []
    for sample in samples:
        sentences1.append(sample['sentence1'])
        sentences2.append(sample['sentence2'])
        labels.append(int(sample['label']))
    
    encoded = tokenizer(
        sentences1,
        sentences2,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    labels_tensor = torch.tensor(labels)
    return encoded, labels_tensor

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=batch_processing)

batch_input, batch_label = next(iter(train_loader))
print("Kích thước batch đầu vào:", {k: v.shape for k, v in batch_input.items()})
print("Kích thước nhãn:", batch_label.shape)

Kết quả in ra:

Kích thước batch đầu vào: {
'input_ids': torch.Size([4, 26]), 
'token_type_ids': torch.Size([4, 26]), 
'attention_mask': torch.Size([4, 26])
}
Kích thước nhãn: torch.Size([4])
  1. Xây dựng và huấn luyện mô hình

Thiết kế mô hình

Đối với bài toán phân loại, ta có thể sử dụng lớp AutoModelForSequenceClassification. Tuy nhiên, trong thực tế thường cần tùy chỉnh thêm:

import torch.nn as nn
from transformers import AutoModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Sử dụng thiết bị: {device}")

class ModelCustom(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, 2)

    def forward(self, inputs):
        outputs = self.bert(**inputs)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

model = ModelCustom().to(device)
print(model)

Kết quả in ra:

Sử dụng thiết bị: cuda
ModelCustom(
  (bert): BertModel(...)
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

Cách tiếp cận phổ biến hơn là kế thừa từ lớp BertPreTrainedModel:

from transformers import AutoConfig, BertPreTrainedModel, BertModel

class ModelCustom(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 2)
        self.post_init()

    def forward(self, inputs):
        outputs = self.bert(**inputs)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

config = AutoConfig.from_pretrained(model_name)
model = ModelCustom.from_pretrained(model_name, config=config).to(device)
print(model)

Kết quả in ra:

ModelCustom(
  (bert): BertModel(...)
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

Tối ưu hóa tham số

Chia quy trình huấn luyện thành vòng lặp huấn luyện và kiểm tra:

from tqdm import tqdm

def train_loop(dataloader, model, loss_func, optimizer, scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    finish_steps = (epoch - 1) * len(dataloader)
    
    model.train()
    for step, (inputs, labels) in enumerate(dataloader, 1):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        progress_bar.set_description(f"loss: {total_loss/(finish_steps + step):>7f}")
        progress_bar.update(1)
    return total_loss

def evaluate_loop(dataloader, model, mode="Test"):
    size = len(dataloader.dataset)
    correct = 0
    
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            correct += (outputs.argmax(1) == labels).sum().item()
    
    accuracy = correct / size
    print(f"{mode} Độ chính xác: {(100*accuracy):>0.1f}%\n")
    return accuracy

Khởi tạo optimizer và scheduler:

from transformers import AdamW, get_scheduler

learning_rate = 1e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_steps = 3 * len(train_loader)
scheduler = get_scheduler("linear", optimizer=optimizer, num_training_steps=num_steps)

Quy trình huấn luyện hoàn chỉnh:

loss_func = nn.CrossEntropyLoss()
best_accuracy = 0

for epoch in range(3):
    print(f"Epoch {epoch+1}/3\n--------------------------------------")
    total_loss = train_loop(train_loader, model, loss_func, optimizer, scheduler, epoch+1, 0)
    valid_acc = evaluate_loop(valid_loader, model, "Valid")
    
    if valid_acc > best_accuracy:
        best_accuracy = valid_acc
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}_acc_{(100*best_accuracy):0.1f}.pth")
print("Hoàn tất!")
  1. Tài liệu tham khảo

Hướng dẫn Transformers - Chương 7: Tối ưu hóa mô hình tiền huấn luyện

Thẻ: Transformers PyTorch BERT fine-tuning

Đăng vào ngày 22 tháng 6 lúc 09:41