Bài 10 đã code được single-head self-attention từ zero bằng NumPy — Q, K, V, scaled dot-product, causal mask, softmax. Nó chạy được. Input tensor vào, output tensor ra, shape đúng.
Nhưng GPT-3 không có 1 head. Nó có 96 layers, mỗi layer 96 heads — xấp xỉ 9,000 attention heads đang làm việc song song mỗi khi bạn hỏi nó một câu. Llama-3-70B có 64 heads. Llama-3-8B có 32 heads. Thậm chí model nhỏ nhất trong dòng Llama cũng không bao giờ chạy single-head.
Câu hỏi hợp lý: nếu single-head đã tính được attention, tại sao lại cần thêm 31 cái nữa? Compute có giới hạn, memory có giới hạn, tại sao tốn thêm?
Bài này trả lời câu hỏi đó — intuition, cơ chế, công thức, code NumPy đầy đủ, và một cái nhìn vào điều gì đó thú vị: các head thực sự học được những thứ khác nhau, không phải ngẫu nhiên.
Mental model tổng quát
single-head attention:
1 ánh mắt nhìn toàn bộ câu, học 1 kiểu quan hệ tại một thời điểm
multi-head attention:
nhiều ánh mắt song song, mỗi cái chuyên về một kiểu:
head 1 ── syntax: "subject" nhìn về "verb"
head 2 ── coreference: "nó" nhìn về danh từ đứng trước
head 3 ── long-range: token đầu nhìn về token cuối câu
head 4 ── entity: "Paris" nhìn về "France"
...
output = concat tất cả góc nhìn → project về dimension gốc
Ý tưởng cốt lõi: một softmax-weighted attention matrix chỉ có thể học được một pattern duy nhất trong một lần chạy. Ngôn ngữ có nhiều kiểu quan hệ đồng thời — syntactic, semantic, positional, coreference — và tất cả chúng đều cần được nắm bắt trong cùng một forward pass. Multi-head là cách để làm điều đó mà không tăng depth (không thêm layer), chỉ tăng width.
Phần 1: Vấn đề với single-head
Trong single-head attention, mỗi token tính một vector attention weight phân phối trên tất cả token còn lại:
"The cat sat on the mat because it was soft"
token "it" nhìn về:
"The" 0.03
"cat" 0.25
"sat" 0.05
"on" 0.02
"the" 0.02
"mat" 0.58 <-- thảm, vì thảm mới "soft"
...
Attention weight là một phân phối xác suất — tổng bằng 1. Điều đó có nghĩa: một head chỉ có thể phân bổ attention một lần. Nó phải chọn giữa “học coreference (it → mat)” hoặc “học syntactic (it → is)”. Không thể làm cả hai cùng lúc trong một phân phối duy nhất.
Nhưng câu “it was soft” cần cả hai: phải biết “it” refer đến “mat” (coreference), đồng thời phải biết “it” là subject của “was” (syntax). Và trong một đoạn văn dài hơn, còn cần quan hệ long-range, entity linking, temporal ordering…
Đây không phải vấn đề của model quá nhỏ. Đây là giới hạn kiến trúc: một softmax không thể biểu diễn nhiều distribution cùng lúc.
Solution đơn giản và hiệu quả: chạy nhiều attention head song song, mỗi head có Q/K/V projection riêng, cho phép mỗi head học một loại pattern khác nhau.
Phần 2: Kiến trúc multi-head attention
Chia dimension, không nhân compute
Cách naive nhất để có h heads: tạo h bộ Q/K/V ma trận riêng biệt, mỗi bộ kích thước (d_model, d_model). Nhưng như vậy compute tăng h lần — không hiệu quả.
Cách thực tế: chia d_model thành h phần. Mỗi head nhận d_k = d_model / h dimension, không phải d_model đầy đủ.
Ví dụ cụ thể:
d_model = 512 (dimension của mỗi token embedding)
h = 8 (số heads)
d_k = 64 (dimension mỗi head = 512 / 8)
8 heads, mỗi head xử lý 64 dimension
→ tổng computation ~ bằng 1 head xử lý 512 dimension
Về lý thuyết, tổng compute không tăng. Nhưng thay vì 1 head nhìn toàn bộ 512 chiều cùng một phân phối attention, ta có 8 heads, mỗi head nhìn 64 chiều riêng với phân phối attention riêng.
Cấu trúc projection
Mỗi head i có ba ma trận projection riêng:
W_q_i : (d_model, d_k) — project Q cho head i
W_k_i : (d_model, d_k) — project K cho head i
W_v_i : (d_model, d_k) — project V cho head i
Và sau khi tất cả heads chạy xong, có một output projection chung:
W_O : (d_model, d_model) — trộn output từ tất cả heads
Flow dữ liệu qua multi-head
Input X: (B, N, d_model) -- batch, sequence_len, dim
|
v
8 heads chạy song song:
head_1: X → (Q1, K1, V1) → Attention → out_1 (B, N, d_k)
head_2: X → (Q2, K2, V2) → Attention → out_2 (B, N, d_k)
...
head_8: X → (Q8, K8, V8) → Attention → out_8 (B, N, d_k)
|
v
Concat: (B, N, d_k * 8) = (B, N, d_model)
|
v
Output projection W_O: (B, N, d_model)
Mỗi head tạo ra output kích thước (B, N, d_k). Concat 8 heads lại thành (B, N, d_k × 8) = (B, N, d_model). Nhân với W_O để “trộn” thông tin từ tất cả heads lại thành vector cuối kích thước (B, N, d_model) — giống y hệt single-head về shape, chỉ khác về nội dung.
Phần 3: Công thức và shapes
Công thức chính thức từ paper “Attention is All You Need”:
head_i = Attention(X @ W_q_i, X @ W_k_i, X @ W_v_i)
MultiHead(X) = Concat(head_1, ..., head_h) @ W_O
Trong đó hàm Attention là scaled dot-product từ bài 10:
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
Shapes cụ thể, ví dụ với GPT-small (d_model=512, h=8, d_k=64):
| Tensor | Shape | Ghi chú |
|---|---|---|
| X (input) | (B, N, 512) | sequence với d_model=512 |
| W_q_i, W_k_i, W_v_i | (512, 64) | projection per head |
| Q_i, K_i, V_i | (B, N, 64) | projected queries/keys/values |
| Attention weights | (B, N, N) | per head |
| head_i output | (B, N, 64) | per head |
| Concat output | (B, N, 512) | sau khi concat 8 heads |
| W_O | (512, 512) | output projection |
| Final output | (B, N, 512) | về dimension gốc |
Điểm quan trọng: input và output cùng shape (B, N, d_model). Multi-head attention là một “black box” nhìn từ ngoài giống hệt single-head — chỉ khác là nó tính toán phức tạp hơn bên trong.
Phần 4: Implementation trick — reshape thay vì loop
Cách naive: viết for-loop qua từng head, tạo h bộ ma trận riêng. Đúng về mặt toán nhưng chậm vì không tận dụng batch matmul.
Cách hiệu quả: dùng 1 ma trận lớn rồi reshape. Thay vì h ma trận W_q kích thước (d_model, d_k), dùng một ma trận W_q kích thước (d_model, d_model) rồi reshape kết quả.
# Thay vì:
[W_q_1, W_q_2, ..., W_q_h] -- h ma trận (d_model, d_k) riêng
# Dùng:
W_q shape (d_model, d_model) -- 1 ma trận lớn
# Sau khi project:
Q = X @ W_q -- (B, N, d_model)
Q = Q.reshape(B, N, h, d_k) -- (B, N, h, d_k)
Q = Q.transpose(0, 2, 1, 3) -- (B, h, N, d_k)
# Bây giờ tất cả heads nằm trong batch dimension
# Batched matmul tự động chạy song song tất cả heads
scores = Q @ K.transpose(...) -- (B, h, N, N)
Reshape này không thay đổi kết quả về mặt toán học — chỉ sắp xếp lại data trong memory để tận dụng vectorization. Đây là lý do implementation thực tế trong PyTorch/JAX nhanh hơn nhiều so với naive loop.
Phần 5: Code multi-head attention bằng NumPy
Code đầy đủ, chạy được, không dependencies ngoài NumPy:
import numpy as np
np.random.seed(42)
class MultiHeadAttention:
def __init__(self, d_model, num_heads):
assert d_model % num_heads == 0, "d_model phải chia hết cho num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 1 big projection matrix per Q/K/V, split sau khi nhân
scale = 0.1
self.W_q = np.random.randn(d_model, d_model) * scale
self.W_k = np.random.randn(d_model, d_model) * scale
self.W_v = np.random.randn(d_model, d_model) * scale
self.W_o = np.random.randn(d_model, d_model) * scale
def __call__(self, X, causal=True):
B, N, D = X.shape
h, d_k = self.num_heads, self.d_k
# Project toàn bộ sequence với 1 matmul
Q = X @ self.W_q # (B, N, D)
K = X @ self.W_k # (B, N, D)
V = X @ self.W_v # (B, N, D)
# Split heads: (B, N, D) -> (B, h, N, d_k)
Q = Q.reshape(B, N, h, d_k).transpose(0, 2, 1, 3)
K = K.reshape(B, N, h, d_k).transpose(0, 2, 1, 3)
V = V.reshape(B, N, h, d_k).transpose(0, 2, 1, 3)
# Attention scores: (B, h, N, N)
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)
# Causal mask: token i không nhìn được token j > i
if causal:
mask = np.triu(np.ones((N, N)), k=1) * -1e9
scores = scores + mask
# Softmax trên axis cuối (N keys)
weights = self._softmax(scores, axis=-1)
# Weighted sum of values: (B, h, N, d_k)
output = weights @ V
# Concat heads: (B, h, N, d_k) -> (B, N, D)
output = output.transpose(0, 2, 1, 3).reshape(B, N, D)
# Output projection
return output @ self.W_o, weights
@staticmethod
def _softmax(x, axis):
# Numerically stable softmax
x_shifted = x - np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x_shifted)
return e_x / e_x.sum(axis=axis, keepdims=True)
# Test basic shapes
X = np.random.randn(2, 5, 8) # batch=2, seq_len=5, d_model=8
mha = MultiHeadAttention(d_model=8, num_heads=2)
out, weights = mha(X)
print(f"Input shape: {X.shape}") # (2, 5, 8)
print(f"Output shape: {out.shape}") # (2, 5, 8)
print(f"Weights shape: {weights.shape}") # (2, 2, 5, 5) -- (batch, heads, N, N)
Output khi chạy:
Input shape: (2, 5, 8)
Output shape: (2, 5, 8)
Weights shape: (2, 2, 5, 5)
weights có shape (B, h, N, N) — mỗi head có một attention matrix riêng. Đây là điểm để visualize head specialization.
So sánh output từng head
Thêm đoạn này để thấy rõ các heads cho kết quả khác nhau:
# Tách weights từng head để so sánh
head_0_weights = weights[0, 0] # batch 0, head 0: (5, 5)
head_1_weights = weights[0, 1] # batch 0, head 1: (5, 5)
print("Head 0 attention (first row, first token attending to all):")
print(np.round(head_0_weights[0], 3))
print("Head 1 attention (first row):")
print(np.round(head_1_weights[0], 3))
# Với random weights, hai head sẽ cho pattern khác nhau
# Sau training, sự khác biệt này có ý nghĩa ngữ nghĩa
Với random weights, sự khác biệt chỉ là nhiễu. Sau training, head 0 có thể học pattern ngắn (attend mostly local tokens), head 1 học pattern dài. Đó là head specialization.
Phần 6: Heads học gì trong thực tế
Một số nghiên cứu quan trọng đã visualize attention patterns trong BERT và GPT:
Clark et al. 2019 phân tích BERT, phát hiện:
- Một số heads chuyên “attend to next token” — phản ánh local syntactic structure
- Một số heads chuyên “attend to previous token”
- Head đặc biệt chuyên về dấu câu
[SEP]và[CLS] - Heads học được coreference resolution mà không được train explicitly
Voita et al. 2019 thử prune attention heads trong Transformer dịch máy:
- Có thể remove 48/64 heads (75%) mà BLEU score gần như không đổi
- 16 heads “quan trọng” đảm nhận các vai trò rõ ràng: positional heads, syntactic heads, rare-word heads
Điều này dẫn đến một quan sát thực tế quan trọng: không phải tất cả heads đều quan trọng như nhau. Nhiều heads có thể là redundant — model học nhiều “cách backup” cho cùng một kiểu pattern. Đây liên quan đến lottery ticket hypothesis: trong một mạng lớn, luôn tồn tại một subnetwork nhỏ hơn đạt performance tương đương.
Nhưng từ góc độ training, để biết head nào quan trọng cần phải train trước — nên về mặt thực tế, đơn giản hơn là giữ nhiều heads và để model tự tìm ra cái nào cần dùng.
Ví dụ visualize pattern đơn giản
# Simulate câu có 6 tokens: ["The", "cat", "sat", "it", "was", "soft"]
# Giả sử head 0 attend strongly to local context
# Giả sử head 1 attend strongly to distant tokens
# Đây là pattern bạn sẽ thấy sau training, không phải với random weights
local_pattern = np.array([
[0.70, 0.20, 0.05, 0.02, 0.02, 0.01], # "The" mostly self-attend
[0.40, 0.45, 0.10, 0.03, 0.01, 0.01], # "cat" attend "The" + self
[0.10, 0.30, 0.50, 0.07, 0.02, 0.01], # "sat" attend "cat" + self
[0.05, 0.10, 0.15, 0.60, 0.07, 0.03], # "it" attend "cat" (coreference)
[0.02, 0.05, 0.10, 0.30, 0.48, 0.05], # "was" attend "it" (verb-subject)
[0.01, 0.02, 0.03, 0.04, 0.05, 0.85], # "soft" attend "mat" (predicate)
])
# Head 1: long-range dependency pattern
longrange_pattern = np.array([
[0.30, 0.05, 0.05, 0.20, 0.20, 0.20], # attend cả đầu lẫn cuối
[0.10, 0.30, 0.05, 0.25, 0.20, 0.10],
[0.08, 0.10, 0.25, 0.20, 0.25, 0.12],
[0.05, 0.08, 0.10, 0.25, 0.30, 0.22], # "it" attend "soft" (long-range)
[0.05, 0.05, 0.10, 0.15, 0.35, 0.30],
[0.05, 0.05, 0.10, 0.15, 0.30, 0.35],
])
print("Local context head — each token mostly attends to neighbors:")
for i, row in enumerate(local_pattern):
dominant = row.argmax()
print(f" token {i}: dominant attention → token {dominant} ({row[dominant]:.2f})")
print("\nLong-range head — attend across distances:")
for i, row in enumerate(longrange_pattern):
dominant = row.argmax()
print(f" token {i}: dominant attention → token {dominant} ({row[dominant]:.2f})")
Phần 7: Variants — MHA, MQA, GQA, MLA
Kể từ 2017, đã có nhiều biến thể của multi-head attention, chủ yếu tập trung vào giảm memory footprint trong inference (KV cache là phần tốn memory nhất).
MHA — Multi-Head Attention (original, 2017)
Mỗi head có bộ Q, K, V projection riêng hoàn toàn. Đây là kiến trúc trong bài viết này.
h heads × (W_q, W_k, W_v) mỗi head
→ memory KV cache: 2 × h × d_k × N × L (N = seq_len, L = layers)
MQA — Multi-Query Attention (2019, Shazeer)
Tất cả heads chia sẻ chung một W_k và một W_v, chỉ Q là khác nhau.
h bộ W_q riêng
1 bộ W_k shared
1 bộ W_v shared
→ memory KV cache giảm h lần
→ quality giảm nhẹ, nhưng inference nhanh hơn đáng kể
Dùng trong PaLM và một số phiên bản GPT-J.
GQA — Grouped-Query Attention (2023, Google)
Compromise giữa MHA và MQA: chia heads thành G nhóm, các heads trong cùng nhóm share K và V.
h heads, G groups (G < h)
→ mỗi group có 1 W_k, 1 W_v shared
→ mỗi head vẫn có W_q riêng
ví dụ: 32 heads, 8 groups → 4 heads / group
→ memory KV cache giảm 4 lần so với MHA
Đây là kiến trúc được dùng trong Llama 2/3, Mistral, Gemma. Đây là sự lựa chọn mặc định của phần lớn open-source LLM hiện tại.
MLA — Multi-head Latent Attention (2024, DeepSeek)
DeepSeek V2/V3 dùng kiến trúc này. Thay vì lưu K và V đầy đủ trong KV cache, model nén chúng xuống một “latent vector” nhỏ hơn, rồi expand lại khi cần.
KV cache: lưu latent vector c kích thước nhỏ
→ khi cần: expand c → (K, V) qua learned projection
→ memory giảm đáng kể so với MHA
→ DeepSeek V2 claim giảm 93.3% KV cache memory
Đây là một trong những lý do DeepSeek có inference cost thấp hơn các model tương đương.
| Variant | Q | K | V | KV Cache | Dùng trong |
|---|---|---|---|---|---|
| MHA | h riêng | h riêng | h riêng | lớn nhất | GPT-2, BERT, original Transformer |
| MQA | h riêng | 1 shared | 1 shared | giảm h lần | PaLM, GPT-J |
| GQA | h riêng | G shared | G shared | giảm h/G lần | Llama 2/3, Mistral, Gemma |
| MLA | h riêng | latent | latent | giảm ~93% | DeepSeek V2/V3 |
Cheatsheet
5 điểm cần nhớ về multi-head attention:
-
Mỗi head học một pattern khác nhau. Không phải do design cứng — do training tự nhiên phân vai. Head nào “thấy được” pattern nào thì phụ trách pattern đó.
-
d_k = d_model / h, không phải d_model. Mỗi head xử lý dimension nhỏ hơn, nên tổng compute xấp xỉ single-head với cùng d_model.
-
Reshape, không loop. Efficient implementation dùng reshape + transpose để tất cả heads chạy song song trong một batched matmul.
-
W_O là quan trọng. Output projection không phải formality — nó “trộn” information từ tất cả heads, cho phép mỗi head influence output cuối theo trọng số được học.
-
GQA là default 2024. Nếu bạn đọc một model card mới và thấy “GQA” hay “grouped-query”, đây là multi-head với K/V được share trong group để tiết kiệm KV cache.
Sizes trong các model thực tế:
GPT-2 small: d_model=768, h=12, d_k=64
GPT-2 large: d_model=1280, h=20, d_k=64
GPT-3 175B: d_model=12288, h=96, d_k=128
Llama-3-8B: d_model=4096, h=32, d_k=128, GQA groups=8
Llama-3-70B: d_model=8192, h=64, d_k=128, GQA groups=8
DeepSeek V3: d_model=7168, h=128, d_k=? MLA
Lời kết
Bài 10 code single-head. Bài này thêm “nhiều” vào. Nhưng “nhiều” không chỉ là chạy nhiều lần — nó là architectural decision cho phép model học nhiều loại quan hệ song song trong cùng một layer, rồi trộn chúng lại qua output projection.
Trước khi sang bài 12, thử một experiment nhỏ: lấy code MultiHeadAttention trong bài này, chạy với num_heads=1, num_heads=2, num_heads=4, num_heads=8 với cùng d_model=64. Quan sát shape của weights mỗi lần — đặc biệt là (B, h, N, N). Khi h=1, bạn đang nhìn vào attention pattern của single-head. Khi h=8, bạn có 8 patterns song song. Print ra và so sánh — với random weights chúng sẽ khác nhau, nhưng chưa có ý nghĩa. Sau khi có training (bài 14), sự khác nhau đó mới có cấu trúc.
Bài 12: Transformer block đầy đủ — ghép multi-head attention với Feed-Forward Network, Layer Normalization, và residual connection. Đây là “module” lặp lại 32 đến 96 lần trong mỗi LLM. Hiểu block này là hiểu kiến trúc cốt lõi của mọi thứ từ BERT đến GPT-4 đến Llama.