Just got back from vacation, and super excited to finally release Griffin - a new hybrid LLM mixing RNN layers with Local Attention - scaled up to 14B params!
My co-authors have already posted about our amazing results, so here's a 🧵on how we got there!
Our work built on S4/S4D from
@_albertgu
et al, as well as our own work on LRU led by
@orvieto_antonio
, which simplified S4/S4D without sacrificing performance:
These models are blazingly fast at inference, but no one had scaled them up on language yet.
Initially at small scale, we saw S4D/LRU layers actually performed worse than even vanilla well-tuned LSTMs!
Why not just use LSTMs then? LSTMs have the same inference speed benefits as S4D/LRU, but are too slow to train and therefore cannot be scaled up!
S4D/LRU is much easier to scale since it uses a diagonal recurrent matrix. But how do we improve performance?
Inspired by LSTMs, we added a Recurrent Gate to LRU to allow discarding the input at time t and preserve info from the history (“RG-LRU”). This matched LSTM performance!
What about the overall architecture?
1. The MLP block makes the model expressive! ()
We found Gated MLP blocks to be better than vanilla MLPs.
2. The recurrent block can be simple.
Temporal Conv1D helps capture local functions RNNs struggle to express.
But this still doesn’t match the performance of a *well-tuned* transformer!
Solution: simply use a Local Attn (LA) block every 2 recurrent blocks!
LA has fixed-size state and so fast inference! LA captures the recent past, while RG-LRU models global structure. This is Griffin!
RG-LRU was still slow to train though as it was memory-bound. Inspired by
@tri_dao
's Flash Attn, we used Pallas to optimize HBM accesses → 3x speedup! We also use linear scans instead of parallel scans as its faster on TPUs.
Griffin now matches/beats Transformer training speed!
There were many more details involved in getting the project to succeed, but if I had to summarize our team's approach, we focussed on:
1. Simple ideas that scale
2. Attention to detail in model design & implementation
3. Careful hyperparameter tuning
As an additional point: one thing we really focused on was running fair comparisons against a well-tuned Transformer baseline. Our baseline performs remarkably well on downstream evals, outperforming many well-known models while being trained on significantly fewer tokens.
@sohamde_
This is a really interesting & well-written paper, Enjoyed reading it. I have a few questions:
From the paper, Scaling curves show continued improvements with model size. How do you expect performance to change as Griffin is scaled up even further to 100B+ parameters or more?…
@pengzhangzhi1
Yes we have. Gated MLPs work better than vanilla MLPs. We have ablations on different window sizes of Local Attention vs Global Attention in the appendix of the paper.