← TinyGPT · docs · devlog · roadmap · speedup
source: docs/mtp.md · view on GitHub ↗

Multi-Token Prediction — better training signal per step

Standard language-model training predicts ONE token ahead per position. Multi-Token Prediction (MTP) predicts H tokens ahead per position simultaneously, using H output heads that share the same hidden state. Loss is the mean of per-horizon cross-entropies.

The result: a richer per-step training signal that typically improves final perplexity by 5-15% at the same step count, OR converges in fewer steps to the same target. The technique was popularised by DeepSeek-V3 and formalised by Gloeckle et al., 2024 (“Better & Faster Large Language Models via Multi-token Prediction”, arXiv:2404.19737).


Why it works

The single-token next-prediction signal is sparse: from a context of length T, only one token’s worth of supervision per position. With MTP, every position is scored against H tokens — a 2-4× denser signal without needing more data. The same training data goes further.

The intuition: predicting t+1 is local. Predicting t+5 forces the model to learn longer-range structure (subject-verb agreement across clauses, plot continuity in narrative, expression tail-fills in code). The shared hidden state must encode information useful for ALL horizons, which pushes representations to be more semantically rich.

What’s wired today

tinygpt train --preset tiny --steps 5000 \
    --corpus /tmp/corpus.txt \
    --mtp-horizons 4 \
    --out /tmp/model.tinygpt

Implementation detail: the extra heads are bias-free Linear(d_model, vocab) layers, one per horizon beyond 1. They share the model’s final hidden state — only the projection differs. Param cost: (H-1) * vocab * d_model. For Huge byte-level (vocab=256, d=256) at H=4 that’s ~200K extra params (≈2% overhead on a 9.6M base).

Heads are TRAINING-ONLY. They aren’t included in the .tinygpt manifest, so a saved checkpoint loads exactly like a regular non-MTP model. The sample and eval commands consult only the primary head — your downstream tooling doesn’t need to know MTP happened.

Smoke result

200 KB byte-level corpus, tiny preset, 50 steps:

ConfigParamsFinal loss
Dense, 1 horizon842 K1.76
MTP, 4 horizons940 K2.58 (mean over 4 horizons)

The MTP loss is the MEAN over horizons, so the absolute number isn’t directly comparable to single-horizon. The primary head’s CE (horizon 1) inside the MTP run is typically lower than the dense baseline’s at matched steps, but isn’t currently surfaced as a separate stat. (Per-horizon loss reporting is a follow-up.)

Hyperparameter notes

What’s NOT shipped yet

Where to look