zlacker

[parent] [thread] 0 comments
1. janals+(OP)[view] [source] 2026-02-05 00:26:56
Well for example the last step is to softmax over all output logits, which is the same as your vocab size. You need the sum of the exponentiated values of each logit to calculate the denominator which is O(N).

Bigger impact is before that you need to project the hidden state matrix to the vocab list. Something like 4096x250000. Bigger vocab=more FLOPs.

If you’re on a GPU things are parallelized so maybe it’s not quite linear if everything fits nicely. But on a cpu you’re going to struggle more.

This is why the juiciest target when shrinking models is the token embedding table. For example AlBERT factorized the whole embedding table to two low rank matrices.

[go to top]