Công cụ trích xuất thực thể dựa trên RNN cho tiếng Trung

Giới thiệu

RNN4IE là một công cụ mã nguồn mở được phát triển nhằm thực hiện trích xuất thông tin từ văn bản tiếng Trung, sử dụng các kiến trúc mạng thần kinh hồi tiếp (RNN) được xây dựng trên PyTorch. Dự án tập trung vào bài toán nhận dạng thực thể có tên (Named Entity Recognition - NER), trong đó dữ liệu được gán nhãn theo định dạng B-I-O (Begin, Inside, Outside).

Kiến trúc mô hình

Dự án cung cấp bốn biến thể mô hình khác nhau, tất cả đều dựa trên GRU làm lớp cơ sở và tích hợp CRF ở tầng ra để tối ưu hóa chuỗi dự đoán:

  • GRU-CRF: Mô hình cơ bản kết hợp GRU với lớp CRF để học các ràng buộc chuỗi.
  • GRU-MHSA-CRF: Tích hợp cơ chế tự chú ý đa đầu (Multi-Head Self-Attention) sau GRU nhằm nắm bắt mối quan hệ dài hạn giữa các từ.
  • GRU-SA-CRF: Sử dụng cơ chế chú ý mềm (Soft Attention) để tăng cường trọng số cho các từ quan trọng trong câu.
  • GRU-XCA-CRF: Áp dụng cơ chế chú ý dựa trên hiệp phương sai chéo (Cross-Covariance Attention) nhằm cải thiện biểu diễn ngữ nghĩa.

Hướng dẫn sử dụng

Cấu hình

Mỗi mô hình yêu cầu một tệp cấu hình riêng (config.cfg) chứa các siêu tham số như kích thước embedding, số lớp GRU, đường dẫn dữ liệu, v.v. Các mẫu cấu hình được cung cấp sẵn tương ứng với từng mô hình: gru_cfg, gru_mhsa_cfg, gru_sa_cfg, gru_xca_cfg.

Đào tạo mô hình

Ví dụ dưới đây minh họa cách khởi động quá trình huấn luyện cho từng mô hình:

from rnn4ie.gru.train import Train
trainer = Train()
trainer.train_model('config.cfg')
from rnn4ie.gru_mhsa.train import Train
trainer = Train()
trainer.train_model('config.cfg')

Dự đoán

Sau khi mô hình được huấn luyện, bạn có thể sử dụng lớp Predict để trích xuất thực thể từ câu mới:

from rnn4ie.gru.predict import Predict
predictor = Predict()
predictor.load_model_vocab('config.cfg')
result = predictor.predict('Theo báo cáo của Tân Hoa Xã, thành phố Lục An, tỉnh An Huy đã được bình chọn là một trong mười thành phố đáng sống nhất!')
print(result)

Tương tự với các mô hình còn lại như gru_mhsa, gru_sa, gru_xca.

Đánh giá hiệu năng

Các chỉ số đánh giá chính bao gồm độ chính xác (Precision), độ phủ (Recall), F1-score và độ phức tạp trung bình (Perplexity - PPL). Có thể sử dụng thư viện scikit-learn để tính toán nhanh chóng:

from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

Cài đặt

Dự án hỗ trợ hai cách cài đặt:

  1. Thông qua pip:
    pip install RNN4IE
  2. Từ mã nguồn:
    git clone https://github.com/jiangnanboy/RNN4IE.git
    cd RNN4IE
    python setup.py install

Nếu không muốn cài đặt, người dùng có thể tải trực tiếp mã nguồn và chạy dưới dạng module cục bộ.

Bộ dữ liệu

Dự án sử dụng dữ liệu từ báo Nhân Dân (People's Daily), với các nhãn thực thể: [ORG, PER, LOC, T, O], trong đó T đại diện cho thời gian. Dữ liệu được chia thành tập huấn luyện (train.csv) và tập kiểm thử (dev.csv), định dạng CSV với hai cột: source (văn bản) và target (nhãn B-I-O tương ứng).

Embedding tiền huấn luyện được sử dụng là sgns.sogou.char.bz2 – một ma trận embedding ký tự tiếng Trung phổ biến.

Thẻ: RNN GRU CRF Named Entity Recognition PyTorch

Đăng vào ngày 23 tháng 5 lúc 16:23