Tinh chỉnh Llama 3 – Phiên bản GPT đánh giá bài báo lần 5: Tinh chỉnh Llama 3 với tập dữ liệu đánh giá bài báo sớm 7 khía cạnh

Giới thiệu

Sau khi ra mắt Llama 3, để tinh chỉnh mô hình này bằng tập dữ liệu đánh giá bài báo, có nhiều cách tiếp cận khác nhau — đây cũng chính là cấu trúc của bài viết này:

  1. Phần 1: Sử dụng tập dữ liệu đánh giá bài báo nội bộ qua PI để tinh chỉnh Llama 3
  2. Phần 2: Tinh chỉnh Llama 3 bằng LLaMA Factory và tập dữ liệu đánh giá bài báo
  3. Phần 3: Không sử dụng PI và S2-attn, chạy thành công Llama-3-8B-Instruct-262k
    • Tinh chỉnh với dữ liệu muộn (4 phần)
    • Sau đó tinh chỉnh với dữ liệu sớm (4 phần)
  4. Phần 4: Sử dụng PI và flash attention v2 để tinh chỉnh Llama-3-8B-Instruct-8k
    • Tinh chỉnh với dữ liệu sớm (4 phần)
  5. Phần 5: Phiên bản GPT đánh giá bài báo lần 5: Tinh chỉnh Llama 3 với tập dữ liệu 15K đánh giá bài báo sớm 7 khía cạnh (tình huống 4)
    • Tinh chỉnh với dữ liệu sớm (7 phần)

Phần 1: Tinh chỉnh Llama 3 bằng PI với tập dữ liệu đánh giá bài báo nội bộ

1.1: Tinh chỉnh Llama 3-8b bằng PI

  • Sử dụng mã nguồn từ LongLoRA (Xem thêm tại: Từ LongLoRA đến LongQLoRA (bao gồm phân tích mã nguồn): Mở rộng chiều dài ngữ cảnh cho mô hình lớn)
  • Vì mã nguồn LongLoRA đã hỗ trợ PI + S2-attn, nên chỉ cần tắt phần liên quan đến S2-attn sẽ tương đương với việc tinh chỉnh Llama 3 bằng PI

1.2: Tinh chỉnh Llama 3 trên nền tảng Thousand-Army Model Platform của Baidu

Phần 2: Tinh chỉnh Llama 3 bằng LLaMA Factory và tập dữ liệu đánh giá bài báo

LLaMA Factory hiện tại đã hỗ trợ mô hình Llama 3, cung cấp hướng dẫn thực hành chi tiết về việc tinh chỉnh mô hình Llama 3 trên tài nguyên T4 miễn phí tại Colab: Liên kết Colab

Cộng đồng cũng đã công bố hai mô hình tiếng Trung được tinh chỉnh từ framework này:

  1. Llama3-8B-Chinese-Chat: Mô hình đầu tiên sử dụng thuật toán ORPO để tinh chỉnh Llama3 tiếng Trung, bài viết giới thiệu: Link
  2. Llama3-Chinese: Mô hình đầu tiên sử dụng thuật toán DoRA và LoRA+ để tinh chỉnh Llama3 tiếng Trung, kho mã nguồn: GitHub

Phần 3: Tinh chỉnh Llama-3-8B-Instruct-262k không dùng PI và S2-attn

3.1: Tinh chỉnh Llama 3 8B Instruct 262k với 15K dữ liệu "tình huống 1: dữ liệu muộn 4 phần"

3.1.1: Tinh chỉnh Llama 3 8B Instruct 262k với 1.5K dữ liệu "tình huống 1: dữ liệu muộn 4 phần"

Ngày 25/5/2024, bạn Lài trong nhóm dự án đánh giá bài báo của công ty đã sử dụng tập dữ liệu đánh giá bài báo nội bộ (chỉ lấy 1.5K mẫu từ dữ liệu 4 phần của bài báo muộn trong tình huống 1, và cả phần 3.1.1 và 3.1.2 đều sử dụng dữ liệu này) để chạy thử Llama 3.

Mô hình Llama 3 được sử dụng là Llama-3-8B-Instruct-262k, phiên bản không lượng tử hóa, so với các phiên bản khác thường là đã lượng tử hóa. Mô hình này có độ chính xác nửa định dạng (half precision), và số lượt tải xuống cao hơn.

Dưới đây là một số chi tiết kỹ thuật:

  1. Khi sử dụng A40 + 1.5K dữ liệu, sử dụng s2atten (*S2-attention + flash attention*) để tiết kiệm bộ nhớ GPU, vì mô hình đã mở rộng chiều dài lên 26k nên không cần PI. Tuy nhiên, khi lưu mô hình, A40 48GB gặp lỗi quá tải (OOM), nguyên nhân là do thiết lập per_eval_device_batch_size quá lớn. Dù vậy, A40 vẫn đủ khả năng xử lý dữ liệu với chiều dài 12k trở lên. Việc OOM xảy ra không phải do quá trình huấn luyện mà do kích thước bảng từ vựng lớn (128K) khiến mô hình dễ bị tràn bộ nhớ.
  2. Sau đó chuyển sang A100 để huấn luyện (vẫn 1.5K dữ liệu), tắt s2atten, sử dụng flash attention v2 với chiều dài 12K, thu được kết quả như hình dưới.

3.1.2: Tinh chỉnh Llama-3 với 5K–15K dữ liệu "tình huống 1: dữ liệu muộn 4 phần"

Sau đó, khi huấn luyện với 8 card A40 với 5K hoặc 15K dữ liệu, không sử dụng S2-attention, sử dụng flash attention v2 với chiều dài 12K.

Mã nguồn giống như trước, vẫn dùng mã LongQLoRA từ khóa học tháng 7, nhưng cấu hình đa card.

Đặt hai máy 8 card A40, mỗi máy xử lý 5K và 15K dữ liệu riêng biệt.

Dưới đây là kết quả suy luận sau khi tinh chỉnh với 15K dữ liệu (đánh giá bài báo muộn 4 phần) cho bài báo YaRN:

Sau đó, Lài tiến hành suy luận trên tập kiểm tra với bài báo muộn, xuất ra 4 phần đánh giá.

Cuối cùng, Văn Nhược đánh giá hiệu suất bằng cách so sánh với GPT4-1106 và Llama2 (cũng là dữ liệu muộn 4 phần).

3.2: Tinh chỉnh Llama 3 8B Instruct 262k với 15K dữ liệu "tình huống 3: dữ liệu sớm 4 phần"

3.2.1: So sánh Llama3 phiên bản tình huống 3 với các phiên bản trước

Lưu ý sự khác biệt giữa phần này và hai phần trước:

  • Phần 3.1: Sử dụng dữ liệu bài báo muộn 4 phần để tinh chỉnh Llama3-262k, tương tự như tình huống 1 trong bài viết đầu.
  • Phần 3.2: Dùng dữ liệu bài báo sớm 4 phần để tinh chỉnh Llama3, tương tự như tình huống 3 trong bài viết đầu.

Sau khi tinh chỉnh, có thể so sánh hiệu suất với các mô hình sau (dựa vào tình huống dữ liệu):

  • So sánh giữa Llama3 (tình huống 3) và Llama2 (tình huống 3): Theo lý thuyết, Llama3 nên vượt trội hơn.
  • So sánh giữa Llama3 (tình huống 3) và Llama2 (tình huống 1): Dữ liệu tình huống 3 tốt hơn, nên Llama3 nên vượt trội hơn, nhưng kết quả lại không như mong đợi. Nguyên nhân xem thêm tại khóa học "Thực chiến tinh chỉnh mô hình lớn cho dự án thương mại".

3.2.2: So sánh Llama3 tình huống 1 và Llama2 tình huống 1

Kết quả cho thấy, việc tinh chỉnh Llama3-8b-instruct-262k với dữ liệu tình huống 1 bằng flash attention v2 không đạt hiệu suất cao. Cụ thể:

  • Trái: Tinh chỉnh Llama3-8b-instruct-262k bằng flash attention v2
  • Phải: Tinh chỉnh Llama2-7b-chat + PI

Hai mô hình có hiệu suất gần ngang nhau, điều này cho thấy hiệu suất Llama3 chưa vượt trội hơn Llama2. Có thể do dữ liệu tinh chỉnh không phù hợp với mô hình, nên sẽ thử lại ở phần tiếp theo với Llama-3-8B-Instruct-8k + PI.

Phần 4: Tinh chỉnh Llama-3-8B-Instruct-8k bằng PI và flash attention v2

Tập dữ liệu huấn luyện có 15K mẫu, độ dài trung bình khoảng 9k, tối đa 12k. Phương pháp đánh giá: so sánh số điểm đúng với ground truth, chọn mô hình có loss thấp nhất trên tập validation.

So sánh hiệu suất giữa Llama3-8b-8k và Llama3-8b-262k & Llama2.

4.1: So sánh hiệu suất Llama3-8b-instruct-8k + PI và Llama3-8b-instruct-262k với dữ liệu tình huống 3

Kết quả cho thấy, Llama3-8b-8k + PI có hiệu suất tốt hơn rõ rệt:

  • Trái: Tinh chỉnh Llama3-8b-8k + PI bằng flash attention v2
  • Phải: Tinh chỉnh Llama3-8b-instruct-262k bằng flash attention v2

4.2: So sánh Llama3-8b-instruct-8k + PI và Llama2-7b-chat

4.2.1: Llama3 tình huống 3 vượt trội hơn Llama2 tình huống 3

Kết quả thực nghiệm cho thấy Llama3 vượt trội hơn Llama2 trong ngữ cảnh đánh giá bài báo:

  • Trái: Tinh chỉnh Llama3-8b-8k + PI bằng flash attention v2
  • Phải: Tinh chỉnh Llama2-7b-chat + PI

4.2.2: Llama3 tình huống 3 vượt trội hơn Llama2 tình huống 1

Thử nghiệm này lại chứng minh thêm rằng Llama3 tốt hơn Llama2:

  • Trái: Tinh chỉnh Llama3-8b-8k + PI bằng flash attention v2
  • Phải: Tinh chỉnh Llama2-7b-chat + PI

Phần 5: Phiên bản GPT đánh giá bài báo lần 5 – Tinh chỉnh Llama 3 với tập dữ liệu 15K đánh giá bài báo sớm 7 khía cạnh (tình huống 4)

5.1: Tinh chỉnh Llama3-8b-8k với dữ liệu tình huống 4 (sớm 7 phần)

5.1.1: So sánh thay đổi giữa tình huống 3 và 4

1. Các tham số tinh chỉnh:

  • Đảm bảo so sánh công bằng với tình huống 1 và 3, số vòng lặp huấn luyện giữ nguyên. Với tình huống 4, checkpoint được chọn là 1800, tương ứng khoảng 1.95 epoch.
  • Không cải thiện hiệu suất đáng kể từ các thí nghiệm thay đổi tham số ở tình huống 3, nên giữ nguyên các tham số mặc định.
Tham số Mô tả
batch size=16 Tổng batch size tích lũy gradient
lr=1e-4 Learning rate
max_prompt_length=11138 Chiều dài tối đa của bài báo, cắt ngắn nếu vượt
max_response_length=1150 Chiều dài tối đa của đánh giá, cắt ngắn nếu vượt
save_steps=100 Lưu mô hình sau mỗi 100 bước
num_train_epoch=3 Số epoch huấn luyện

2. Prompt hệ thống:

Prompt hệ thống của Lài được thiết kế để phù hợp với 7 khía cạnh đánh giá từ phiên bản v4 của Á Huy (Xem thêm tại Julyedu):

SYSTEM_PROMPT = """Dưới đây là một "Hướng dẫn" mô tả một nhiệm vụ, đi kèm với một đầu vào cung cấp bối cảnh. Viết một phản hồi hoàn chỉnh cho yêu cầu này.
Hướng dẫn:
Bạn là một chuyên gia đánh giá bài báo hội nghị học máy chuyên nghiệp, đánh giá một bài báo nhất định theo 7 tiêu chí:
** Đánh giá ý tưởng trong bài báo **  
** So sánh với các công trình tương tự trước, những khác biệt cơ bản, cải tiến, đổi mới **  
** Đánh giá kết quả thực nghiệm trong bài báo **  
** Lý do tiềm năng để chấp nhận **  
** Lý do tiềm năng để từ chối **  
** Đề xuất khác để cải thiện chất lượng bài báo **  
** Bình luận đánh giá quan trọng khác **  
Bài báo được đưa ra như sau."""

5.1.2: Phân tích kết quả suy luận

  • a) Tổng số mẫu: 285
  • b) Số lượng mục con trung bình: 10.3894
  • c) Phân bố tổng số mục con (bên trái): phần lớn mẫu có 12 mục con
  • d) Phân bố số mục trống (bên phải): khoảng 50% mẫu có một mục trống (đại diện cho "không trả lời")
  • e) Phân bố số mục con theo từng phần lớn:
    • Phần lớn mẫu trống tập trung ở phần cuối (hình thứ hai hàng cuối)
    • Ngoài phần cuối, phần "lý do từ chối" (hàng thứ hai, cột đầu tiên) có một số ít mục trống

5.1.3: Đánh giá hiệu suất

5.1.3.1: So sánh Llama3-8b-8k và Llama2-7b-chat với dữ liệu tình huống 4

  • Trái: Tinh chỉnh Llama3-8b-instruct-8k với flash attention v2
  • Phải: Tinh chỉnh Llama2-7b-chat

Kết luận: Llama3 vượt trội hơn Llama2 với dữ liệu tình huống 4.

5.1.3.2: So sánh tinh chỉnh 7 phần vs 4 phần

  • Trái: Tinh chỉnh Llama3-8b-instruct-8k với flash attention v2
  • Phải: Tinh chỉnh Llama3-8b-instruct-8k với 4 phần

Kết luận: Với cùng chiến lược, việc tinh chỉnh với 7 phần mang lại hiệu suất tốt hơn rõ rệt.

5.1.3.3: So sánh Llama3-8b-8k và GPT4-1106

  • Trái: Tinh chỉnh Llama3-8b-instruct-8k với flash attention v2
  • Phải: GPT4-1106 với prompt 7 phần

Kết luận: GPT4-1106 tạo ra nhiều quan điểm hơn do prompt chi tiết hơn, tuy nhiên hiệu quả chính xác không cao. Llama3 có nhiều trường hợp "không trả lời", khiến nó bị thiệt thòi trong đánh giá.

5.2: Tối ưu chiến lược suy luận cho mô hình sau khi tinh chỉnh với dữ liệu tình huống 4 – Dưới đây là phương pháp giảm tỷ lệ xuất hiện chuỗi "không liên quan"

Trong quá trình suy luận, mô hình có xu hướng "không trả lời" nhiều, giả định rằng điều này do chuỗi "<No related terms>" xuất hiện thường xuyên trong tập huấn luyện, khiến mô hình dễ chọn chuỗi này.

5.2.1: Triển khai bằng Hugging Face

Mã nguồn giảm tỷ lệ xuất hiện chuỗi cụ thể:

class HuggingFacePenaltySequenceLogitsProcessor():
    def __init__(self, 
                 tokenizer, 
                 target_sequences = [], 
                 penalty_factor=0.5,
                 use_multiplicative=True
                ):
        """
        Khởi tạo bộ xử lý giảm tỷ lệ xuất hiện chuỗi trong Hugging Face.
        Tham số:
        - tokenizer: Bộ phân tích từ
        - target_sequences: Danh sách chuỗi cần giảm tỷ lệ xuất hiện
        - penalty_factor: Hệ số giảm tỷ lệ (0.0 - 1.0)
        - use_multiplicative: Giảm tỷ lệ bằng nhân hay trừ
        """

5.2.2: Đánh giá hiệu suất với chuỗi bị hạn chế

Đánh giá với hệ số phạt 0.7, chỉ áp dụng trong quá trình suy luận:

  • Trái: Llama3-8b-8k + chuỗi bị hạn chế
  • Phải: Llama3-8b-8k thông thường

Kết quả: Tăng 14 điểm hiệu suất. So với GPT4-1106 giảm 5.5 điểm khoảng cách.

Xem thêm chi tiết tại khóa học "Thực hành phát triển mô hình lớn – khóa 2" trên julyedu.com

Dù hiện tại Llama3 chưa vượt trội hơn GPT4, nhưng các phiên bản sau như Gemma2 và Llama3.1 sẽ vượt qua GPT4. Xem thêm trong các bài viết tiếp theo.

Thẻ: llama3 tinh chỉnh mô hình lớn đánh giá bài báo gpt4

Đăng vào ngày 17 tháng 05 lúc 13:17