Skip to main content

Quantize your graph weights

Quantization is an optimization technique that reduces the numeric precision of weights in a model. For example, models are usually trained with float32 weights, but you can quantize the values to a lower precision type such as int8 or int4. That is, instead of storing each scalar value with 32-bits, you can use just 8 or 4 bits. This reduces the computational and memory demands during inference, which makes the model faster and compatible with more systems.

To support quantization with MAX Graph, we’ve built an API designed for low-level graph engineers who want to quantize specific weights in a model. This API does not quantize an entire model. Like the MAX Graph API, this is a low-level API meant for engineers who want to build high-performance graphs in a systems programming language—specifically, in Mojo.

If you just want to read some code, check out the Quantize TinyStories pipeline, which quantizes a 15-million parameter version of Llama 2 with Q4_0 (4-bit) encoding.

note

This is post-training quantization. The Graph API does not support model training, so you must import your model weights, load them as Tensor values, and then quantize them.

Overview

When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision.

To support this mixed-precision strategy, the quantization API in MAX Graph is declarative. That means you can quantize the weights in your model explicitly as you see fit, rather than pick one quantization format for the whole model. You can quantize different weights with different encodings, write custom ops that understand your quantizations, and even implement your own quantization encodings.

The primary API is the quantize() function (from the QuantizationEncoding trait), which takes a float32 tensor and returns a quantized tensor as a uint8 bytes buffer (it’s a type-erased blob of bytes that can be in any quantization encoding). You can call quantize() using one of the existing quantization encodings, such as Q4_0, Q4_K, and Q6_K (from GGML). Then, add the quantized tensor as a node in your graph.

Because the quantized data is just a blob of bytes with a special encoding for the values and scaling factor, any op that you pass this data into must know how to dequantize that data in order to perform its calculation with the full-precision float32 value.

Currently, the only op included in MAX Graph that can operate on quantized data is qmatmul(). This takes a float32 tensor and a quantized tensor, and returns a float32 tensor. This op alone allows you to build a variety of quantized transformer models. However, using quantized weights with any other op in max.graph.ops doesn’t work as is, because they all expect float32 inputs. To make it work, you can create a custom op that accepts a quantized input, dequantizes it with the appropriate decoding, and then completes the operation.

Now let’s look at some simple code examples.

Quantize some weights

When you build a graph with MAX Graph, each batch of weights begins as a Tensor that you set in the graph as a constant (a node created with Graph.constant()). When you want to quantize those weights, just pass the Tensor to the quantize() method from the encoding type you want to use before you add it to the graph.

For example, the following code quantizes a tensor with Q4_0Encoding (4-bit encoding), performs quantized matmul (using qmatmul()), and prints the results:

from max.tensor import Tensor, TensorShape
from max.engine import InferenceSession
from max.graph import Graph, TensorType
from max.graph.quantization import Q4_0Encoding
from max.graph.ops.quantized_ops import qmatmul

def main():
graph = Graph(TensorType(DType.float32, 32, 64))

# Perform matmul with the full-precision constant
# constant_value = Tensor[DType.float32](TensorShape(64, 32), 0.15)
# constant = graph.constant(constant_value)
# matmul = graph[0] @ constant

# Perform matmul with the quantized constant (transposed)
constant_value = Tensor[DType.float32](TensorShape(32, 64), 0.15)
quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
matmul = qmatmul[Q4_0Encoding](graph[0], quantized_constant)

graph.output(matmul)

session = InferenceSession()
model = session.load(graph)

input = Tensor[DType.float32](TensorShape(32, 64), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)
caution

qmatmul() expects the "right-hand-side" argument to be transposed. For example, whereas the normal matmul() op takes lhs and rhs with shapes m×nm × n and n×pn × p, respectively (to get an output shape m×pm × p), qmatmul() requires the rhs shape be p×np × n.

In the above example, the input shape is [32, 64] and the quantized shape is also [32, 64], making the output shape [32, 32].

You probably noticed this code also includes the full-precision matmul as an option. If you toggle the comments on lines 11-13 and 16-19, and run it again, you can see for yourself how close the results are even though the quantized constant uses just 1/8th of the memory (4-bits vs 32-bits).

No matter which quantization encoding you choose, the quantize() method works the same—it takes in a full-precision value as a Tensor value and returns the quantized value as a Tensor.

Alternatively, you can use Graph.quantize() to combine these two lines:

    quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)

Into one line:

    quantized_constant = graph.quantize[Q4_0Encoding](constant_value)

To see how we quantized a real model with this API, check out the Quantize TinyStories pipeline, which is a 15-million parameter model quantized with 4-bit encoding down to about 10MB.

note

Because the Graph API builds a static computation graph, quantization happens at graph build time. That means you can’t use quantize() with runtime inputs, because all the tensors you want to quantize must be fixed at the time the Graph calls execute.

Save and load tensors to disk

To avoid quantizing your weights every time you load a model, you can save and load them from disk using the save() and load() functions. For example:

from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.max")

def read_from_disk():
tensors = load("/path/to/checkpoint.max")
x = tensors.get[DType.int32]("x")

The TensorDict type is just a dictionary type for named tensors.

Custom quantization encodings

If you want a quantization encoding that’s not provided in the quantization package already, you can implement your own by building a type that conforms to the QuantizationEncoding trait.

Your custom QuantizationEncoding type must implement the quantize() function, which takes a float32 tensor and returns a uint8 byte tensor in the corresponding quantized buffer shape. With that type defined, you can call the quantize() function to produce your quantized tensors as you build the graph, but this doesn’t take care of the runtime decoding.

To decode your quantization type during inference, you also need to build custom ops for each op in your graph that takes in any of these quantized values. For example, if your graph performs matrix-multiplication using quantized inputs, you’ll need to implement a custom version of matmul that knows how to decode your custom quantization encoding and then use that custom op instead of the traditional matmul op.

For more information, see how to create a custom op in MAX Graph.