Trong web dev, có một cấu trúc gần như mọi framework đều có: vòng lặp request, handler, response. Express, FastAPI, Spring Boot, Rails. Bạn học một cái xong, đọc cái khác chỉ mất 30 phút vì cấu trúc giống nhau.

Trong ML, vòng lặp tương đương là training loop. Mọi neural network từ logistic regression đến Llama-3-400B đều training theo đúng 5 bước: forward, loss, backward, optimizer step, scheduler step. Hiểu một cái là hiểu tất cả.

Vấn đề là phần lớn dev đọc tutorial chỉ thấy trainer.fit() của Lightning hoặc Trainer.train() của HuggingFace. Magic. Khi cần debug “tại sao loss không giảm”, “tại sao gradient explode”, “tại sao learning rate nên warm up”, bạn không biết bắt đầu từ đâu vì cái trainer kia đã ăn hết cả pipeline.

Bài này tháo cái trainer ra. Code training loop từ zero bằng PyTorch thuần, train một MLP nhỏ trên synthetic data, rồi mở rộng sang transformer mini. Sau bài này bạn đọc training code của bất cứ ai cũng không bị lạc.

Mental model: 5 bước, lặp lại

Toàn bộ training loop, dù to hay nhỏ, đều theo cấu trúc này:

for epoch in range(num_epochs):
    for batch in dataloader:
        # 1. Forward: data đi qua model, sinh ra prediction
        logits = model(batch.input)

        # 2. Loss: so prediction với ground truth
        loss = loss_fn(logits, batch.target)

        # 3. Backward: tính gradient của loss theo từng parameter
        loss.backward()

        # 4. Optimizer step: cập nhật parameter theo gradient
        optimizer.step()
        optimizer.zero_grad()

        # 5. Scheduler step: điều chỉnh learning rate
        scheduler.step()

Năm dòng, đủ để train mọi thứ từ MNIST classifier đến GPT-4. Phần còn lại của training infrastructure (logging, checkpointing, multi-GPU, mixed precision) là bao quanh 5 dòng này.

Điều cần ngấm sớm: một step training = một batch đi qua đủ 5 bước, một lần cập nhật weights. Số step để train xong = (số sample / batch size) x số epoch. GPT-3 train 300 tỷ tokens với batch size 3.2 triệu tokens, tổng khoảng 95,000 steps. Mỗi step mất khoảng 5 phút trên 10,000 GPU. Tổng 34 ngày.

Phần 1: Forward pass

Forward là phần dễ nhất. Data đi vào, đi qua từng layer, sinh ra output.

import torch
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self, in_dim=784, hidden=256, out_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, out_dim)
        self.act = nn.GELU()

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        return self.fc3(x)

model = SimpleMLP()
batch_input = torch.randn(32, 784)
logits = model(batch_input)

Forward pass đơn giản là gọi model(input). Đằng sau, PyTorch chạy __call__ của Module, cuối cùng gọi forward().

Có một thứ quan trọng PyTorch làm tự động: build computation graph. Mỗi phép toán trong forward (matmul, add, activation) được ghi lại thành một node trong graph. Graph này sẽ được dùng ở bước backward để tính gradient bằng chain rule.

Nếu bạn không muốn build graph (ví dụ lúc inference), wrap forward trong torch.no_grad():

with torch.no_grad():
    logits = model(batch_input)

Trong training, KHÔNG wrap no_grad(). Cần graph để backward.

Phần 2: Loss function

Loss là một con số đo “model dự đoán sai bao nhiêu”. Training là quá trình tối thiểu hoá con số này.

Với classification:

loss_fn = nn.CrossEntropyLoss()
targets = torch.tensor([3, 7, 0, 9])
loss = loss_fn(logits, targets)

CrossEntropyLoss của PyTorch combo LogSoftmax + NLLLoss. Không cần softmax trong model, để raw logits là đủ. Đây là pattern phổ biến nhưng nhiều dev mới hay lặp softmax 2 lần, loss sai mà không biết.

Với LLM, loss vẫn là cross-entropy nhưng tính trên từng token:

loss = nn.functional.cross_entropy(
    logits.view(-1, vocab_size),
    targets.view(-1),
    ignore_index=-100,
)

Bài Probability cho LLM: softmax, cross-entropy, perplexity đã đi sâu vào tại sao cross-entropy là choice mặc định cho LLM, và liên hệ với perplexity. Nếu chưa đọc, đọc lại sau bài này sẽ thấy mượt hơn.

Phần 3: Backward pass

Đây là phần làm nên tên tuổi PyTorch. Một dòng:

loss.backward()

PyTorch sẽ:

  1. Đi ngược computation graph từ loss về từng parameter
  2. Áp chain rule tính d_loss / d_param cho mọi param có requires_grad=True
  3. Lưu gradient vào param.grad

Sau khi gọi loss.backward(), mọi param.grad đều có giá trị. Kiểm tra:

for name, p in model.named_parameters():
    if p.grad is not None:
        print(f"{name}: grad mean={p.grad.mean():.6f}, std={p.grad.std():.6f}")

Hai pitfall hay gặp:

Pitfall 1: quên zero_grad(). Gradient được cộng dồn vào param.grad mỗi lần backward. Nếu không zero, gradient của step trước sẽ tích lũy với step này. Loss diverge ngay. Luôn gọi optimizer.zero_grad() (hoặc model.zero_grad()) trước hoặc sau mỗi step.

Pitfall 2: gradient explosion. Với network sâu (hơn 12 layer) hoặc learning rate quá cao, gradient có thể vọt lên hàng nghìn, sau đó NaN. Cách fix: gradient clipping.

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Cap tổng L2 norm của tất cả gradient ở 1.0. Mọi tutorial LLM đều có dòng này.

Phần 4: Optimizer step

Optimizer quyết định cách dùng gradient để cập nhật parameter. Đơn giản nhất là SGD:

param_new = param_old - learning_rate * gradient

PyTorch:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.step()

Nhưng SGD thuần ít khi đủ cho LLM. Optimizer phổ biến nhất hiện nay là AdamW:

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),
    eps=1e-8,
    weight_decay=0.1,
)

AdamW giữ thêm 2 state buffer cho mỗi parameter: m (first moment, trung bình gradient) và v (second moment, variance gradient). Update rule:

m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad^2
param = param - lr * m_hat / (sqrt(v_hat) + eps) - lr * weight_decay * param

Hệ quả: AdamW tốn 3x memory so với SGD vì giữ thêm mv. Với Llama-3-8B FP32, model weights 32GB, optimizer state thêm 64GB, tổng 96GB chỉ cho optimizer + model, chưa kể activation. Đây là lý do tại sao DeepSpeed ZeRO chia optimizer state qua nhiều GPU (bài 17 sẽ chi tiết).

So sánh optimizer:

OptimizerMemory overheadProsCons
SGD1x paramĐơn giản, ít memoryCần tune lr cẩn thận, convergence chậm
SGD + momentum2x paramConvergence nhanh hơnVẫn cần tune
Adam3x paramAdaptive lr, robustMemory cao
AdamW3x paramNhư Adam + decoupled weight decayMemory cao
Lion2x paramMới 2026, ít memory hơn AdamChưa được test rộng

LLM hiện đại gần như 100% dùng AdamW.

Phần 5: Learning rate schedule

Learning rate không nên giữ cố định. Đầu training cần lr lớn để đi nhanh, cuối training cần lr nhỏ để fine-tune. Schedule là cách thay đổi lr theo step.

Schedule phổ biến nhất cho LLM: cosine với warmup.

       lr
        |    warmup        cosine decay
   max  |          ___________
        |        /            \___
        |      /                  \___
        |    /                        \___
   min  |  /                              \___
        |/_______________________________________> step
        0   warmup_steps                      total_steps

Warmup: linear tăng từ 0 đến max trong vài nghìn step đầu. Tránh model rơi vào “bad region” do gradient không ổn định lúc bắt đầu.

Cosine decay: giảm từ max về min theo hàm cos. Smooth, không có cliff.

Implement bằng PyTorch:

from torch.optim.lr_scheduler import LambdaLR
import math

def get_lr_lambda(step, warmup_steps=1000, total_steps=100000, min_ratio=0.1):
    if step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_ratio + (1 - min_ratio) * 0.5 * (1 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=get_lr_lambda)

Gọi scheduler.step() sau mỗi optimizer.step().

GPT-3 dùng: max lr 6e-4, warmup 375 triệu tokens, cosine decay xuống 10% của max. Llama-3 dùng: max lr 3e-4, warmup 8000 steps, cosine decay xuống 10%.

Phần 6: Full training loop, runnable

Code này train được trên CPU trong vài phút, dùng synthetic data:

import torch
import torch.nn as nn
import math

torch.manual_seed(42)
N, D = 10000, 64
X = torch.randn(N, D)
true_w = torch.randn(D)
y = X @ true_w + 0.1 * torch.randn(N)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(D, 128), nn.GELU(),
            nn.Linear(128, 128), nn.GELU(),
            nn.Linear(128, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

model = MLP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
total_steps = 2000
warmup = 100

def lr_lambda(step):
    if step < warmup:
        return step / warmup
    progress = (step - warmup) / (total_steps - warmup)
    return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
loss_fn = nn.MSELoss()

batch_size = 64
for step in range(total_steps):
    idx = torch.randint(0, N, (batch_size,))
    xb, yb = X[idx], y[idx]

    pred = model(xb)
    loss = loss_fn(pred, yb)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    if step % 200 == 0:
        current_lr = optimizer.param_groups[0]["lr"]
        print(f"step {step:5d} | loss {loss.item():.4f} | lr {current_lr:.6f}")

Output mẫu:

step     0 | loss 65.3217 | lr 0.000010
step   200 | loss 1.2451 | lr 0.000997
step   400 | loss 0.0823 | lr 0.000936
step   600 | loss 0.0245 | lr 0.000815
step  1800 | loss 0.0098 | lr 0.000125

Loss giảm đều, learning rate warmup rồi decay. Pattern này lặp lại y hệt khi train transformer, chỉ khác model lớn hơn và data nhiều hơn.

Pitfall thực tế: loss không giảm

Có một lần tôi train một MLP nhỏ trên synthetic data, loss đứng yên ở 2.3 suốt 500 step. Đoán đủ thứ: lr quá cao, model bị broken, data sai. Cuối cùng nguyên nhân: quên gọi optimizer.zero_grad().

Gradient cộng dồn qua từng step, đến step 100 thì gradient đã to gấp 100 lần value thật. Optimizer cập nhật param theo gradient bị inflate, param nhảy lung tung không hội tụ.

Fix: thêm 1 dòng optimizer.zero_grad() trước loss.backward(). Loss bắt đầu giảm trong step thứ 2.

Bài học: khi loss không giảm, kiểm tra theo thứ tự:

  1. optimizer.zero_grad() có được gọi không?
  2. loss.backward() có chạy không (in loss để verify nó là tensor có graph)?
  3. Learning rate có quá cao (in optimizer.param_groups[0]["lr"])?
  4. Gradient có NaN không (torch.isnan(p.grad).any() cho mọi param)?
  5. Data có bị normalize lệch không (mean, std)?

90% trường hợp loss không giảm là do 1 trong 5 nguyên nhân trên.

Cheatsheet: PyTorch training API

CodeMục đích
model.train()Chuyển sang training mode (dropout, batch norm hoạt động)
optimizer.zero_grad()Reset gradient buffer về 0
loss.backward()Tính gradient backward qua graph
optimizer.step()Cập nhật param theo gradient
scheduler.step()Cập nhật learning rate
torch.nn.utils.clip_grad_norm_(params, max_norm)Clip gradient L2 norm
torch.no_grad()Context manager tắt graph (inference)
param.gradTensor chứa gradient của param đó
param.requires_gradParam có cần backward không
HyperparameterLLM range phổ biếnLưu ý
Learning rate1e-4 đến 6e-4Pretraining cao hơn fine-tune
Batch size (tokens)0.5M đến 4MGPT-3 dùng 3.2M, Llama-3 dùng 4M
Weight decay0.05 đến 0.1Decoupled trong AdamW
Gradient clip1.0Gần như universal
Warmup steps1-3% total stepsGPT-3: 0.4%
Beta1, beta2 (Adam)(0.9, 0.95)LLM dùng beta2 nhỏ hơn default 0.999

Lời kết

Bạn vừa đi qua nguyên tử của ML: training loop. Mọi paper LLM, mọi codebase research đều build trên 5 dòng này. Khi bạn debug training pipeline lần sau, hãy quay lại 5 bước cơ bản trước khi nghĩ đến những thứ phức tạp hơn.

Hands-on song song:

  1. Copy code trong Phần 6 vào một file train.py, chạy thử với Python local. Không cần GPU. Verify loss giảm.
  2. Modify: thử thay AdamW bằng SGD, xem convergence chậm hơn bao nhiêu. Thử bỏ warmup, xem loss có spike không.
  3. Đọc training loop của nanoGPT (train.py, khoảng 400 dòng). Nhận diện 5 thành phần trên trong đó. Phần còn lại của file là DDP, mixed precision, checkpointing, sẽ học ở bài 16 và 17.
  4. Nếu muốn dataset thật, dùng tinystories từ HuggingFace (datasets.load_dataset("roneneldan/TinyStories")) làm test bed. Tokenize bằng GPT-2 tokenizer rồi train một transformer 6 layer. Chạy trên Colab free tier khoảng 2 tiếng được.

Bài 15 sẽ bàn về Scaling laws Chinchilla: dữ liệu bao nhiêu, parameter bao nhiêu, compute bao nhiêu là tối ưu. Hiểu được scaling laws là biết được “model 7B của Meta train với 15 trillion tokens có phải overkill không”, “training data 100GB của tôi đủ cho model 1B không”. Đây là kiến thức economist của ML engineer.