Most model weights I have seen are floating points between -1 and 1, If we got rid of the exponent bits wouldn't we be able to save ~31% model weight size in a 16 bit floating point?
Presumably this would require changes in the underlying hardware itself, in order to perform calculations with this new floating point.
But still find it bizarre that ML models have all these useless bits lying around.
There's a lot of work going on right now to try and figure out the best format. Jax has 5 different 8 bit floating point numbers (https://jax.readthedocs.io/en/latest/_modules/jax/_src/dtypes.html) that they are supporting right now. I suspect that we'll move to 8 bit floating point numbers, and eventually towards lower precision, as we figure out the right mix of precision/magnitude.
Hardware support is a problem though- right now, it's only the most recent Nvidia GPUs that support fp8.
The part about GPTQ is pretty bizarre - I would've thought quantization is just doing what you showed at scale. Maybe it works because it does that rounding operation in a vectorized operation? Rather than naive rounding which is slower? That doesn't sound like I'm saying anything intelligence. A tad funny that we don't know exactly why quantization works.
Correct me if I'm being naive, but this seems like the same approach that was published in Exponentially Faster Language Modeling. Have you had the chance to check out their work?
Efficiency in inference for large language models is paramount, and this article provides valuable insights. This article highlights key strategies for optimizing inference in large language models, emphasizing the significance of code profiling and simple optimizations like data structure changes to enhance performance and minimize resource utilization.
Thanks for the post,
Most model weights I have seen are floating points between -1 and 1, If we got rid of the exponent bits wouldn't we be able to save ~31% model weight size in a 16 bit floating point?
Presumably this would require changes in the underlying hardware itself, in order to perform calculations with this new floating point.
But still find it bizarre that ML models have all these useless bits lying around.
There's a lot of work going on right now to try and figure out the best format. Jax has 5 different 8 bit floating point numbers (https://jax.readthedocs.io/en/latest/_modules/jax/_src/dtypes.html) that they are supporting right now. I suspect that we'll move to 8 bit floating point numbers, and eventually towards lower precision, as we figure out the right mix of precision/magnitude.
Hardware support is a problem though- right now, it's only the most recent Nvidia GPUs that support fp8.
The part about GPTQ is pretty bizarre - I would've thought quantization is just doing what you showed at scale. Maybe it works because it does that rounding operation in a vectorized operation? Rather than naive rounding which is slower? That doesn't sound like I'm saying anything intelligence. A tad funny that we don't know exactly why quantization works.
Seems like the figure comparing model sizes to levels of precision is missing.
Correct me if I'm being naive, but this seems like the same approach that was published in Exponentially Faster Language Modeling. Have you had the chance to check out their work?
Link to paper:
https://arxiv.org/abs/2308.14711
Efficiency in inference for large language models is paramount, and this article provides valuable insights. This article highlights key strategies for optimizing inference in large language models, emphasizing the significance of code profiling and simple optimizations like data structure changes to enhance performance and minimize resource utilization.