Vai trò của Bộ nhớ trong Mô hình
Khi giải quyết các bài toán điều khiển tiêu chuẩn như cân bằng thanh trên xe đẩy (CartPole), thông tin trạng thái tại thời điểm hiện tại thường đủ để dự đoán hành động tiếp theo, tuân theo giả định Markov bậc nhất. Tuy nhiên, đối với nhiều môi trường phức tạp hơn hoặc trong điều kiện quan sát không đầy đủ (Partially Observable Markov Decision Process - POMDP), chỉ dựa vào trạng thái đơn lẻ là chưa đủ.
Một số chiến lược trước đây đã đề xuất việc đưa chuỗi nhiều khung hình liên tiếp làm đầu vào cho mô hình Deep Q-Network (DQN). Một phương án thay thế hiệu quả hơn là tích hợp mạng nơ-ron có khả năng ghi nhớ ngữ cảnh, cụ thể là sử dụng kiến trúc RNN (Recurrent Neural Network) làm hàm chính sách (Policy Network). Điều này cho phép tác nhân phân biệt được lịch sử các trạng thái đã trải qua, cải thiện hiệu suất ra quyết định trong các tình huống phụ thuộc vào trình tự.
Cấu hình Kiến trúc Mạng
Cơ sở triển khai dựa trên nguyên lý Actor-Critic. Để thích ứng với yêu cầu bộ nhớ, phần Actor sẽ được tinh chỉnh thành một mạng RNN, trong khi phần Critic giữ nguyên dạng mạng nơ-ron truyền thẳng thông thường. Dưới đây là cách tổ chức mã nguồn:
Hệ thống thư viện
Mô hình sử dụng các công cụ tiêu chuẩn cho học sâu và môi trường thử nghiệm:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import numpy as np
import matplotlib.pyplot as plt
# Thiết lập thiết bị tính toán (GPU nếu có)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Mạng Giá trị (Value Network)
Đây là mạng con dùng để ước lượng giá trị kỳ vọng của trạng thái (State-Value Function).
class ValueFunction(nn.Module):
def __init__(self, input_dim, hidden_units):
super(ValueFunction, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_units),
nn.ReLU(),
nn.Linear(hidden_units, 1)
)
def forward(self, x):
return self.layers(x)
Mạng Chính sách (RNN Policy Network)
Sử dụng cell GRU để xử lý dữ liệu chuỗi. Đầu ra bao gồm xác suất chọn hành động và trạng thái ẩn mới.
class RecurrentPolicy(nn.Module):
def __init__(self, input_dim, hidden_units, action_count):
super(RecurrentPolicy, self).__init__()
# Sử dụng GRUCell để lưu trữ thông tin ẩn tại mỗi bước thời gian
self.memory_cell = nn.GRUCell(input_dim, hidden_units)
self.output_layer = nn.Linear(hidden_units, action_count)
def forward(self, current_state, hidden_memory=None):
# Cập nhật trạng thái ẩn
if hidden_memory is None:
hidden_memory = torch.zeros(current_state.size(0),
self.memory_cell.hidden_size,
device=device)
updated_memory = self.memory_cell(current_state, hidden_memory)
logits = self.output_layer(F.leaky_relu(updated_memory))
probabilities = F.softmax(logits, dim=-1)
return probabilities, updated_memory
Quy Trình Tối Ưu Hóa Tác Nhân
Lớp điều khiển chính (Agent) kết hợp hai mạng trên và thực hiện cập nhật trọng số dựa trên độ lợi khoảng cách (TD error).
Khởi tạo Agent
class RNN_Agent_Critic:
def __init__(self, input_dim, hidden_units, action_count, learning_rate_actor,
learning_rate_critic, discount_factor):
self.actor_net = RecurrentPolicy(input_dim, hidden_units, action_count)
self.critic_net = ValueFunction(input_dim, hidden_units)
self.optimizer_actor = torch.optim.Adam(self.actor_net.parameters(), lr=learning_rate_actor)
self.optimizer_critic = torch.optim.Adam(self.critic_net.parameters(), lr=learning_rate_critic)
self.discount = discount_factor
self.device = device
Hành động và Ghi nhớ Xác suất
Thay vì chờ đợi đến cuối episode mới tính toán gradient, ta thu thập log-probability ngay tại mỗi bước hành động để hỗ trợ backpropagation đúng cách.
def execute_step(self, obs_vector, hidden_state=None):
# Chuyển đổi quan sát thành Tensor
state_tensor = torch.FloatTensor(obs_vector).unsqueeze(0).to(self.device)
prob_dist, next_hidden = self.actor_net(state_tensor, hidden_state)
dist_object = torch.distributions.Categorical(prob_dist)
action = dist_object.sample()
log_likelihood = dist_object.log_prob(action)
return action.item(), next_hidden, log_likelihood
Giai đoạn Cập nhật Tham số
Tính toán mục tiêu TD (Temporal Difference) và áp dụng các hàm mất mát riêng cho Actor và Critic.
def train_model(self, trajectory_buffer):
states = torch.FloatTensor(trajectory_buffer['states']).to(self.device)
rewards = torch.FloatTensor(trajectory_buffer['rewards']).view(-1, 1).to(self.device)
next_states = torch.FloatTensor(trajectory_buffer['next_states']).to(self.device)
done_flags = torch.FloatTensor(trajectory_buffer['dones']).view(-1, 1).to(self.device)
# Kết hợp log_prob thành Tensor liên tục
log_probs = torch.cat([torch.tensor(p) for p in trajectory_buffer['log_probs']]).to(self.device)
# Tính toán Giá trị Tiếp Theo và Mục Tiêu TD
with torch.no_grad():
td_target = rewards + self.discount * self.critic_net(next_states) * (1.0 - done_flags)
td_error = td_target - self.critic_net(states)
# Mất mát cho Actor (Policy Gradient)
actor_loss = torch.mean(-log_probs * td_error.detach())
# Mất mát cho Critic (Regressor)
critic_loss = F.mse_loss(self.critic_net(states), td_target.detach())
# Backpropagation
self.optimizer_actor.zero_grad()
self.optimizer_critic.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.optimizer_actor.step()
self.optimizer_critic.step()
Tiến Trình Huấn Luyện
Quá trình huấn luyện lặp lại các episode, khởi tạo lại trạng thái ẩn của RNN mỗi khi bắt đầu episode mới để đảm bảo tính độc lập giữa các kịch bản.
def run_training_process(environment_name, total_episodes):
env = gym.make(environment_name)
hyperparams = {
'input_dim': env.observation_space.shape[0],
'hidden_units': 64,
'action_count': env.action_space.n,
'lr_actor': 1e-3,
'lr_critic': 1e-2,
'gamma': 0.99
}
agent = RNN_Agent_Critic(**hyperparams)
scores_list = []
print(f"Bắt đầu huấn luyện {total_episodes} vòng.")
for i_episode in range(total_episodes):
obs, info = env.reset()
hidden_state = None
episode_return = 0
steps_record = {'states': [], 'next_states': [], 'rewards': [], 'dones': [], 'log_probs': []}
done_condition = False
while not done_condition:
action_id, hidden_state, prob_log = agent.execute_step(obs, hidden_state)
next_obs, reward, termination, truncated, _ = env.step(action_id)
done_condition = termination or truncated
steps_record['states'].append(obs)
steps_record['next_states'].append(next_obs)
steps_record['rewards'].append(reward)
steps_record['dones'].append(done_condition)
steps_record['log_probs'].append(prob_log)
obs = next_obs
episode_return += reward
scores_list.append(episode_return)
agent.train_model(steps_record)
if (i_episode + 1) % 100 == 0:
avg_score = np.mean(scores_list[-100:])
print(f"Episode {i_episode+1}: Điểm trung bình gần đây = {avg_score:.2f}")
return scores_list
Xác Nhận Hiệu Suất
Khi so sánh kết quả giữa mạng lưới MLP thông thường và mạng lưới tích hợp RNN trên môi trường CartPole-v1, người ta quan sát thấy:
- Mô hình sử dụng RNN hội tụ chậm hơn do độ phức tạp tính toán tăng thêm của các tham số recurrent.
- Trên môi trường đơn giản như CartPole, giả định Markov thường thỏa mãn tốt, do đó lợi ích từ việc duy trì trạng thái ẩn qua RNN không quá nổi bật so với việc tăng chi phí tính toán.
- Tuy nhiên, với các bài toán đòi hỏi ghi nhớ dài hạn, cấu trúc này sẽ phát huy sức mạnh vượt trội.