Bài 9 đã xây xong intuition: Query là câu hỏi, Key là nhãn, Value là nội dung thực. Dot product Q·K cho điểm liên quan, softmax chuẩn hóa thành attention weights, rồi nhân với V để trộn thông tin. Đẹp về ý nghĩa — nhưng vẫn còn là lý thuyết.

Bài này code thật. Mục tiêu cụ thể: viết một SelfAttention class bằng NumPy thuần, xử lý batch input, có causal mask, có scaling — rồi verify rằng output của nó khớp với torch.nn.functional.scaled_dot_product_attention trong PyTorch. Nếu hai cái đó match, bạn biết mình không làm sai.

80 dòng. Không PyTorch cho core logic. Sau bài này attention không còn là hộp đen.

Mental model trước khi code

Self-attention layer là 5 bước tuần tự. Hiểu pipeline trước, code sẽ tự nhiên hơn:

input X [B, N, D]
    -> [Q, K, V] projection:  X @ W_q,  X @ W_k,  X @ W_v   -> [B, N, D]
    -> scores:                 Q @ K.transpose(-2, -1)         -> [B, N, N]
    -> scale:                  / sqrt(d_k)
    -> mask (causal):          cộng -inf vào upper triangle
    -> softmax                                                  -> [B, N, N]
    -> output:                 weights @ V                      -> [B, N, D]

Chú ý: shape đầu vào [B, N, D] và shape đầu ra [B, N, D] giống nhau. Self-attention không thay đổi kích thước — nó chỉ làm phong phú thêm mỗi vector bằng context từ các token khác trong cùng sequence.

Ba chữ cái đáng ghi nhớ: B là batch size, N là sequence length (số token), D là d_model (chiều embedding).

Phần 1: Setup — dimensions và test input

import numpy as np
np.random.seed(42)

# Config
batch_size = 2
seq_len    = 5
d_model    = 8   # embedding dimension

# Fake input: 2 sequences, mỗi cái 5 tokens, mỗi token là vector 8 chiều
X = np.random.randn(batch_size, seq_len, d_model)
print(f"Input shape: {X.shape}")  # (2, 5, 8)

np.random.seed(42) để kết quả reproducible — quan trọng khi so sánh với PyTorch sau này. randn cho ra normal distribution với mean=0, std=1, phù hợp mô phỏng embedding đã được normalize.

Trong thực tế, X là output của embedding layer cộng positional encoding. Ở đây dùng random để tập trung vào logic attention.

Phần 2: Q, K, V projections

# Weight matrices (thực tế được học; ở đây random init nhỏ)
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1

Q = X @ W_q  # (2, 5, 8)
K = X @ W_k  # (2, 5, 8)
V = X @ W_v  # (2, 5, 8)

Tại sao nhân * 0.1 khi khởi tạo weights? Nếu weights quá lớn, dot product Q·K sẽ có magnitude lớn, softmax sẽ bão hòa ngay, gradient nhỏ — network khó train. Đây là Xavier-style initialization đơn giản.

Cơ chế broadcasting trong NumPy: X có shape (2, 5, 8), W_q có shape (8, 8). Khi viết X @ W_q, NumPy xử lý @ theo batched matmul: hai chiều cuối làm matmul (5,8) @ (8,8) = (5,8), chiều batch 2 được giữ nguyên. Kết quả (2, 5, 8). Không cần loop qua từng sample trong batch.

Ba phép nhân này là ba “góc nhìn” khác nhau của cùng một input. W_q học cách trích xuất “cái gì đang tìm kiếm”, W_k học “cái gì đang được offer”, W_v học “nội dung thực sự cần lấy về”. Ba projection riêng biệt cho phép attention linh hoạt hơn nhiều so với dùng X trực tiếp.

Phần 3: Attention scores

d_k = d_model
scores = Q @ K.transpose(0, 2, 1)  # (2, 5, 5)
scores = scores / np.sqrt(d_k)
print(f"Scores shape: {scores.shape}")  # (2, 5, 5)

K.transpose(0, 2, 1) hoán đổi hai chiều cuối: (2, 5, 8) thành (2, 8, 5). Batched matmul (2,5,8) @ (2,8,5) cho ra (2, 5, 5) — ma trận vuông N×N chứa điểm liên quan giữa từng cặp token.

scores[b, i, j] là độ liên quan của token i đến token j trong sequence b.

Tại sao chia cho sqrt(d_k)? Xét trường hợp Q và K là các vector random với mean=0, variance=1. Khi d_k lớn, dot product Q·K là tổng của d_k phần tử ngẫu nhiên — variance của tổng này bằng d_k, std bằng sqrt(d_k). Nếu không scale, với d_k=512, dot product có thể đạt magnitude ~22. Đưa vào softmax, các giá trị lớn như vậy sẽ cho ra phân phối gần như one-hot: một token chiếm 99.9% attention, còn lại gần 0. Gradient của softmax gần như biến mất — network không học được. Chia cho sqrt(d_k) đưa variance về ~1, softmax hoạt động ở regime tốt hơn.

Phần 4: Causal mask

Causal mask là thứ biến self-attention từ “nhìn thấy toàn bộ sequence” thành “chỉ nhìn thấy quá khứ và hiện tại”. Bắt buộc khi train language model vì không thể cho phép token ở vị trí i nhìn vào token j > i — token chưa được sinh ra không thể ảnh hưởng đến dự đoán hiện tại.

mask = np.triu(np.ones((seq_len, seq_len)), k=1) * -1e9
scores = scores + mask

np.triu(..., k=1) tạo upper triangle không bao gồm đường chéo chính:

token\token  0    1    2    3    4
0          [ 0   -inf -inf -inf -inf ]
1          [ 0    0   -inf -inf -inf ]
2          [ 0    0    0   -inf -inf ]
3          [ 0    0    0    0   -inf ]
4          [ 0    0    0    0    0   ]

Token 0 chỉ attend đến chính nó. Token 2 attend đến token 0, 1, 2. Token 4 attend đến tất cả. Đây chính xác là cách GPT xử lý: token hiện tại nhìn về quá khứ, không nhìn về tương lai.

Tại sao -1e9 thay vì -inf thực sự? Khi cộng -inf vào rồi qua softmax, exp(-inf) = 0 về mặt toán học. Nhưng trong floating point, np.exp(-np.inf) cho ra 0.0 — ổn. Tuy nhiên, -inf + giá trị hữu hạn = -inf trong float64, nhưng có thể gây NaN trong một số tính toán mixed-precision (float16). -1e9 an toàn hơn: exp(-1e9) ≈ 0 nhưng tránh được các edge case. Thực tế PyTorch dùng -inf với masked_fill, nhưng cũng xử lý riêng cho float16. Với NumPy demo này, -1e9 là lựa chọn pragmatic.

Broadcasting tự động: mask shape (5, 5) được broadcast sang scores shape (2, 5, 5) — cùng mask áp dụng cho cả hai samples trong batch.

Phần 5: Softmax

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    e_x = np.exp(x - x_max)
    return e_x / e_x.sum(axis=axis, keepdims=True)

weights = softmax(scores, axis=-1)
print(f"Weights shape: {weights.shape}")      # (2, 5, 5)
print(f"Row sums (should be 1.0):")
print(weights[0].sum(axis=-1))               # [1.0, 1.0, 1.0, 1.0, 1.0]

axis=-1 có nghĩa softmax theo chiều cuối cùng — tức là theo chiều “key token”. Mỗi hàng i trong ma trận (N, N) là phân phối xác suất cho biết token i phân bổ attention bao nhiêu cho từng token j.

Numerical stability trick: Không viết trực tiếp exp(x) / sum(exp(x)). Nếu x có giá trị lớn (ví dụ 100), exp(100) overflow thành inf trong float32. Trick là trừ max trước: softmax(x) = softmax(x - max(x)) về mặt toán học là tương đương (tử và mẫu cùng nhân exp(-max(x))), nhưng số lớn nhất trong exp(...) bây giờ là exp(0) = 1 — không overflow.

Kiểm tra weights.sum(axis=-1) phải ra tất cả 1.0. Nếu không, implementation có bug.

Phần 6: Output = weights @ V

output = weights @ V  # (2, 5, 8)
print(f"Output shape: {output.shape}")  # (2, 5, 8)

Đây là bước tổng hợp. weights[b, i, :] là phân phối attention của token i trong batch b. Nhân với V[b, :, :] cho ra output[b, i, :] — vector của token i sau khi đã được làm phong phú bởi thông tin từ các token khác, mỗi token đóng góp theo đúng trọng số attention.

Shape cuối (2, 5, 8) khớp với shape đầu vào (2, 5, 8). Self-attention là một phép biến đổi “shape-preserving” — đưa vào ma trận N×D, nhận lại ma trận N×D nhưng giàu context hơn.

Đây là điều khiến Transformer có thể stack nhiều layer: output của layer này trở thành input của layer tiếp theo mà không cần reshape gì cả.

Phần 7: Wrap thành class reusable

class SelfAttention:
    def __init__(self, d_model):
        self.d_model = d_model
        self.W_q = np.random.randn(d_model, d_model) * 0.1
        self.W_k = np.random.randn(d_model, d_model) * 0.1
        self.W_v = np.random.randn(d_model, d_model) * 0.1

    def __call__(self, X, causal=True):
        B, N, D = X.shape

        Q = X @ self.W_q  # (B, N, D)
        K = X @ self.W_k  # (B, N, D)
        V = X @ self.W_v  # (B, N, D)

        # Scores + scaling
        scores = Q @ K.transpose(0, 2, 1) / np.sqrt(self.d_model)  # (B, N, N)

        # Causal mask
        if causal:
            mask = np.triu(np.ones((N, N)), k=1) * -1e9
            scores = scores + mask

        # Attention weights
        weights = softmax(scores, axis=-1)  # (B, N, N)

        # Output
        return weights @ V  # (B, N, D)


# Test
np.random.seed(42)
attn = SelfAttention(d_model=8)

X = np.random.randn(2, 5, 8)
out = attn(X)
print(f"Output shape: {out.shape}")   # (2, 5, 8)
print(f"Sample output[0, 0, :4]: {out[0, 0, :4].round(4)}")

Class này chứa toàn bộ logic trong 20 dòng. __call__ cho phép gọi như function: attn(X). Flag causal=True là default vì đây là setup cần thiết cho language model; có thể tắt đi với causal=False cho encoder-only model như BERT.

Lưu ý B, N, D = X.shape — unpack shape để code rõ ràng hơn, và N được dùng khi tạo mask để mask tự động adapt với sequence length khác nhau.

Phần 8: Verify với PyTorch

Đây là bước quan trọng nhất. Tự implement một thuật toán mà không verify là không đủ — cần bằng chứng rằng output đúng.

import torch
import numpy as np

# Reuse X, Q, K, V từ phần trước (trước khi wrap vào class)
np.random.seed(42)
batch_size, seq_len, d_model = 2, 5, 8

X   = np.random.randn(batch_size, seq_len, d_model)
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1

Q = X @ W_q
K = X @ W_k
V = X @ W_v

# NumPy output (recompute sạch)
d_k    = d_model
scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_k)
mask   = np.triu(np.ones((seq_len, seq_len)), k=1) * -1e9
scores = scores + mask
weights = softmax(scores, axis=-1)
output_np = weights @ V

# PyTorch reference
Q_t = torch.tensor(Q, dtype=torch.float64)
K_t = torch.tensor(K, dtype=torch.float64)
V_t = torch.tensor(V, dtype=torch.float64)

output_torch = torch.nn.functional.scaled_dot_product_attention(
    Q_t, K_t, V_t,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,
)

# So sánh
match = np.allclose(output_np, output_torch.numpy(), atol=1e-6)
print(f"Match: {match}")   # Match: True

# Nếu muốn xem độ sai khác
diff = np.abs(output_np - output_torch.numpy())
print(f"Max absolute diff: {diff.max():.2e}")   # < 1e-10

Một lưu ý quan trọng về dtype: NumPy mặc định float64, PyTorch mặc định float32. Khi so sánh hai bên, dùng torch.float64 để tránh sai khác do precision. Nếu dùng float32, diff vẫn nhỏ (< 1e-5) nhưng không hoàn toàn khớp ở bit level.

torch.nn.functional.scaled_dot_product_attention là implementation chính thức của PyTorch từ 2.0, hỗ trợ Flash Attention nếu hardware cho phép. Khi is_causal=True, nó tự động apply causal mask — đây là thứ chúng ta tự viết bằng np.triu. Hai bên khớp nhau xác nhận logic của np.triu(..., k=1) * -1e9 là đúng.

Cheatsheet

Shape biến đổi qua từng bước:

BướcOperationInput shapeOutput shape
ProjectionX @ W_q/k/v(B, N, D)(B, N, D)
ScoresQ @ K.T(B, N, D), (B, D, N)(B, N, N)
Scale/ sqrt(d_k)(B, N, N)(B, N, N)
Mask+ upper_triangle(-inf)(B, N, N)(B, N, N)
Softmaxexp / sum(B, N, N)(B, N, N) (mỗi hàng sum=1)
Outputweights @ V(B, N, N), (B, N, D)(B, N, D)

Năm điểm phải nhớ:

  1. QKV là ba projection riêng biệt từ cùng một input — ba “góc nhìn” khác nhau của cùng dữ liệu.
  2. Scores là ma trận N×N — tốn bộ nhớ O(N²). Đây là lý do sequence length lớn là bottleneck của attention cổ điển.
  3. Chia sqrt(d_k) để giữ variance — không phải ma thuật, là toán học cơ bản về variance của tổng.
  4. Causal mask bắt buộc với language model — không có nó, model thấy tương lai trong quá trình training, leaks answer.
  5. Output shape = input shape — cho phép stack nhiều layer không cần reshape giữa các layer.

Lời kết

80 dòng NumPy, không có framework. Bạn vừa implement thứ nằm trong tim của mọi LLM hiện đại — từ GPT-2 năm 2019 đến Llama-3 năm 2024, core attention logic không thay đổi nhiều. Những gì thay đổi là cách tối ưu hóa nó: Flash Attention (tính attention mà không materialize ma trận N×N đầy đủ), GQA (grouped-query attention để giảm memory KV cache), Sliding Window Attention (giới hạn context window). Nhưng tất cả đều là biến thể của pipeline 5 bước bạn vừa code.

Trước khi sang bài 11, thử mấy thứ này để hiểu sâu hơn:

  • Thay đổi seq_len thành 10, 20, 50 và xem attention weights thay đổi thế nào
  • Tắt causal mask (causal=False) và so sánh weights pattern với có mask
  • print(weights[0]) để visualize ma trận 5×5 — hàng nào attend nhiều nhất vào đâu?
  • Thử d_model=1 — điều gì xảy ra khi chỉ có 1 chiều?

Bài 11 sẽ mở rộng thành multi-head attention: thay vì một bộ QKV duy nhất, dùng nhiều bộ song song rồi concat lại. Về mặt code, đó là reshape và thêm vài dòng. Về mặt lý do, mỗi head học một khía cạnh khác nhau của quan hệ giữa các token — head này chuyên về syntax, head kia về semantic, head khác về long-range dependency. Cùng một cơ chế, sức mạnh nhân lên h lần.