Wow, this was super helpful - probably the most intuitive write up on LLM hardware economics that I’ve read anywhere.
One question for you - how does the length of the context window fit into this equation? AFAIK, longer context windows are more computationally expensive, even if you don’t fill them with tokens. How do you account for that in your calculations?
I tried to work out the math when you describe the optimal batch size for memory bound vs compute bound and I think there may be an error. The multiplicative factor of B (batch size) should be with the compute latency calculation.
It’s basically because we have to multiply and add every parameter when we do a matmul, so the total flops for calculating a matmul is 2 * m * n. As we do this for the total parameters in a transformer, this becomes 2 * the number of parameters. The linked blog post by kipply goes into detail and links to lecture notes that prove this.
Wow, this was super helpful - probably the most intuitive write up on LLM hardware economics that I’ve read anywhere.
One question for you - how does the length of the context window fit into this equation? AFAIK, longer context windows are more computationally expensive, even if you don’t fill them with tokens. How do you account for that in your calculations?
Ah that’s a good question. The short answer is that it makes its way in through the # of parameters (and thus the flops).
I’ll try to add a longer discussion of this to the article when I have some free time.
Looking forward to it!
I tried to work out the math when you describe the optimal batch size for memory bound vs compute bound and I think there may be an error. The multiplicative factor of B (batch size) should be with the compute latency calculation.
Kipply's blog also has the same - https://kipp.ly/transformer-inference-arithmetic/#batch-sizes
ugh, yes, you're right. I fixed it on my blog but not here.
Ok, fixed. ty for pointing that out.
Can you please explain the eqution
latency_memory = 2 . P . n_bytes / memory_bandwidth ?
i am struggling with the factor 2, can't figure out where it came from. Thank you
It’s basically because we have to multiply and add every parameter when we do a matmul, so the total flops for calculating a matmul is 2 * m * n. As we do this for the total parameters in a transformer, this becomes 2 * the number of parameters. The linked blog post by kipply goes into detail and links to lecture notes that prove this.
Thanks, I see, I thought we only use that for latency_compute not related to memory.
i will go through the blog and lecture notes
I think that's a mistake- should just be
latency_memory = P * n_bytes / memory_bandwidth
Thanks for looking into it.
Oh, touche. Hmm. That might be a mistake actually. Let me double check....