Kỹ Thuật Chưng Cất Tri Thức Từ DeepSeek-R1 Sang Mô Hình Nhỏ Hơn

Các mô hình học sâu hiện đại đã tạo ra bước ngoặt lớn trong trí tuệ nhân tạo, tuy nhiên kích thước khổng lồ và yêu cầu tính toán cao thường là rào cản khi triển khai thực tế. Kỹ thuật chưng cất mô hình (Model Distillation) giải quyết vấn đề này bằng cách chuyển giao tri thức từ một mô hình lớn phức tạp (giáo viên) sang một mô hình nhỏ gọn hơn (học sinh).

Bài viết này sẽ hướng dẫn quy trình sử dụng các kỹ thuật chuyên sâu như LoRA (Low-Rank Adaptation) để chưng cất khả năng suy luận của DeepSeek-R1 vào một mô hình nhẹ hơn như Microsoft Phi-3-Mini.

Khái niệm về Chưng cất Mô hình

Chưng cất là phương pháp machine learning mà trong đó một mô hình nhỏ được huấn luyện để bắt chước hành vi của một mô hình lớn đã được tiền huấn luyện. Mục tiêu cuối cùng là giữ lại phần lớn hiệu suất của mô hình gốc trong khi giảm đáng kể chi phí tính toán và bộ nhớ cần thiết.

Ý tưởng này bắt nguồn từ nghiên cứu tiên phong của Geoffrey Hinton về知识蒸馏 (Knowledge Distillation). Thay vì học trực tiếp từ dữ liệu thô, mô hình học sinh học từ đầu ra hoặc các biểu diễn trung gian của mô hình giáo viên, tương tự như quá trình học tập của con người.

Lợi ích cốt lõi:

  • Tối ưu chi phí: Mô hình nhỏ tiêu thụ ít tài nguyên phần cứng hơn.
  • Độ trễ thấp: Phù hợp cho các ứng dụng thời gian thực hoặc thiết bị biên.
  • Tính chuyên biệt: Dễ dàng điều chỉnh cho các lĩnh vực cụ thể mà không cần huấn luyện lại mô hình khổng lồ.

Các phương pháp Chưng cất phổ biến

Có nhiều cách tiếp cận để thực hiện quá trình này, mỗi cách đều có ưu điểm riêng:

  1. Chưng cất dữ liệu (Data Distillation): Mô hình giáo viên tạo ra dữ liệu tổng hợp hoặc nhãn giả để huấn luyện mô hình học sinh. Phương pháp này linh hoạt cho nhiều tác vụ khác nhau.
  2. Chưng cất Logits: Học sinh được huấn luyện để khớp với điểm số đầu ra thô (logits) của giáo viên trước khi qua hàm softmax. Cách này giữ lại thông tin về độ tin cậy của quyết định.
  3. Chưng cất đặc trưng (Feature Distillation): Truyền tải kiến thức từ các lớp trung gian. Việc căn chỉnh các biểu diễn tiềm ẩn giúp học sinh nắm bắt các đặc trưng trừu tượng tốt hơn.

Bối cảnh mô hình蒸馏 của DeepSeek

DeepSeek AI đã phát hành nhiều mô hình đã được chưng cất dựa trên các kiến trúc phổ biến như Qwen và Llama. Họ sử dụng 800.000 mẫu dữ liệu thu thập từ DeepSeek-R1 để tinh chỉnh các mô hình nguồn mở. Dù nhỏ hơn nhiều, các mô hình này vẫn đạt hiệu suất ấn tượng trên nhiều benchmark.

Tại sao nên tự chưng cất mô hình?

Mặc dù đã có các mô hình chưng cất sẵn, việc tự thực hiện quy trình này mang lại nhiều lợi thế:

  • Tối ưu hóa tác vụ cụ thể: Các mô hình chung thường không đủ chuyên sâu. Ví dụ, một chatbot tài chính cần dữ liệu suy luận về rủi ro và dự báo giá mà mô hình chung có thể thiếu.
  • Kiểm soát tài nguyên: Bạn có thể điều chỉnh kích thước mô hình học sinh phù hợp chính xác với hạ tầng hiện có.
  • Hiệu suất thực tế: Benchmark không phải lúc nào cũng phản ánh đúng hiệu năng trong môi trường sản xuất thực tế.
  • Cải thiện liên tục: Mô hình tự chưng cất có thể được cập nhật thường xuyên khi có dữ liệu mới, khác với các mô hình tĩnh.

Hướng dẫn kỹ thuật: Chưng cất DeepSeek-R1 vào Phi-3

Bước 1: Cài đặt môi trường

Đầu tiên, cần cài đặt các thư viện cần thiết cho PyTorch và Hugging Face.

pip install torch transformers peft trl datasets accelerate bitsandbytes

Bước 2: Chuẩn bị và Định dạng Dữ liệu

Bạn có thể tự tạo dữ liệu bằng cách triển khai DeepSeek-R1 qua Ollama, nhưng trong hướng dẫn này, chúng ta sẽ sử dụng bộ dữ liệu Magpie-Reasoning-V2. Bộ dữ liệu này chứa 250.000 mẫu suy luận Chain-of-Thought (CoT) do DeepSeek-R1 tạo ra, bao phủ các tác vụ toán học, lập trình và giải quyết vấn đề.

Mỗi mẫu dữ liệu cần được định dạng đúng theo template chat của mô hình đích (Phi-3).

import os
from datasets import load_dataset

# Tải bộ dữ liệu
raw_data = load_dataset(
    "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B", 
    token=os.getenv("HF_TOKEN")
)
train_split = raw_data["train"]

# Hàm định dạng mẫu dữ liệu
def prepare_sample(example):
    return {
        "text": (
            "<|user|>\n"
            f"{example['instruction']}\n"
            "<|end|>\n"
            "<|assistant|>\n"
            f"{example['response']}\n"
            "<|end|>"
        )
    }

# Áp dụng định dạng và chia tập dữ liệu
processed_data = train_split.map(prepare_sample, batched=False, remove_columns=train_split.column_names)
final_dataset = processed_data.train_test_split(test_size=0.1)

Lưu ý rằng mỗi mô hình ngôn ngữ lớn đều có quy chuẩn riêng về định dạng hội thoại. Việc căn chỉnh dữ liệu đúng cấu trúc <|user|>, <|assistant|> là bắt buộc để mô hình học được cách phản hồi phù hợp.

Bước 3: Khởi tạo Model và Tokenizer

Để mô hình nhỏ học được khả năng suy luận từng bước, chúng ta cần thêm các token đặc biệt vào bộ từ vựng.

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

base_model_name = "microsoft/phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

# Thêm token đặc biệt cho quá trình suy nghĩ
REASONING_TOKENS = ["<think>", "</think>"]
tokenizer.add_special_tokens({"additional_special_tokens": REASONING_TOKENS})
tokenizer.pad_token = tokenizer.eos_token

# Tải mô hình với hỗ trợ flash attention
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"
)
model.resize_token_embeddings(len(tokenizer))

Bước 4: Cấu hình LoRA để tối ưu bộ nhớ

LoRA giúp giảm thiểu tài nguyên bằng cách đóng băng trọng số gốc và chỉ huấn luyện các lớp adapter nhỏ.

from peft import LoraConfig

lora_settings = LoraConfig(
    r=16,  # Hạng của ma trận thấp
    lora_alpha=32,  # Hệ số tỷ lệ
    lora_dropout=0.1,  # Tỷ lệ dropout
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Các lớp attention mục tiêu
    bias="none",
    task_type="CAUSAL_LM"
)

Bước 5: Thiết lập Tham số Huấn luyện

from transformers import TrainingArguments

training_config = TrainingArguments(
    output_dir="./phi-3-distilled-output",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=20,
    learning_rate=3e-5,
    fp16=True,
    optim="paged_adamw_32bit",
    max_grad_norm=0.5,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine"
)

Bước 6: Tiến hành Huấn luyện

Sử dụng SFTTrainer từ thư viện TRL để đơn giản hóa quá trình Supervised Fine-Tuning.

from trl import SFTTrainer
from transformers import DataCollatorForLanguageModeling

# Cấu hình collator cho dữ liệu
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Khởi tạo trainer
supervisor = SFTTrainer(
    model=model,
    args=training_config,
    train_dataset=final_dataset["train"],
    eval_dataset=final_dataset["test"],
    data_collator=data_collator,
    peft_config=lora_settings
)

# Bắt đầu quá trình học
supervisor.train()
supervisor.save_model("./phi-3-distilled-output")
tokenizer.save_pretrained("./phi-3-distilled-output")

Bước 7: Hợp nhất và Lưu mô hình

Sau khi huấn luyện, cần hợp nhất các trọng số LoRA vào mô hình gốc để sử dụng độc lập mà không cần phụ thuộc vào thư viện PEFT.

merged_model = supervisor.model.merge_and_unload()
merged_model.save_pretrained("./phi-3-distilled-final")
tokenizer.save_pretrained("./phi-3-distilled-final")

Bước 8: Kiểm tra Suy luận

Cuối cùng, tải mô hình đã hợp nhất để kiểm tra khả năng sinh văn bản.

from transformers import pipeline

# Tải mô hình đã tinh chỉnh
inference_model = AutoModelForCausalLM.from_pretrained(
    "./phi-3-distilled-final",
    device_map="auto",
    torch_dtype=torch.float16
)

inference_tokenizer = AutoTokenizer.from_pretrained("./phi-3-distilled-final")
inference_model.resize_token_embeddings(len(inference_tokenizer))

# Tạo pipeline chat
chat_bot = pipeline(
    "text-generation",
    model=inference_model,
    tokenizer=inference_tokenizer,
    device_map="auto"
)

# Câu hỏi kiểm tra
input_prompt = """<|user|>
Xác suất để gieo được tổng 7 chấm với hai xúc xắc là bao nhiêu?
<|end|>
<|assistant|>
"""

result = chat_bot(
    input_prompt,
    max_new_tokens=5000,
    temperature=0.7,
    do_sample=True,
    eos_token_id=inference_tokenizer.eos_token_id
)

print(result[0]['generated_text'])

Khi so sánh kết quả, mô hình chưa qua chưng cất thường đưa ra đáp án ngắn gọn và trực tiếp. Ngược lại, mô hình sau khi chưng cất từ DeepSeek-R1 sẽ hiển thị rõ ràng phần <think>, phân tích từng bước logic trước khi đưa ra kết luận cuối cùng. Cấu trúc này giúp tăng độ chính xác đáng kể khi xử lý các vấn đề phức tạp đòi hỏi suy luận nhiều tầng.

Thẻ: deepseek-r1 knowledge-distillation lora phi-3 llm-fine-tuning

Đăng vào ngày 26 tháng 5 lúc 13:12