Chi tiết về presence_penalty, frequency_penalty và repetition_penalty trong mô hình ngôn ngữ

Khi làm việc với các mô hình ngôn ngữ lớn, đặc biệt là qua API, bạn sẽ gặp ba tham số quen thuộc: presence_penalty, frequency_penaltyrepetition_penalty. Nhiều tài liệu chỉ giải thích sơ lược rằng cả ba đều dùng để giảm lặp từ, nhưng sự khác biệt cụ thể thường không được đề cập. Bài viết này sẽ phân tích chi tiết từ mã nguồn thực tế, giúp bạn hiểu rõ cách chúng hoạt động.

Bắt đầu từ mã nguồn

Trong transformers của Hugging Face, tham số presence_penalty chỉ xuất hiện trong danh sách UNUSED_CHAT_COMPLETION_FIELDS, nghĩa là chưa được hỗ trợ chính thức. Xem tại: source.

Ngược lại, vllm triển khai đầy đủ cả ba tham số. Bạn có thể tham khảo các hàm test:

  • test_sampler_presence_penalty: source
  • test_sampler_frequency_penaltytest_sampler_repetition_penalty nằm cùng file.

Tất cả các hàm này cuối cùng gọi đến hàm apply_penalties tại: source.

Phân tích chi tiết

Trước hết, cần phân biệt prompt tokens (token từ đầu vào) và output tokens (token do mô hình sinh ra). Mỗi tham số có phạm vi thống kê khác nhau.

Giả sử logits đầu ra có dạng (num_seqs, seq_length, vocab_size). Khi sinh token mới, ta chỉ xét logits của token cuối cùng: logits[:, -1, :] có kích thước (vocab_size).

1. Repetition Penalty

Cách hoạt động: Duyệt toàn bộ sequence (cả prompt và output), tìm các token đã xuất hiện. Với mỗi token này, điều chỉnh logits tương ứng:

if logit >= 0:
    logit = logit / p
else:
    logit = logit * p

Trong đó p là giá trị repetition_penalty. Khi p > 1, logits của token đã xuất hiện bị giảm, giúp tránh lặp. p = 1 nghĩa là không áp dụng penalty.

2. Presence Penalty

Mã nguồn trong vllm:

logits -= presence_penalties.unsqueeze(dim=1) * output_mask
  • output_mask là ma trận boolean cùng kích thước với logits, đánh dấu token nào đã xuất hiện trong output tokens.
  • Penalty là hằng số, trừ trực tiếp vào logits của token đã xuất hiện. presence_penalty = 0 nghĩa là tắt.

3. Frequency Penalty

Mã nguồn:

logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
  • output_bin_counts đếm số lần xuất hiện của mỗi token trong output.
  • Penalty tỉ lệ thuận với tần suất: token xuất hiện càng nhiều thì bị giảm càng mạnh.
  • Chỉ áp dụng cho output tokens, không tính prompt tokens.
  • frequency_penalty = 0 nghĩa là tắt.

Tổng kết

Tham sốPhạm viCách tínhGiá trị tắt
repetition_penaltyprompt + outputChia/nhân với hằng số1
presence_penaltyoutputTrừ hằng số nếu có xuất hiện0
frequency_penaltyoutputTrừ theo tần suất xuất hiện0

Giá trị càng cao thì khả năng lặp càng thấp, và ngược lại.

Thẻ: logits Transformers vLLM penalty OpenAI API

Đăng vào ngày 25 tháng 6 lúc 09:29