Mojo struct
TensorDict
A collection of keyed Tensor
values used with checkpoint files.
This is the type accepted by
save()
and
returned by
load()
.
For example:
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape
def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")
def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
Implemented traits
AnyType
,
Sized
Methods
__init__
__init__(inout self: Self)
__copyinit__
__copyinit__(inout self: Self, existing: Self)
Copies a dictionary.
Args:
- existing (
Self
): The existing dict.
__moveinit__
__moveinit__(inout self: Self, owned existing: Self)
Moves data of an existing dictionary into a new one.
Args:
- existing (
Self
): The existing dict.
__setitem__
__setitem__[T: DType](inout self: Self, key: String, value: Tensor[T])
Supports setting items with the bracket accessor.
For example:
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
Args:
- key (
String
): The key to associate with the specified value. - value (
Tensor[T]
): The data to store in the dictionary.
set
set[T: DType](inout self: Self, key: String, value: Tensor[T])
Adds or updates a tensor in the dictionary.
Args:
- key (
String
): The name of the tensor. - value (
Tensor[T]
): The tensor to add.
get
get[type: DType](self: Self, key: String) -> Tensor[$0]
Gets a tensor from the dictionary.
Currently, this returns a copy of the tensor. For better performance,
use Dict.pop()
.
This method may change in the future to return an immutable reference instead of a mutable tensor copy.
Args:
- key (
String
): The name of the tensor.
Returns:
A copy of the tensor.
pop
pop[type: DType](inout self: Self, key: String) -> Tensor[$0]
Removes a tensor from the dictionary.
This function moves the Tensor pointer out of the dictionary and returns it to the caller.
Args:
- key (
String
): The name of the tensor.
Returns:
The tensor.
__len__
__len__(self: Self) -> Int
items
items(ref [self_is_lifetime] self: Self) -> _DictEntryIter[$0, String, _CheckpointTensor, $1._items, 1]
Gets an iterable view of all elements in the dictionary.
keys
keys(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]
Gets an iterable view of all keys in the dictionary.
__iter__
__iter__(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]
__str__
__str__(self: Self) -> String
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
If you'd like to share more information, please report an issue on GitHub
😔 What went wrong?