Skip to main content

Build a graph with MAX Graph

Welcome to the quickstart tutorial for the MAX Graph API.

In this brief tutorial, you'll learn the basics about how to build a neural network graph using the MAX Graph API, and then execute it with with the MAX Engine API, all in Mojo. To learn more about what MAX Graph is, see the MAX Graph intro.

Before you begin, be sure you install the latest version of MAX.

Preview

This API is still in development and subject to change.

Get the code

You can get all the code in this tutorial with a single copy at the end of the page.

Import Mojo packages

We need to import max.graph to build the graph and max.engine to execute it:

from max.engine import InferenceSession
from max.graph import Graph, TensorType, ops
from max.tensor import Tensor, TensorShape

Create the graph

The Graph object is your starting point. It's a lot like a function: it has a name, it takes arguments (inputs), it performs calculations (feeds data through the graph), and it returns values (outputs).

We can instantiate a Graph that takes one input like this:

def main():
graph = Graph(TensorType(DType.float32, 2, 6))
print(graph)
"builtin.module"() ({ "mo.graph"() <{counter = 0 : i64, functionType = (!mo.tensor<[2, 6], f32>) -> (), inputParams = #kgen<param.decls[]>, resultParams = #kgen<param.decls[]>, signature = !kgen.signature<(!mo.tensor<[2, 6], f32>) -> ()>}> ({ ^bb0(%arg0: !mo.tensor<[2, 6], f32>): }) {sym_name = "graph"} : () -> () }) : () -> ()

We're printing the graph just to satisfy curiosity, but what we get isn't useful because there's nothing connecting the input and output types, so we don't know the graph shape yet. This is basically an intermediate debug format for now.

When you initialize a Graph, you need to specify the data type and shape for the input using TensorType. In this case, the input shape is 2x6 and float32.

If your model takes multiple inputs or returns multiple outputs, you can pass a list of TensorType values like this (although we're still using just one item in each list to match the model we're building):

def main():
graph = Graph(
in_types=List[Type](TensorType(DType.float32, 2, 6)),
out_types=List[Type](TensorType(DType.float32, 2, 1)),
)

Add some ops

All ops receive inputs from either graph inputs, constants, or other op outputs. To build a sequence of ops, call each op function and pass it the appropriate inputs, which usually includes the output from a previous op.

For example, we'll now add three simple ops to our graph:

  • A matrix-multiplication function, which takes the graph input and a constant.
  • A RELU activation function, which takes the mat-mul output.
  • A softmax activation function, which takes the RELU output.

To close the graph, we then pass the final softmax op as the output:

# Create a constant for usage in the matmul op below:
matmul_constant_value = Tensor[DType.float32](TensorShape(6, 1), 0.15)
matmul_constant = graph.constant(matmul_constant_value)

# Start adding a sequence of operator calls to build the graph.
# We use the index accessor to get the graph's first input tensor:
matmul = graph[0] @ matmul_constant
relu = ops.relu(matmul)
softmax = ops.softmax(relu)

# Add the sequence of ops as the graph output:
graph.output(softmax)

Notice that we get the graph input using graph[0], which denotes the graph's first input (it's the first—and, in this case, the only—TensorType passed to the Graph constructor in_type). Then we perform a matrix-multiply with the constant, using the @ matrix-multiply operator, which is equivalent to calling ops.matmul().

The value returned into each variable (matmul, relu, and softmax) is a Symbol value. Each one is a symbolic handle for the output of an op, and not a real value/tensor. Because we're building a static graph, real values won't exist until execution time, and we can't execute the graph until we compile it with MAX Engine.

The only concrete value in the above code is matmul_constant_value, which holds static weights that we then convert into a Symbol with Graph.constant().

To finish the graph, we pass the entire sequence of ops to Graph.output().

We can print the graph again, but it's still not pretty. At this point, the graph has output but it's now in an intermediate representation that uses a lot of ops, because we have not made any optimization passes yet. We'll improve this output soon to make it more useful to you.

print(graph)
module { mo.graph @graph(%arg0: !mo.tensor<[2, 6], f32>) -> !mo.tensor<[2, 1], f32> { %0 = mo.constant {value = #M.dense_array<1.500000e-01, 1.500000e-01, 1.500000e-01, 1.500000e-01, 1.500000e-01, 1.500000e-01> : tensor<6x1xf32>} : !mo.tensor<[6, 1], f32> %1 = mo.shape_of(%arg0) : (!mo.tensor<[2, 6], f32>) -> !mo.tensor<[2], si64> %2 = mo.shape_of(%0) : (!mo.tensor<[6, 1], f32>) -> !mo.tensor<[2], si64> %3 = mo.constant {value = #M.dense_array<-1> : tensor<si64>} : !mo.tensor<[], si64> %4 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %5 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %6 = mo.cast(%4) : (!mo.tensor<[], si64>) -> !mo.tensor<[], si64> %7 = mo.cast(%5) : (!mo.tensor<[], si64>) -> !mo.tensor<[], si64> %8 = mo.add(%6, %7) : !mo.tensor<[], si64> %9 = mo.shape_of(%1) : (!mo.tensor<[2], si64>) -> !mo.tensor<[1], si64> %10 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %11 = mo.shape_of(%4) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %12 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %13 = mo.unsqueeze_shape(%11, %12) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %14 = mo.reshape(%4, %13) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %15 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %16 = mo.concat[%15 : !mo.tensor<[], si64>] (%14) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %17 = mo.shape_of(%8) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %18 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %19 = mo.unsqueeze_shape(%17, %18) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %20 = mo.reshape(%8, %19) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %21 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %22 = mo.concat[%21 : !mo.tensor<[], si64>] (%20) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %23 = mo.shape_of(%10) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %24 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %25 = mo.unsqueeze_shape(%23, %24) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %26 = mo.reshape(%10, %25) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %27 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %28 = mo.concat[%27 : !mo.tensor<[], si64>] (%26) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %29 = mo.slice(%1, %16, %22, %28) : (!mo.tensor<[2], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %30 = mo.shape_of(%29) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %31 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %32 = mo.squeeze_shape(%30, %31) : (!mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[0], si64> %33 = mo.reshape(%29, %32) : (!mo.tensor<[1], si64>, !mo.tensor<[0], si64>) -> !mo.tensor<[], si64> %34 = mo.shape_of(%3) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %35 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %36 = mo.unsqueeze_shape(%34, %35) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %37 = mo.reshape(%3, %36) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %38 = mo.shape_of(%33) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %39 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %40 = mo.unsqueeze_shape(%38, %39) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %41 = mo.reshape(%33, %40) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %42 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %43 = mo.concat[%42 : !mo.tensor<[], si64>] (%37, %41) : (!mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[2], si64> %44 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %45 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %46 = mo.shape_of(%1) : (!mo.tensor<[2], si64>) -> !mo.tensor<[1], si64> %47 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %48 = mo.shape_of(%47) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %49 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %50 = mo.unsqueeze_shape(%48, %49) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %51 = mo.reshape(%47, %50) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %52 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %53 = mo.concat[%52 : !mo.tensor<[], si64>] (%51) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %54 = mo.shape_of(%44) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %55 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %56 = mo.unsqueeze_shape(%54, %55) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %57 = mo.reshape(%44, %56) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %58 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %59 = mo.concat[%58 : !mo.tensor<[], si64>] (%57) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %60 = mo.shape_of(%45) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %61 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %62 = mo.unsqueeze_shape(%60, %61) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %63 = mo.reshape(%45, %62) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %64 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %65 = mo.concat[%64 : !mo.tensor<[], si64>] (%63) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %66 = mo.slice(%1, %53, %59, %65) : (!mo.tensor<[2], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %67 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %68 = mo.constant {value = #M.dense_array<2> : tensor<si64>} : !mo.tensor<[], si64> %69 = mo.constant {value = #M.dense_array<1> : tensor<si64>} : !mo.tensor<[], si64> %70 = mo.shape_of(%2) : (!mo.tensor<[2], si64>) -> !mo.tensor<[1], si64> %71 = mo.shape_of(%67) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %72 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %73 = mo.unsqueeze_shape(%71, %72) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %74 = mo.reshape(%67, %73) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %75 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %76 = mo.concat[%75 : !mo.tensor<[], si64>] (%74) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %77 = mo.shape_of(%68) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %78 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %79 = mo.unsqueeze_shape(%77, %78) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %80 = mo.reshape(%68, %79) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %81 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %82 = mo.concat[%81 : !mo.tensor<[], si64>] (%80) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %83 = mo.shape_of(%69) : (!mo.tensor<[], si64>) -> !mo.tensor<[0], si64> %84 = mo.constant {value = #M.dense_array<0> : tensor<1xsi64>} : !mo.tensor<[1], si64> %85 = mo.unsqueeze_shape(%83, %84) : (!mo.tensor<[0], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %86 = mo.reshape(%69, %85) : (!mo.tensor<[], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %87 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %88 = mo.concat[%87 : !mo.tensor<[], si64>] (%86) : (!mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %89 = mo.slice(%2, %76, %82, %88) : (!mo.tensor<[2], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[1], si64> %90 = mo.constant {value = #M.dense_array<0> : tensor<si64>} : !mo.tensor<[], si64> %91 = mo.concat[%90 : !mo.tensor<[], si64>] (%66, %89) : (!mo.tensor<[1], si64>, !mo.tensor<[1], si64>) -> !mo.tensor<[2], si64> %92 = mo.reshape(%arg0, %43) : (!mo.tensor<[2, 6], f32>, !mo.tensor<[2], si64>) -> !mo.tensor<[?, 6], f32> %93 = mo.matmul(%92, %0) : (!mo.tensor<[?, 6], f32>, !mo.tensor<[6, 1], f32>) -> !mo.tensor<[?, 1], f32> %94 = mo.reshape(%93, %91) : (!mo.tensor<[?, 1], f32>, !mo.tensor<[2], si64>) -> !mo.tensor<[2, 1], f32> %95 = mo.relu(%94) : !mo.tensor<[2, 1], f32> %96 = mo.softmax(%95) : (!mo.tensor<[2, 1], f32>) -> !mo.tensor<[2, 1], f32> mo.output %96 : !mo.tensor<[2, 1], f32> } }

Execute the model

Now we can load the graph into a MAX Engine InferenceSession. And, before we feed the model with inputs, we need to know the names for the input and output tensors, so let's print those now:

session = InferenceSession()
model = session.load(graph)

in_names = model.get_model_input_names()
for name in in_names:
print("Input:", name[])

out_names = model.get_model_output_names()
for name in out_names:
print("Output:", name[])
Input: input0 Output: output0

Now that we know the tensor names, we can create our input as a Tensor, pass it into the graph, execute it, and get the output:

input = Tensor[DType.float32](TensorShape(2, 6), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)
Tensor([[1.0], [1.0]], dtype=float32, shape=2x1)

That's it! You just built a model with the MAX Graph API and ran it with MAX Engine. But this was just a brief introduction, using only ops that are built into the MAX library.

If you'd like to implement custom ops for your graph, check out the guide to create a custom op in MAX Graph.

And this is just the beginning of the Graph API. To stay up to date on what's coming, sign up for our newsletter.

For a more a larger code example, check out our MAX Graph implementation of Llama2.

Full code example

Here's all the code from above (also available on GitHub):

from max.engine import InferenceSession
from max.graph import Graph, TensorType, ops
from max.tensor import Tensor, TensorShape


def main():
graph = Graph(TensorType(DType.float32, 2, 6))

# Create a constant for usage in the matmul op below:
matmul_constant_value = Tensor[DType.float32](TensorShape(6, 1), 0.15)
matmul_constant = graph.constant(matmul_constant_value)

# Start adding a sequence of operator calls to build the graph.
# We can use the subscript notation to get the graph's first input tensor:
matmul = graph[0] @ matmul_constant
relu = ops.relu(matmul)
softmax = ops.softmax(relu)
graph.output(softmax)

# Load the graph:
session = InferenceSession()
model = session.load(graph)

# Print the input/output names:
in_names = model.get_model_input_names()
for name in in_names:
print("Input:", name[])
out_names = model.get_model_output_names()
for name in out_names:
print("Output:", name[])

# Execute the model:
input = Tensor[DType.float32](TensorShape(2, 6), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)