Flax trong hệ sinh thái JAX: Khung học sâu linh hoạt và hiệu năng cao

Flax là một khung học sâu được xây dựng trên nền tảng JAX, tập trung vào tính linh hoạt, khả năng mở rộng và kiểm soát chi tiết đối với luồng tính toán. Khác với các khung truyền thống, Flax không che giấu cơ chế thực thi — thay vào đó, nó cung cấp giao diện hướng đối tượng để định nghĩa mô hình trong khi vẫn giữ nguyên đặc tính hàm thuần (pure functional) và khả năng biên dịch của JAX.

Tại sao Flax phù hợp với nghiên cứu hiện đại?

  • Thiết kế dựa trên module trạng thái rõ ràng: Mỗi thành phần như lớp tích chập, lớp chuẩn hóa hay dropout đều là một thể hiện của lớp nn.Module, cho phép tách biệt logic định nghĩa cấu trúc khỏi logic cập nhật trạng thái.
  • Hợp nhất liền mạch với JAX: Các hàm như jax.jit, jax.vmap, jax.pmapjax.grad hoạt động trực tiếp với mô hình Flax mà không cần wrapper phức tạp.
  • Hỗ trợ trạng thái có thể huấn luyện và trạng thái cố định riêng biệt: Ví dụ: tham số mô hình (params) và thống kê của lớp BatchNorm (batch_stats) được lưu trữ và cập nhật độc lập.
  • Giao diện nhất quán cho huấn luyện, suy luận và triển khai: Cùng một mô hình có thể được sử dụng ở cả ba chế độ chỉ bằng cách điều chỉnh cách truyền dữ liệu trạng thái.

Xây dựng mạng nơ-ron đơn giản với Flax

Dưới đây là một ví dụ minh họa cách định nghĩa và huấn luyện một mạng perceptron đa tầng (MLP) sử dụng Flax và Optax:

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

class SimplePerceptron(nn.Module):
  hidden_size: int
  num_classes: int

  @nn.compact
  def __call__(self, x, training: bool = True):
    x = nn.Dense(self.hidden_size)(x)
    x = nn.Dropout(0.2)(x, deterministic=not training)
    x = nn.relu(x)
    x = nn.Dense(self.num_classes)(x)
    return x

# Khởi tạo trạng thái huấn luyện
rng = jax.random.key(42)
model = SimplePerceptron(hidden_size=64, num_classes=10)
init_x = jnp.ones((1, 28 * 28))
variables = model.init(rng, init_x, training=True)
params = variables['params']
batch_stats = variables.get('batch_stats', {})

tx = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

# Hàm mất mát với xử lý trạng thái BatchNorm
def compute_loss(params, batch_stats, model, batch, rng, training=True):
  variables = {'params': params, 'batch_stats': batch_stats}
  logits, new_model_state = model.apply(
      variables, 
      batch['image'], 
      training=training, 
      mutable=['batch_stats'],
      rngs={'dropout': rng}
  )
  one_hot_labels = jax.nn.one_hot(batch['label'], 10)
  loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot_labels))
  if training:
    return loss, (new_model_state['batch_stats'], logits)
  else:
    return loss, logits

@jax.jit
def train_step(state, batch, rng):
  grad_fn = jax.value_and_grad(
      lambda p: compute_loss(p, state.batch_stats, model, batch, rng, training=True)[0], 
      has_aux=False
  )
  loss, grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss

Mở rộng sang kiến trúc CNN và quản lý trạng thái

Flax hỗ trợ dễ dàng xây dựng các mô hình có trạng thái nội bộ như lớp chuẩn hóa theo lô hoặc lớp dropout. Dưới đây là một mô hình CNN tối giản nhưng đầy đủ chức năng:

class LightweightCNN(nn.Module):
  num_classes: int

  @nn.compact
  def __call__(self, x, training: bool = True):
    # Lớp tích chập đầu tiên + ReLU + Pooling
    x = nn.Conv(features=16, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    # Lớp tích chập thứ hai
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    # Làm phẳng và phân lớp
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(self.num_classes)(x)
    return x

Quản lý tham số nâng cao

Flax cung cấp cơ chế để kiểm soát chi tiết việc cập nhật từng nhóm tham số — ví dụ, đóng băng các lớp đầu trong vi điều chỉnh:

from flax.core import freeze, unfreeze

# Đóng băng toàn bộ tham số
frozen_params = freeze(state.params)

# Chỉ giải phóng lớp cuối cùng để cập nhật
thawed = unfreeze(frozen_params)
thawed['Dense_1'] = unfreeze(thawed['Dense_1'])  # Giả sử lớp cuối là Dense_1

# Tạo trạng thái mới với tham số đã chọn
state_finetune = state.replace(params=freeze(thawed))

Lưu/truy xuất trạng thái huấn luyện

Flax tích hợp sẵn công cụ quản lý checkpoint mạnh mẽ, hỗ trợ lưu trạng thái huấn luyện dưới dạng thư mục có cấu trúc rõ ràng:

from flax.training import checkpoints

# Lưu checkpoint tại epoch hiện tại
checkpoints.save_checkpoint(
    ckpt_dir='/tmp/flax_ckpt',
    target={'state': state, 'epoch': epoch},
    step=epoch,
    overwrite=True,
    keep=3
)

# Khôi phục từ checkpoint gần nhất
restored = checkpoints.restore_checkpoint(
    ckpt_dir='/tmp/flax_ckpt',
    target={'state': None, 'epoch': 0}
)

Tối ưu hóa triển khai đa thiết bị

Với JAX làm nền tảng, Flax tự nhiên hỗ trợ huấn luyện phân tán qua nhiều GPU/TPU bằng cách kết hợp pmapshard:

from jax.sharding import PositionalSharding
import jax.numpy as jnp

# Thiết lập phân vùng cho 4 thiết bị
devices = jax.local_devices()[:4]
sharding = PositionalSharding(devices)

# Chia nhỏ dữ liệu đầu vào theo chiều batch
def shard_batch(batch):
  return jax.tree_map(
      lambda x: jax.device_put(x, sharding.reshape(-1, len(devices))),
      batch
  )

# Hàm huấn luyện song song
@jax.pmap
def p_train_step(state, batch):
  loss, grads = jax.value_and_grad(
      lambda p: compute_loss(p, state.batch_stats, model, batch, jax.random.key(0), training=True)[0]
  )(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss

Thẻ: flax jax optax linen deep-learning

Đăng vào ngày 30 tháng 5 lúc 22:55