├── .gitignore
├── CODEOWNERS
├── LICENSE
├── README.md
├── basalt
├── __init__.mojo
├── autograd
│ ├── __init__.mojo
│ ├── attributes.mojo
│ ├── graph.mojo
│ ├── node.mojo
│ ├── ops
│ │ ├── __init__.mojo
│ │ ├── basics.mojo
│ │ ├── conv.mojo
│ │ ├── dynamics.mojo
│ │ ├── matmul.mojo
│ │ ├── mlops.mojo
│ │ ├── ops.mojo
│ │ └── pool.mojo
│ ├── params.mojo
│ └── symbol.mojo
├── nn
│ ├── __init__.mojo
│ ├── activations.mojo
│ ├── initializers.mojo
│ ├── layers
│ │ ├── __init__.mojo
│ │ ├── conv.mojo
│ │ ├── dropout.mojo
│ │ ├── linear.mojo
│ │ ├── pool.mojo
│ │ └── sequential.mojo
│ ├── loss.mojo
│ ├── model.mojo
│ ├── optim.mojo
│ └── tensor.mojo
└── utils
│ ├── __init__.mojo
│ ├── bytes.mojo
│ ├── collection.mojo
│ ├── dataloader.mojo
│ ├── datasets.mojo
│ ├── graph_render.py
│ ├── math_util.mojo
│ ├── onnx_utils.mojo
│ ├── perf_utils.mojo
│ ├── rand_utils.mojo
│ ├── tensor_creation_utils.mojo
│ └── tensorutils.mojo
├── examples
├── data
│ ├── housing.csv
│ ├── mnist_test_small.csv
│ └── mnist_torch.onnx
├── housing.mojo
├── housing.py
├── mnist.mojo
├── mnist.py
├── mnist_load_model.mojo
├── sin_estimate.mojo
└── sin_estimate.py
├── magic.lock
├── mojoproject.toml
├── python-requirements.txt
└── tests
├── __init__.mojo
├── mojo
├── test_activations.mojo
├── test_attributes.mojo
├── test_backward.mojo
├── test_collection.mojo
├── test_dynamic_ops.mojo
├── test_loss.mojo
├── test_mlops.mojo
├── test_ops.mojo
├── test_tensorutils.mojo
└── test_tensorutils_data.mojo
├── python
├── test_broadcast_shapes.mojo
├── test_conv.mojo
├── test_dynamic_ops_torch.mojo
├── test_mlops_torch.mojo
├── test_models_mnist.mojo
├── test_models_regression.mojo
├── test_models_sin_estimate.mojo
├── test_models_torch.py
├── test_ops_torch.mojo
└── test_pool.mojo
└── testing_utils.mojo
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | __pycache__
3 |
4 | basalt.📦
5 |
6 | examples/data/mnist_test.csv
7 | examples/data/mnist_train.csv
8 | examples/data/mnist_train_small.csv
9 |
10 | output_model.onnx
11 | Makefile
12 |
13 | ./temp
14 | flamegraph.svg
15 |
16 | .magic
17 |
18 | examples/data/yolov8n.onnx
19 |
20 | *.DS_Store
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @StijnWoestenborghs
2 | * @andresnowak
3 | * @Benny-Nottonson
4 | * @soraros
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
Basalt
8 |
9 |
10 | A Machine Learning framework from scratch in pure Mojo 🔥
11 |
12 |
13 |
14 |
19 |
20 |
21 | ## About The Project
22 |
23 | Basalt is a stand-alone machine learning framework that leverages the power of Mojo.
24 |
25 | As [discussed](https://docs.modular.com/mojo/why-mojo) by Modular, Mojo is a language for the future of AI development. Built on top of MLIR technology, rather than existing GCC and LLVM approaches, Mojo looks and feels like Python code, yet performs much closer to languages like Rust or C++. Parametric functions and compile time parameters allow for the graph to statically compiled. Having the static graph allows for much harder performance optimizations.
26 |
27 | Basalt, while still in its infancy, is able to achieve speeds comparable to well established frameworks like Pytorch. Below a snapshot of the current benchmarks. But keep posted, there is much more room for improvement and we are upgrading the project on a daily basis.
28 |
29 | 
30 |
31 |
32 | ## Quick Start
33 |
34 | Try out the benchmarks yourself:
35 |
36 | ```
37 | mojo -I . examples/housing.mojo
38 | ```
39 | ```
40 | mojo -I . examples/sin_estimate.mojo
41 | ```
42 | ```
43 | mojo -I . examples/mnist.mojo
44 | ```
45 |
46 | Compare to the alternative PyTorch implementation:
47 | Make sure to install the requirements in `python-requirements.txt` in your python environment.
48 |
49 | ```
50 | python examples/housing.py
51 | python examples/sin_estimate.py
52 | python examples/mnist.py
53 | ```
54 |
55 | ## Roadmap
56 |
57 | ### v0.1.0 ✅
58 | - [x] Improve matrix multiplication and convolution kernels
59 | - [x] Switch to custom Tensor and TensorShape implementations
60 | - [x] Improve benchmarks and overall model execution performance
61 | - [x] Add profiling and additional performance tests
62 |
63 | ### v0.2.0 (WIP)
64 | - [ ] Add additional operators: Slice, (Un)Squeeze, Concat, Clip, Gather, Split, FMA ...
65 | - [ ] Better layer support and more activation functions
66 | - [ ] Graph submodules & graph concatenation
67 | - [ ] Computer vision benchmark.
68 |
69 | ### Long-Term
70 | - [ ] Better parallelization
71 | - [ ] GPU support
72 | - [ ] Reworked Dataloader
73 | - [ ] Autotuning and related features
74 | - [ ] Graph compilation optimizations
75 | - [ ] Operator fusion
76 | - [ ] ONNX / Max compatibility
77 |
78 | ## Contributing
79 |
80 | Basalt is built by community efforts and relies on your expertise and enthousiasm!
81 | Small fixes and improvements are much appreciated. If you are considering larger contributions, feel free to contact us for a smoother communication channel on Discord. If you find a bug or have an idea for a feature, please use our issue tracker. Before creating a new issue, please:
82 | * Check if the issue already exists. If an issue is already reported, you can contribute by commenting on the existing issue.
83 | * If not, create a new issue and include all the necessary details to understand/recreate the problem or feature request.
84 |
85 | ### Creating A Pull Request
86 |
87 | 1. Fork the Project
88 | 2. Create your Feature Branch
89 | 3. Commit your Changes
90 | 4. Push to the Branch
91 | 5. Open a Pull Request
92 | > Once your changes are pushed, navigate to your fork on GitHub. And create a pull request against the original basalt-org/basalt repository.
93 | > - Before creating a PR make sure it doesn't break any of the unit-tests. (e.g. `mojo run -I . test/test_ops.mojo`)
94 | > - Introducing new big features requires a new test!
95 | > - In the pull request, provide a detailed description of the changes and why they're needed. Link any relevant issues.
96 | > - If there are any specific instructions for testing or validating your changes, include those as well.
97 |
98 | ## License
99 |
100 | Distributed under the Apache 2.0 License with LLVM Exceptions. See [LICENSE](https://github.com/Basalt-Org/Basalt/blob/main/LICENSE) and the LLVM [License](https://llvm.org/LICENSE.txt) for more information.
101 |
102 | ## Acknowledgements
103 |
104 | * Built with [Mojo](https://github.com/modularml/mojo) created by [Modular](https://github.com/modularml)
105 |
--------------------------------------------------------------------------------
/basalt/__init__.mojo:
--------------------------------------------------------------------------------
1 | from .autograd import Graph, Symbol, OP
2 | from .nn import Tensor, TensorShape
3 | from sys.info import simdwidthof
4 | from basalt.utils.collection import Collection
5 |
6 | alias dtype = DType.float32
7 | alias nelts = 2 * simdwidthof[dtype]()
8 | alias seed = 42
9 | alias epsilon = 1e-12
10 |
--------------------------------------------------------------------------------
/basalt/autograd/__init__.mojo:
--------------------------------------------------------------------------------
1 | from .symbol import Symbol
2 | from .graph import Graph
3 | from .ops import OP
4 |
--------------------------------------------------------------------------------
/basalt/autograd/attributes.mojo:
--------------------------------------------------------------------------------
1 | from collections import Optional, OptionalReg
2 | from utils.static_tuple import StaticTuple
3 | from utils.index import IndexList
4 |
5 | from basalt.nn.tensor import Tensor, TensorShape, MAX_RANK
6 | from basalt.utils.bytes import Bytes, scalar_to_bytes, bytes_to_scalar
7 |
8 |
9 | alias MAX_ATTRS = 10
10 | alias MAX_NAME_CHARS = 16
11 | alias MAX_DATA_BYTES = 32
12 |
13 |
14 | @register_passable("trivial")
15 | struct AttributeType(Stringable):
16 | alias BOOL = AttributeType(0, "BOOL")
17 | alias INT = AttributeType(1, "INT")
18 | alias FLOAT = AttributeType(2, "FLOAT")
19 | alias STRING = AttributeType(3, "STRING")
20 | alias INTS = AttributeType(4, "INTS")
21 | alias FLOATS = AttributeType(5, "FLOATS")
22 |
23 | var id: UInt8
24 | var name: Bytes[MAX_NAME_CHARS]
25 |
26 | fn __init__(inout self, id: UInt8, name: String):
27 | self.id = id
28 | self.name = Bytes[MAX_NAME_CHARS](name)
29 |
30 | fn __init__(inout self, type: DType):
31 | if type.is_floating_point():
32 | self = AttributeType.FLOAT
33 | elif type == DType.bool:
34 | self = AttributeType.BOOL
35 | else:
36 | self = AttributeType.INT
37 |
38 | fn __eq__(self, other: Self) -> Bool:
39 | return self.id == other.id
40 |
41 | fn __str__(self) -> String:
42 | return str(self.name)
43 |
44 |
45 | @register_passable("trivial")
46 | struct AttributeVector(Sized, Stringable, CollectionElement):
47 | var attributes: StaticTuple[Attribute, MAX_ATTRS]
48 | var size: Int
49 |
50 | fn __init__(inout self, *attributes: Attribute):
51 | self.attributes = StaticTuple[Attribute, MAX_ATTRS](Attribute("", ""))
52 | self.size = len(attributes)
53 | for i in range(self.size):
54 | self.attributes[i] = attributes[i]
55 |
56 | @always_inline("nodebug")
57 | fn __len__(self) -> Int:
58 | return self.size
59 |
60 | @always_inline("nodebug")
61 | fn __getitem__(self, index: Int) -> Attribute:
62 | return self.attributes[index]
63 |
64 | @always_inline("nodebug")
65 | fn __getitem__(self, index: StringLiteral) -> OptionalReg[Attribute]:
66 | for i in range(self.size):
67 | if self.attributes[i].name == Bytes[MAX_NAME_CHARS](index):
68 | return self.attributes[i]
69 | return None
70 |
71 | fn __str__(self) -> String:
72 | var s: String = "["
73 | for i in range(self.size):
74 | s += str(self.attributes[i])
75 | if i < self.size - 1:
76 | s += ", "
77 | return s + "]"
78 |
79 |
80 | @register_passable("trivial")
81 | struct Attribute(Stringable, CollectionElement):
82 | var data_shape: IndexList[MAX_RANK]
83 | var name: Bytes[MAX_NAME_CHARS]
84 | var data: Bytes[MAX_DATA_BYTES]
85 | var type: AttributeType
86 | var size: Int
87 |
88 | fn __init__(inout self, name: String, value: String):
89 | self.data_shape = IndexList[MAX_RANK]()
90 | self.name = Bytes[MAX_NAME_CHARS](name)
91 | self.data = Bytes[MAX_DATA_BYTES](value)
92 | self.type = AttributeType.STRING
93 | self.size = len(value)
94 |
95 | fn __init__(inout self, name: String, value: TensorShape):
96 | self.data_shape = IndexList[MAX_RANK]()
97 | self.name = Bytes[MAX_NAME_CHARS](name)
98 | self.data = Bytes[MAX_DATA_BYTES]()
99 | self.type = AttributeType.INTS
100 | self.size = value.rank()
101 |
102 | for i in range(self.size):
103 | self.data_shape[i] = value._shape[i]
104 |
105 | fn __init__[N: Int](inout self, name: String, value: IndexList[N]):
106 | constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]()
107 |
108 | self.data_shape = IndexList[MAX_RANK]()
109 | self.name = Bytes[MAX_NAME_CHARS](name)
110 | self.data = Bytes[MAX_DATA_BYTES]()
111 | self.type = AttributeType.INTS
112 | self.size = N
113 |
114 | for i in range(self.size):
115 | self.data_shape[i] = value[i]
116 |
117 | fn __init__(inout self, name: String, value: List[Int]):
118 | self.data_shape = IndexList[MAX_RANK]()
119 | self.name = Bytes[MAX_NAME_CHARS](name)
120 | self.data = Bytes[MAX_DATA_BYTES]()
121 | self.type = AttributeType.INTS
122 | self.size = len(value)
123 |
124 | for i in range(self.size):
125 | self.data_shape[i] = value[i]
126 |
127 | fn __init__(inout self, name: String, value: StaticTuple[Int, _]):
128 | self.data_shape = IndexList[MAX_RANK]()
129 | self.name = Bytes[MAX_NAME_CHARS](name)
130 | self.data = Bytes[MAX_DATA_BYTES]()
131 | self.type = AttributeType.INTS
132 | self.size = len(value)
133 |
134 | for i in range(self.size):
135 | self.data_shape[i] = value[i]
136 |
137 | fn __init__[dtype: DType](inout self, name: String, value: Scalar[dtype]):
138 | constrained[dtype.is_numeric(), "Attribute value must be numeric."]()
139 |
140 | self.data_shape = IndexList[MAX_RANK]()
141 | self.name = Bytes[MAX_NAME_CHARS](name)
142 | self.data = scalar_to_bytes[dtype, MAX_DATA_BYTES](value)
143 | self.type = AttributeType(dtype)
144 | self.size = 1
145 |
146 | fn __init__(inout self, name: String, value: Int):
147 | self.__init__(name, Int64(value))
148 | self.data_shape[0] = 1
149 |
150 | fn __init__(inout self, name: String, value: FloatLiteral):
151 | self.__init__(name, Float64(value))
152 | self.data_shape[0] = 1
153 |
154 | @always_inline("nodebug")
155 | fn __str__(self) -> String:
156 | return "Attribute(" + str(self.name) + ", " + "..." + ")"
157 |
158 | @always_inline("nodebug")
159 | fn to_string(self) -> String:
160 | return str(self.data)
161 |
162 | @always_inline("nodebug")
163 | fn to_list(self) -> List[Int]:
164 | var result = List[Int]()
165 |
166 | for i in range(self.size):
167 | result.append(self.data_shape[i])
168 |
169 | return result
170 |
171 | @always_inline("nodebug")
172 | fn to_shape(self) -> TensorShape:
173 | return TensorShape(rank=self.size, shape=self.data_shape)
174 |
175 | @always_inline("nodebug")
176 | fn to_static[N: Int](self) -> IndexList[N]:
177 | constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]()
178 |
179 | var result = IndexList[N]()
180 |
181 | for i in range(N):
182 | result[i] = int(self.data_shape[i])
183 |
184 | return result
185 |
186 | @always_inline("nodebug")
187 | fn to_scalar[dtype: DType](self) -> Scalar[dtype]:
188 | constrained[dtype.is_numeric(), "Attribute value must be numeric."]()
189 |
190 | return bytes_to_scalar[dtype](self.data)
191 |
192 | @always_inline("nodebug")
193 | fn to_int(self) -> Int:
194 | return int(self.to_scalar[DType.int64]())
195 |
196 | fn json(self) -> String:
197 | var result = '{"name": "' + str(self.name) + '", '
198 |
199 | var type: String = ""
200 | var value: String = ""
201 |
202 | if self.type == AttributeType.STRING:
203 | type = "STRING"
204 | value = '"' + self.to_string() + '"'
205 | elif self.type == AttributeType.INTS:
206 | type = "INTS"
207 |
208 | var value_temp = self.to_shape()
209 | value = "["
210 | for i in range(value_temp.rank()):
211 | value += str(value_temp._shape[i])
212 | if i < value_temp.rank() - 1:
213 | value += ", "
214 | value += "]"
215 | elif self.type == AttributeType.FLOAT:
216 | type = "FLOAT"
217 | value = str(self.to_scalar[DType.float64]())
218 | elif self.type == AttributeType.INT:
219 | type = "INT"
220 | value = str(self.to_int())
221 | else:
222 | type = "UNKNOWN"
223 | value = "UNKNOWN"
224 |
225 | result += '"type": "' + type + '", ' + '"value": ' + value
226 |
227 | return result + "}"
228 |
--------------------------------------------------------------------------------
/basalt/autograd/graph.mojo:
--------------------------------------------------------------------------------
1 | from python.python import Python
2 | from collections.optional import Optional, OptionalReg
3 |
4 | from .node import Node
5 | from .attributes import AttributeVector, Attribute
6 | from .symbol import Symbol
7 | from .ops import OP, static_result_shape, dynamic_result_shape
8 | from .params import ParamDict, Param
9 |
10 | from basalt import seed, dtype
11 | from basalt import Tensor, TensorShape
12 |
13 |
14 | struct Graph:
15 | var inputs: List[Symbol]
16 | var params: ParamDict
17 | var nodes: List[Node]
18 | var outputs: List[Symbol]
19 | var loss_out: OptionalReg[Symbol]
20 | var symbol_count: UInt32
21 |
22 | fn __init__(inout self):
23 | self.inputs = List[Symbol]()
24 | self.params = ParamDict()
25 | self.nodes = List[Node]()
26 | self.outputs = List[Symbol]()
27 | self.loss_out = None
28 | self.symbol_count = 0
29 |
30 | fn __moveinit__(inout self, owned other: Graph):
31 | self.inputs = other.inputs^
32 | self.params = other.params^
33 | self.nodes = other.nodes^
34 | self.outputs = other.outputs^
35 | self.loss_out = other.loss_out
36 | self.symbol_count = other.symbol_count
37 |
38 | fn create_symbol(inout self, shape: TensorShape, data: Optional[Param] = None, trainable: Bool = False, is_input: Bool = False) -> Symbol:
39 | var symbol = Symbol(self.symbol_count, dtype, shape, trainable)
40 | self.symbol_count += 1
41 |
42 | if is_input:
43 | self.inputs.append(symbol)
44 | else:
45 | if data is not None:
46 | self.params.put(symbol, data.value())
47 | else:
48 | self.params.put(symbol)
49 |
50 | return symbol
51 |
52 | fn input(inout self, shape: TensorShape, trainable: Bool = False) -> Symbol:
53 | return self.create_symbol(shape, trainable=trainable, is_input=True)
54 |
55 | fn param(inout self, shape: TensorShape, init: Param, trainable: Bool = True) -> Symbol:
56 | return self.create_symbol(shape, init, trainable)
57 |
58 | fn param(inout self, shape: TensorShape, trainable: Bool = True) -> Symbol:
59 | return self.create_symbol(shape, trainable=trainable)
60 |
61 | fn scalar(inout self, value: Scalar[dtype]) -> Symbol:
62 | return self.create_symbol(TensorShape(1), Param(value), trainable=False)
63 |
64 | fn constant(inout self, shape: TensorShape, data: List[Scalar[dtype]]) -> Symbol:
65 | return self.create_symbol(shape, Param(data), trainable=False)
66 |
67 | fn out(inout self, symbol: Symbol):
68 | self.outputs.append(symbol)
69 |
70 | fn loss(inout self, symbol: Symbol):
71 | self.loss_out = symbol
72 |
73 | fn op(
74 | inout self,
75 | op: OP,
76 | *operands: Symbol,
77 | attributes: AttributeVector = AttributeVector(),
78 | ) -> Symbol:
79 | var res_shape = static_result_shape(op, operands, attributes)
80 | var res = Symbol(self.symbol_count, dtype, res_shape, self.result_trainable(operands))
81 | self.symbol_count += 1
82 |
83 | var inputs = List[Symbol]()
84 | inputs.reserve(len(operands))
85 |
86 | for operand in operands:
87 | inputs.append(operand)
88 |
89 | self.nodes.append(Node(op, inputs, List[Symbol](res), attributes))
90 | return res
91 |
92 | fn op(
93 | inout self,
94 | op: OP,
95 | operand_1: Symbol,
96 | operand_2: Float64,
97 | attributes: AttributeVector = AttributeVector(),
98 | ) -> Symbol:
99 | return self.op(op, operand_1, self.scalar(operand_2.cast[dtype]()), attributes=attributes)
100 |
101 | fn op(
102 | inout self,
103 | op: OP,
104 | operand_1: Float64,
105 | operand_2: Symbol,
106 | attributes: AttributeVector = AttributeVector(),
107 | ) -> Symbol:
108 | return self.op(op, self.scalar(operand_1.cast[dtype]()), operand_2, attributes=attributes)
109 |
110 | fn create_symbols(inout self, shapes: List[TensorShape], trainable: Bool = False) -> List[Symbol]:
111 | var symbols = List[Symbol]()
112 | symbols.reserve(len(shapes))
113 |
114 | for shape in shapes:
115 | symbols.append(Symbol(self.symbol_count, dtype, shape[], trainable))
116 | self.symbol_count += 1
117 |
118 | return symbols
119 |
120 | fn add_node(inout self, op: OP, inputs: List[Symbol], outputs: List[Symbol], attributes: AttributeVector):
121 | self.nodes.append(Node(op, inputs, outputs, attributes))
122 |
123 | fn concat(inout self, *operands: Symbol, dim: Int = 0) -> Symbol:
124 | var attributes = AttributeVector(Attribute("dim", dim))
125 | var res_shape = dynamic_result_shape(OP.CONCAT, operands, attributes)[0]
126 | var res_symbols = self.create_symbols(List[TensorShape](res_shape), self.result_trainable(operands))
127 |
128 | var operand_list = List[Symbol]()
129 | operand_list.reserve(len(operands))
130 | for operand in operands:
131 | operand_list.append(operand)
132 |
133 | self.add_node(OP.CONCAT, operand_list, res_symbols, attributes)
134 | return res_symbols[0]
135 |
136 | fn split(
137 | inout self, operand: Symbol, sections: List[Int], dim: Int = 0
138 | ) -> List[Symbol]:
139 | var attributes = AttributeVector(Attribute("sections", TensorShape(sections)), Attribute("dim", dim))
140 | var res_shapes = dynamic_result_shape(OP.SPLIT, operand, attributes)
141 | var trainable = self.result_trainable(operand)
142 | var result_symbols = self.create_symbols(res_shapes, trainable)
143 | self.add_node(OP.SPLIT, List[Symbol](operand), result_symbols, attributes)
144 | return result_symbols
145 |
146 | @staticmethod
147 | fn result_trainable(operands: VariadicList[Symbol]) -> Bool:
148 | for operand in operands:
149 | if operand.trainable:
150 | return True
151 | return False
152 |
153 | fn json(self) -> String:
154 | var result: String = '{"graph_name": "basalt", "nodes": ['
155 | for i in range(len(self.nodes)):
156 | result += self.nodes[i].json()
157 | if i < len(self.nodes) - 1:
158 | result += ", "
159 | result += '], "inputs": ['
160 | for i in range(len(self.inputs)):
161 | result += self.inputs[i].json()
162 | if i < len(self.inputs) - 1:
163 | result += ", "
164 | result += '], "outputs": ['
165 | for i in range(len(self.outputs)):
166 | result += self.outputs[i].json()
167 | if i < len(self.outputs) - 1:
168 | result += ", "
169 | if self.loss_out:
170 | result += '], "loss": ['
171 | result += self.loss_out.value().json()
172 | result += '], "params": ['
173 | for i in range(len(self.params)):
174 | result += self.params.symbols[i].json()
175 | if i < len(self.params) - 1:
176 | result += ", "
177 | result += "]}"
178 | return result
179 |
180 | fn render(self, render_type: String = "node") raises:
181 | Python.add_to_path("./basalt/utils")
182 | var renderer = Python.import_module("graph_render")
183 | var json = Python.import_module("json")
184 | _ = renderer.netron_render(json.loads(self.json()), render_type)
185 |
186 | fn compile(inout self):
187 | # 0. Sorting the graph
188 | # The staticlly defined graph has an implicit topological sorted order because,
189 | # each new operation is added the list of nodes after its dependencies have been calculated.
190 | # This eliminates the need for explicit topological sorting.
191 |
192 | # Possibilities:
193 | # - 1. Graph layout transformation (graph rewrite)
194 | # - Layer pruning (removing nodes that have no effect - with common sub-tree identification)
195 | # - Eliminate redundant intermediate data copies
196 | # - Operator replacement (e.g. replacing (combination of) costly ops with more efficient ones)
197 | # - (exmple of graph rewrite: https://dl.acm.org/doi/pdf/10.1145/3453483.3454083 - Table 4)
198 | # - Other intra-block optimizations: (e.g. data layout transformation BCHW -> BHWC, etc.)
199 | # - 2. Operator fusion (combining ops without materializing intermediate results)
200 | # - Fusion plan exploration
201 | # - Fusion plan generation (with subsequent intra-block optimizations)
202 | # - (example fusion plan algorithm: https://dl.acm.org/doi/pdf/10.1145/3453483.3454083 - Listing 1)
203 | # - 3. Fusion Code generation (behaviour)
204 | # - Code generation for planned fusion blocks
205 | # - Other inter-block optimizations (e.g. data layout transformation BCHW -> BHWC, etc.)
206 | # - 4. Auto-tuning (of vectorization-, parallelization-, tiling-, unrolling-parameters)
207 | # - (Might only work when memory is initialized)
208 |
209 | # Other considerations:
210 | # - Efficient Memory management:
211 | # - Memory reuse (in-place operations)
212 | # - Data layout from BCHW (batch, channel, height, width) to BHWC can lead to better utilization and efficiency
213 | # - VJP, JVP (for automatic differentiation)
214 |
215 | pass
216 |
--------------------------------------------------------------------------------
/basalt/autograd/node.mojo:
--------------------------------------------------------------------------------
1 | from collections.optional import Optional
2 | from utils.variant import Variant
3 |
4 | from basalt.autograd import Symbol
5 | from basalt.autograd.ops import OP
6 |
7 | from .attributes import AttributeVector
8 |
9 |
10 | @value
11 | struct Node(CollectionElement, Stringable):
12 | var operator: OP
13 | var inputs: List[Symbol]
14 | var outputs: List[Symbol]
15 | var attributes: AttributeVector
16 |
17 | fn __init__(
18 | inout self,
19 | operator: OP,
20 | inputs: List[Symbol],
21 | outputs: List[Symbol],
22 | attributes: AttributeVector = AttributeVector(),
23 | ):
24 | self.operator = operator
25 | self.inputs = inputs
26 | self.outputs = outputs
27 | self.attributes = attributes
28 |
29 | fn __str__(self) -> String:
30 | return self.json()
31 |
32 | fn json(self) -> String:
33 | var s: String = '{"operator": "' + str(self.operator.name) + '", "inputs": ['
34 | for i in range(len(self.inputs)):
35 | s += self.inputs[i].json()
36 | if i < len(self.inputs) - 1:
37 | s += ", "
38 | s += '], "outputs": ['
39 | for i in range(len(self.outputs)):
40 | s += self.outputs[i].json()
41 | if i < len(self.outputs) - 1:
42 | s += ", "
43 | s += '], "attributes": ['
44 | for i in range(len(self.attributes)):
45 | s += self.attributes[i].json()
46 | if i < len(self.attributes) - 1:
47 | s += ", "
48 | s += "]}"
49 | return s
50 |
--------------------------------------------------------------------------------
/basalt/autograd/ops/__init__.mojo:
--------------------------------------------------------------------------------
1 | from .ops import (
2 | OP,
3 | static_result_shape,
4 | dynamic_result_shape,
5 | forward_op,
6 | backward_op,
7 | )
8 |
--------------------------------------------------------------------------------
/basalt/autograd/ops/dynamics.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Symbol
2 | from basalt.nn.model import Parameters
3 | from ..attributes import AttributeVector
4 |
5 | from memory import memcpy
6 |
7 |
8 | struct CONCAT:
9 | @staticmethod
10 | fn result_shape(
11 | input_shapes: List[TensorShape], attributes: AttributeVector
12 | ) -> List[TensorShape]:
13 | # Assumptions: all tensors have the same shape, except for the concatenating dimension
14 | var dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
15 |
16 | var concat_size: Int = 0
17 | for i in range(len(input_shapes)):
18 | concat_size += input_shapes[i][dim]
19 |
20 | var res_shape = input_shapes[0]
21 | res_shape[dim] = concat_size
22 |
23 | return List[TensorShape](res_shape)
24 |
25 | @staticmethod
26 | fn calc_chunks(shape: TensorShape, dim: Int) -> Int:
27 | # Number of chunks up to the concatenating dimension
28 | # Assuming tensor of equal shape, except for the concatenating dimension
29 | var chunks = 1
30 | for i in range(dim):
31 | chunks *= shape[i]
32 | return chunks
33 |
34 | @staticmethod
35 | fn forward[attributes: AttributeVector](
36 | inputs: List[Symbol],
37 | outputs: List[Symbol],
38 | inout parameters: Parameters,
39 | ):
40 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
41 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim)
42 |
43 | var chunks = List[Int]()
44 | var chunk_offsets = List[Int](0)
45 | for i in range(len(inputs)):
46 | chunks.append(inputs[i].shape.num_elements() // n_chunks)
47 | chunk_offsets.append(chunk_offsets[i] + chunks[i])
48 |
49 | for i in range(n_chunks):
50 | for j in range(len(inputs)):
51 | memcpy(
52 | parameters.tensors[outputs[0]].data()
53 | + i * chunk_offsets[len(inputs)]
54 | + chunk_offsets[j],
55 | parameters.tensors[inputs[j]].data() + i * chunks[j],
56 | chunks[j],
57 | )
58 |
59 | @staticmethod
60 | fn backward[input_id: Int, attributes: AttributeVector](
61 | inputs: List[Symbol],
62 | outputs: List[Symbol],
63 | inout parameters: Parameters,
64 | ) -> Tensor[dtype]:
65 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
66 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim)
67 |
68 | var chunks = List[Int]()
69 | var chunk_offsets = List[Int](0)
70 | for i in range(len(inputs)):
71 | chunks.append(inputs[i].shape.num_elements() // n_chunks)
72 | chunk_offsets.append(chunk_offsets[i] + chunks[i])
73 |
74 | var res_grad = Tensor[dtype](inputs[input_id].shape)
75 | for i in range(n_chunks):
76 | memcpy(
77 | res_grad.data() + i * chunks[input_id],
78 | parameters.grads[outputs[0]].data()
79 | + i * chunk_offsets[len(inputs)]
80 | + chunk_offsets[input_id],
81 | chunks[input_id],
82 | )
83 |
84 | return res_grad ^
85 |
86 |
87 | struct SPLIT:
88 | @staticmethod
89 | fn result_shape(
90 | input_shapes: List[TensorShape], attributes: AttributeVector
91 | ) -> List[TensorShape]:
92 | # Assuming the sum of the sections is equal to the total size in the dim dimension.
93 | # E.g. sections = [5, 5, 2] -> shape (., 12, ., .) for dim = 1
94 | var dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
95 | var sections = attributes["sections"].value().to_shape()
96 |
97 | var res_shapes = List[TensorShape]()
98 | for i in range(sections.rank()):
99 | var new_shape = input_shapes[0]
100 | new_shape[dim] = sections[i]
101 | res_shapes.append(new_shape)
102 |
103 | return res_shapes
104 |
105 | @staticmethod
106 | fn calc_chunks(shape: TensorShape, dim: Int) -> Int:
107 | # Number of chunks up to the concatenating dimension
108 | # Assuming tensor of equal shape, except for the concatenating dimension
109 | var chunks = 1
110 | for i in range(dim):
111 | chunks *= shape[i]
112 | return chunks
113 |
114 | @staticmethod
115 | fn forward[attributes: AttributeVector](
116 | inputs: List[Symbol],
117 | outputs: List[Symbol],
118 | inout parameters: Parameters,
119 | ):
120 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
121 | alias sections = attributes["sections"].value().to_shape()
122 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim)
123 |
124 | var chunks = List[Int]()
125 | var chunk_offsets = List[Int](0)
126 | for i in range(len(outputs)):
127 | chunks.append(outputs[i].shape.num_elements() // n_chunks)
128 | chunk_offsets.append(chunk_offsets[i] + chunks[i])
129 |
130 | for i in range(n_chunks):
131 | for j in range(len(outputs)):
132 | memcpy(
133 | parameters.tensors[outputs[j]].data() + i * chunks[j],
134 | parameters.tensors[inputs[0]].data()
135 | + i * chunk_offsets[len(outputs)]
136 | + chunk_offsets[j],
137 | chunks[j],
138 | )
139 |
140 | @staticmethod
141 | fn backward[input_id: Int, attributes: AttributeVector](
142 | inputs: List[Symbol],
143 | outputs: List[Symbol],
144 | inout parameters: Parameters,
145 | ) -> Tensor[dtype]:
146 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0
147 | alias sections = attributes["sections"].value().to_shape()
148 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim)
149 |
150 | var chunks = List[Int]()
151 | var chunk_offsets = List[Int](0)
152 | for i in range(len(outputs)):
153 | chunks.append(outputs[i].shape.num_elements() // n_chunks)
154 | chunk_offsets.append(chunk_offsets[i] + chunks[i])
155 |
156 | var res_grad = Tensor[dtype](inputs[input_id].shape)
157 |
158 | for i in range(n_chunks):
159 | for j in range(len(outputs)):
160 | memcpy(
161 | res_grad.data()
162 | + i * chunk_offsets[len(outputs)]
163 | + chunk_offsets[j],
164 | parameters.grads[outputs[j]].data() + i * chunks[j],
165 | chunks[j],
166 | )
167 |
168 | return res_grad ^
169 |
--------------------------------------------------------------------------------
/basalt/autograd/ops/matmul.mojo:
--------------------------------------------------------------------------------
1 | from basalt.utils.tensorutils import transpose_2D
2 |
3 | from algorithm import vectorize, parallelize
4 | from memory import memset_zero, stack_allocation, UnsafePointer
5 | from sys.info import simdwidthof
6 |
7 |
8 | @always_inline
9 | fn calculate_block[
10 | M: Int, N: Int, K: Int, BLOCK_M: Int, BLOCK_N: Int, nelts: Int
11 | ](
12 | res: UnsafePointer[Scalar[dtype]],
13 | t1: UnsafePointer[Scalar[dtype]],
14 | t2: UnsafePointer[Scalar[dtype]],
15 | bm: Int,
16 | bn: Int,
17 | ):
18 | # Compute tile
19 | var acc = stack_allocation[BLOCK_M * BLOCK_N, dtype]()
20 | memset_zero(acc, BLOCK_M * BLOCK_N)
21 |
22 | for k in range(K):
23 |
24 | @parameter
25 | for m in range(BLOCK_M):
26 |
27 | @parameter
28 | fn inner_n[nelts: Int](n: Int):
29 | acc.store(
30 | m * BLOCK_N + n,
31 | SIMD[dtype, nelts](t1[(bm + m) * K + k])
32 | .fma(
33 | t2.load[width=nelts](k * N + (bn + n)),
34 | acc.load[width=nelts](m * BLOCK_N + n),
35 | ),
36 | )
37 |
38 | vectorize[inner_n, nelts](BLOCK_N)
39 |
40 | # Store tile
41 | for m in range(BLOCK_M):
42 |
43 | @parameter
44 | fn vec_store[nelts: Int](n: Int):
45 | res.store(
46 | (bm + m) * N + (bn + n), acc.load[width=nelts](m * BLOCK_N + n)
47 | )
48 |
49 | vectorize[vec_store, nelts](BLOCK_N)
50 |
51 |
52 | @parameter
53 | @always_inline
54 | fn dot[
55 | t1_shape: TensorShape, t2_shape: TensorShape
56 | ](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype]):
57 | dot[t1_shape, t2_shape](res.data(), t1.data(), t2.data())
58 |
59 |
60 | @parameter
61 | @always_inline
62 | fn dot[
63 | t1_shape: TensorShape, t2_shape: TensorShape
64 | ](res: UnsafePointer[Scalar[dtype]], t1: UnsafePointer[Scalar[dtype]], t2: UnsafePointer[Scalar[dtype]]):
65 | alias M = t1_shape[0] # t1[0]
66 | alias K = t1_shape[1] # t1[1], t2[0]
67 | alias N = t2_shape[1] # t2[1]
68 |
69 | # simdwidthof[dtype]() = 8 for float32
70 | alias nelts = simdwidthof[dtype]()
71 | alias BLOCK_N = 8 * 2
72 | alias BLOCK_M = 6
73 | alias THREADS = 6 # num_logical_cores()
74 |
75 | alias BLOCK_N_REMAINDER = N % BLOCK_N
76 | alias BLOCK_M_REMAINDER = M % BLOCK_M
77 |
78 | @parameter
79 | fn bm_par(m_outer: Int):
80 | var bm = m_outer * BLOCK_M
81 |
82 | for n_outer in range(0, N // BLOCK_N):
83 | var bn = n_outer * BLOCK_N
84 |
85 | calculate_block[M, N, K, BLOCK_M, BLOCK_N, nelts](res, t1, t2, bm, bn)
86 |
87 | # Handle the remainder of N
88 | @parameter
89 | if BLOCK_N_REMAINDER > 0:
90 | var bn = N - BLOCK_N_REMAINDER
91 |
92 | calculate_block[M, N, K, BLOCK_M, BLOCK_N_REMAINDER, nelts](
93 | res, t1, t2, bm, bn
94 | )
95 |
96 | parallelize[bm_par](M // BLOCK_M, M // BLOCK_M)
97 |
98 | # Handle the remainder of M
99 | @parameter
100 | if BLOCK_M_REMAINDER > 0:
101 | var bm = M - BLOCK_M_REMAINDER
102 |
103 | for n_outer in range(0, N // BLOCK_N):
104 | var bn = n_outer * BLOCK_N
105 |
106 | calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N, nelts](
107 | res, t1, t2, bm, bn
108 | )
109 |
110 | # Handle corner remainder
111 | @parameter
112 | if BLOCK_N_REMAINDER > 0:
113 | var bn = N - BLOCK_N_REMAINDER
114 |
115 | calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N_REMAINDER, nelts](
116 | res, t1, t2, bm, bn
117 | )
118 |
119 |
120 | fn dot_transpose_t2[
121 | A_shape: TensorShape, B_shape: TensorShape
122 | ](inout C: UnsafePointer[Scalar[dtype]], A: UnsafePointer[Scalar[dtype]], B: UnsafePointer[Scalar[dtype]]):
123 | dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B))
124 |
125 |
126 | fn dot_transpose_t2[
127 | A_shape: TensorShape, B_shape: TensorShape
128 | ](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]):
129 | memset_zero(C.data(), C.num_elements())
130 |
131 | dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B))
132 |
133 | # @parameter
134 | # fn calc_row(i: Int):
135 | # for j in range(B_shape[0]):
136 |
137 | # @parameter
138 | # fn calc_row_A_B[nelts: Int](k: Int):
139 | # var A_pos = i * A.dim(1) + k
140 | # var B_pos = j * A.dim(1) + k
141 | # var t_new_pos = i * C.dim(1) + j
142 |
143 | # C[t_new_pos] += (
144 | # A.load[nelts](A_pos) * B.load[nelts](B_pos)
145 | # ).reduce_add()
146 |
147 | # vectorize[calc_row_A_B, nelts, size=A_shape[1]]()
148 |
149 | # parallelize[calc_row](A_shape[0], 1)
150 |
151 |
152 | fn dot_transpose_t1[
153 | A_shape: TensorShape, B_shape: TensorShape
154 | ](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]):
155 | memset_zero(C.data(), C.num_elements())
156 |
157 | dot[TensorShape(A_shape[1], A_shape[0]), B_shape](C, transpose_2D[A_shape](A), B)
158 |
159 | # @parameter
160 | # fn calc_row(i: Int):
161 | # for j in range(A_shape[0]):
162 |
163 | # @parameter
164 | # fn calc_row_t_new_B[nelts: Int](k: Int):
165 | # var A_pos = j * A.dim(1) + i
166 | # var B_pos = j * B.dim(1) + k
167 | # var t_new_pos = i * C.dim(1) + k
168 |
169 | # C.store[nelts](
170 | # t_new_pos,
171 | # C.load[nelts](t_new_pos)
172 | # + A[A_pos] * B.load[nelts](B_pos),
173 | # )
174 |
175 | # vectorize[calc_row_t_new_B, nelts, size=B_shape[1]]()
176 |
177 | # parallelize[calc_row](A_shape[1], 1)
178 |
--------------------------------------------------------------------------------
/basalt/autograd/ops/pool.mojo:
--------------------------------------------------------------------------------
1 | from utils.numerics import min_or_neg_inf
2 |
3 | from basalt import Tensor, TensorShape
4 | from basalt.autograd.attributes import AttributeVector
5 | from basalt.autograd.ops.conv import get_result_shape
6 |
7 |
8 | struct MAXPOOL2D:
9 | @staticmethod
10 | fn result_shape(
11 | input_shape: TensorShape, attributes: AttributeVector
12 | ) -> TensorShape:
13 | var kernel_size = attributes["kernel_size"].value().to_static[2]()
14 | var padding = attributes["padding"].value().to_static[2]()
15 | var stride = attributes["stride"].value().to_static[2]()
16 | var dilation = attributes["dilation"].value().to_static[2]()
17 |
18 | var res = get_result_shape(
19 | input_shape,
20 | TensorShape(kernel_size[0], kernel_size[1]),
21 | padding,
22 | stride,
23 | dilation,
24 | )
25 |
26 | return TensorShape(input_shape[0], input_shape[1], res[0], res[1])
27 |
28 | @staticmethod
29 | fn forward[
30 | input_shape: TensorShape, attributes: AttributeVector
31 | ](inout outputs: Tensor[dtype], inputs: Tensor[dtype]):
32 | """
33 | Returns the max value of each kernel in the input tensor.
34 | inputs.shape [batch_size, channels, iX, iY]
35 | with kernel_size = (kX, kY)
36 | outputs.shape [batch_size, channels, oX, oY].
37 | """
38 | alias kernel_size = attributes["kernel_size"].value().to_static[2]()
39 | alias padding = attributes["padding"].value().to_static[2]()
40 | alias stride = attributes["stride"].value().to_static[2]()
41 | alias dilation = attributes["dilation"].value().to_static[2]()
42 |
43 | alias inputs_strides = input_shape.strides()
44 | alias output_shape = Self.result_shape(input_shape, attributes)
45 | alias outputs_strides = output_shape.strides()
46 |
47 | for batch in range(input_shape[0]):
48 | for in_ch in range(input_shape[1]):
49 | for x in range(output_shape[2]):
50 | for y in range(output_shape[3]):
51 | var max_val: Scalar[dtype] = min_or_neg_inf[dtype]()
52 | var ix_base = x * stride[0] - padding[0]
53 | var iy_base = y * stride[1] - padding[1]
54 | for kx in range(kernel_size[0]):
55 | for ky in range(kernel_size[1]):
56 | var ix = ix_base + kx * dilation[0]
57 | var iy = iy_base + ky * dilation[1]
58 |
59 | if (
60 | ix < 0
61 | or iy < 0
62 | or ix >= input_shape[2]
63 | or iy >= input_shape[3]
64 | ):
65 | continue
66 |
67 | var idx = (
68 | batch * inputs_strides[0]
69 | + in_ch * inputs_strides[1]
70 | + ix * inputs_strides[2]
71 | + iy
72 | )
73 |
74 | var val = inputs[idx]
75 | if val > max_val:
76 | max_val = val
77 |
78 | var out_idx = (
79 | batch * outputs_strides[0]
80 | + in_ch * outputs_strides[1]
81 | + x * outputs_strides[2]
82 | + y
83 | )
84 |
85 | outputs[out_idx] = max_val
86 |
87 | @staticmethod
88 | fn backward[
89 | ug_shape: TensorShape, input_shape: TensorShape, attributes: AttributeVector
90 | ](ug: Tensor[dtype], inputs: Tensor[dtype]) -> Tensor[dtype]:
91 | """
92 | Backward operation of MAXPOOL2D.
93 |
94 | Upper gradient of shape: [batch_size, channels, uX, uY]
95 | """
96 | alias kernel_size = attributes["kernel_size"].value().to_static[2]()
97 | alias padding = attributes["padding"].value().to_static[2]()
98 | alias stride = attributes["stride"].value().to_static[2]()
99 | alias dilation = attributes["dilation"].value().to_static[2]()
100 |
101 | alias ug_strides = ug_shape.strides()
102 | alias inputs_strides = input_shape.strides()
103 |
104 | var res = Tensor[dtype](input_shape)
105 |
106 | for batch in range(input_shape[0]):
107 | for in_ch in range(input_shape[1]):
108 | for x in range(ug_shape[2]):
109 | for y in range(ug_shape[3]):
110 | var max_val: Scalar[dtype] = min_or_neg_inf[dtype]()
111 | var max_idx: Int = -1
112 | var ix_base = x * stride[0] - padding[0]
113 | var iy_base = y * stride[1] - padding[1]
114 | for kx in range(kernel_size[0]):
115 | for ky in range(kernel_size[1]):
116 | var ix = ix_base + kx * dilation[0]
117 | var iy = iy_base + ky * dilation[1]
118 |
119 | if (
120 | ix < 0
121 | or iy < 0
122 | or ix >= input_shape[2]
123 | or iy >= input_shape[3]
124 | ):
125 | continue
126 |
127 | var idx = (
128 | batch * inputs_strides[0]
129 | + in_ch * inputs_strides[1]
130 | + ix * inputs_strides[2]
131 | + iy
132 | )
133 |
134 | var val = inputs[idx]
135 | if val > max_val:
136 | max_val = val
137 | max_idx = idx
138 |
139 | var ug_idx = (
140 | batch * ug_strides[0]
141 | + in_ch * ug_strides[1]
142 | + x * ug_strides[2]
143 | + y
144 | )
145 |
146 | res[max_idx] += ug[ug_idx]
147 |
148 | return res
149 |
--------------------------------------------------------------------------------
/basalt/autograd/params.mojo:
--------------------------------------------------------------------------------
1 | from collections.optional import Optional
2 | from memory import UnsafePointer
3 |
4 | from basalt import dtype
5 | from basalt import Tensor, TensorShape
6 | from .symbol import Symbol
7 | from .attributes import Attribute
8 |
9 |
10 | @value
11 | struct Param(CollectionElement, Stringable):
12 | var data: Optional[List[Scalar[dtype]]]
13 | var initializer: Optional[Attribute]
14 |
15 | fn __init__(inout self):
16 | self.data = None
17 | self.initializer = None
18 |
19 | fn __init__(inout self, data: List[Scalar[dtype]]):
20 | self.data = data
21 | self.initializer = None
22 |
23 | fn __init__(inout self, data: Scalar[dtype]):
24 | self.data = List[Scalar[dtype]](data)
25 | self.initializer = None
26 |
27 | fn __init__(inout self, initializer: String, *args: Scalar[dtype]):
28 | # Supported initializers:
29 | # "random_uniform", lower_bound, upper_bound
30 | # "random_normal", mean, std
31 | # #TODO: "kaiming_uniform", mode, nonlinearity
32 | # #TODO: "kaiming_normal", mode, nonlinearity
33 | self.initializer = Attribute("initializer", initializer)
34 | var data = List[Scalar[dtype]]()
35 | for arg in args:
36 | data.append(arg)
37 | self.data = data
38 |
39 | fn __getitem__(self, i: Int) -> Optional[Scalar[dtype]]:
40 | if self.data:
41 | return self.data.value()[i]
42 | else:
43 | return None
44 |
45 | fn __str__(self) -> String:
46 | var s: String = ""
47 | if self.data:
48 | var data = self.data.value()
49 | s += "["
50 | for i in range(len(data)):
51 | s += str(data[i])
52 | if i < len(data) - 1:
53 | s += ", "
54 | s += "]"
55 | return s
56 |
57 |
58 | @value
59 | struct ParamDict(Sized):
60 | var symbols: List[Symbol]
61 | var values: List[Param]
62 |
63 | fn __init__(inout self):
64 | self.symbols = List[Symbol]()
65 | self.values = List[Param]()
66 |
67 | fn put(inout self, param_id: Symbol, value: Param = Param()):
68 | self.symbols.append(param_id)
69 | self.values.append(value)
70 |
71 | fn get_tensor(self, idx: Int) -> Tensor[dtype]:
72 | # May only be called at runtime
73 | var num = self.symbols[idx].shape.num_elements()
74 | var t = UnsafePointer[Scalar[dtype]].alloc(num)
75 | for i in range(num):
76 | t[i] = self.values[idx][i].value()
77 | return Tensor[dtype](t, self.symbols[idx].shape)
78 |
79 | fn __len__(self) -> Int:
80 | return len(self.symbols)
81 |
--------------------------------------------------------------------------------
/basalt/autograd/symbol.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Tensor, TensorShape
2 |
3 |
4 | @value
5 | @register_passable("trivial")
6 | struct Symbol(CollectionElement, Stringable, EqualityComparable):
7 | var name: UInt32
8 | var dtype: DType
9 | var shape: TensorShape
10 | var trainable: Bool
11 |
12 | fn __eq__(self, other: Self) -> Bool:
13 | return self.name == other.name
14 |
15 | fn __ne__(self, other: Self) -> Bool:
16 | return self.name != other.name
17 |
18 | fn __str__(self) -> String:
19 | return self.json()
20 |
21 | fn json(self) -> String:
22 | return (
23 | '{"name": "'
24 | + str(self.name)
25 | + '", "dtype": "'
26 | + str(self.dtype)
27 | + '", "shape": "'
28 | + str(self.shape)
29 | + '", "trainable": "'
30 | + str(self.trainable)
31 | + '"}'
32 | )
33 |
--------------------------------------------------------------------------------
/basalt/nn/__init__.mojo:
--------------------------------------------------------------------------------
1 | from .tensor import Tensor, TensorShape
2 | from .model import Model
3 |
4 | from .layers.linear import Linear
5 | from .layers.conv import Conv2d
6 | from .layers.pool import MaxPool2d
7 |
8 | from .loss import MSELoss, CrossEntropyLoss
9 | from .activations import (
10 | Softmax,
11 | LogSoftmax,
12 | ReLU,
13 | LeakyReLU,
14 | Sigmoid,
15 | Tanh,
16 | )
17 |
--------------------------------------------------------------------------------
/basalt/nn/activations.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Tensor, TensorShape
2 | from basalt import Graph, Symbol, OP
3 | from basalt.autograd.attributes import Attribute, AttributeVector
4 |
5 |
6 | # '''Activation functions.'''
7 | fn ReLU(inout g: Graph, input: Symbol) -> Symbol:
8 | return g.op(OP.RELU, input)
9 |
10 |
11 | fn LeakyReLU(
12 | inout g: Graph, input: Symbol, negative_slope: Scalar[dtype]
13 | ) -> Symbol:
14 | return g.op(
15 | OP.LEAKYRELU,
16 | input,
17 | attributes=AttributeVector(Attribute("negative_slope", negative_slope)),
18 | )
19 |
20 |
21 | fn Sigmoid(inout g: Graph, input: Symbol) -> Symbol:
22 | return g.op(OP.SIGMOID, input)
23 |
24 |
25 | fn Tanh(inout g: Graph, input: Symbol) -> Symbol:
26 | return g.op(OP.TANH, input)
27 |
28 |
29 | fn Softmax(inout g: Graph, input: Symbol, axis: Int) -> Symbol:
30 | # softmax: exp(x_i) / sum(exp(x_j))
31 | # stable softmax: exp(x_i - max(x_j)) / sum(exp(x_j - max(x_j)))
32 |
33 | var max_values = g.op(
34 | OP.MAX, input, attributes=AttributeVector(Attribute("axis", axis))
35 | )
36 | var input_minus_max = g.op(OP.SUB, input, max_values)
37 | var exp_values = g.op(OP.EXP, input_minus_max)
38 | var sum_values = g.op(
39 | OP.SUM, exp_values, attributes=AttributeVector(Attribute("axis", axis))
40 | )
41 |
42 | return g.op(OP.DIV, exp_values, sum_values)
43 |
44 |
45 | fn LogSoftmax(inout g: Graph, input: Symbol, axis: Int) -> Symbol:
46 | # stable logsoftmax: log(exp(x_i - max(x_j)) / sum(exp(x_j - max(x_j))))
47 | # stable logsoftmax: x_i - max(x_j) - log(sum(exp(x_j - max(x_j))))
48 |
49 | var max_values = g.op(
50 | OP.MAX, input, attributes=AttributeVector(Attribute("axis", axis))
51 | )
52 | var input_minus_max = g.op(OP.SUB, input, max_values)
53 | var exp_values = g.op(OP.EXP, input_minus_max)
54 | var sum_values = g.op(
55 | OP.SUM, exp_values, attributes=AttributeVector(Attribute("axis", axis))
56 | )
57 | var log_values = g.op(OP.LOG, sum_values)
58 |
59 | return g.op(OP.SUB, input_minus_max, log_values)
60 |
--------------------------------------------------------------------------------
/basalt/nn/initializers.mojo:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | from basalt import dtype
4 | from basalt import Tensor, TensorShape
5 | from basalt.utils.rand_utils import rand_normal, rand_uniform
6 |
7 |
8 | fn initialize_tensor(
9 | shape: TensorShape, type: String, data: List[Scalar[dtype]]
10 | ) -> Tensor[dtype]:
11 | if type == "random_uniform":
12 | var low = data[0]
13 | var high = data[1]
14 | var t = Tensor[dtype](shape)
15 | rand_uniform(t, low=low, high=high)
16 | return t
17 | elif type == "random_normal":
18 | var mean = data[0].cast[DType.float64]()
19 | var std = data[1].cast[DType.float64]()
20 | var t = Tensor[dtype](shape)
21 | rand_normal(t, mean=mean, std=std)
22 | return t
23 | # elif type == "kaiming_uniform":
24 | # # mode, nonlinearity
25 | # var mode_id = data[0]
26 | # var mode = "fan_in" if mode_id == 0 else "fan_out"
27 | # return kaiming_uniform(shape, mode = mode)
28 | # elif type == "kaiming_normal":
29 | # # mode, nonlinearity
30 | # var mode_id = data[0]
31 | # var mode = "fan_in" if mode_id == 0 else "fan_out"
32 | # return kaiming_normal(shape, mode = mode)
33 | else:
34 | print("[ERROR] Unsupported initialization type: " + type)
35 | return Tensor[dtype]()
36 |
37 |
38 | fn calculate_fan(shape: TensorShape, mode: String) -> Scalar[dtype]:
39 | """
40 | Calculate the fan-in and fan-out of any tensor.
41 | """
42 | # NOTE: shape.rank() should be > 2
43 | # mode: "fan_in" or "fan_out"
44 | if shape.rank() < 2:
45 | print(
46 | "[ERROR] Fan in and fan out can not be calculated for tensor with less than"
47 | " 2 dimensions"
48 | )
49 |
50 | var num_input_fmaps = shape[1]
51 | var num_output_fmaps = shape[0]
52 | var receptive_field_size = 1
53 | if shape.rank() > 2:
54 | for i in range(2, shape.rank()):
55 | receptive_field_size *= shape[i]
56 |
57 | var fan_in = num_input_fmaps * receptive_field_size
58 | var fan_out = num_output_fmaps * receptive_field_size
59 |
60 | if mode == "fan_in":
61 | return fan_in
62 | else:
63 | return fan_out
64 |
65 |
66 | # # TODO: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
67 | # fn kaiming_uniform(shape: TensorShape, mode: String = "fan_in", nonlinearity: String = "leaky_relu") -> Tensor[dtype]:
68 | # var fan = calculate_fan(shape, mode)
69 |
70 | # # TODO: add support for other gains: https://github.com/pytorch/pytorch/blob/main/torch/nn/init.py#L68
71 | # # Gain for linear and conv layers is 1
72 | # var gain = 1
73 | # var std = gain / sqrt(fan)
74 |
75 | # # var bound = sqrt(3) * std.cast[dtype]()
76 | # var bound = std.cast[dtype]()
77 |
78 | # # print("Shape", shape, "Fan", fan, "Bound", bound)
79 |
80 | # var t = Tensor[dtype](shape)
81 | # rand_uniform(t, low = -bound, high = bound)
82 | # return t^
83 |
84 |
85 | # # TODO: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
86 | # fn kaiming_normal(shape: TensorShape, mode: String = "fan_in", nonlinearity: String = "leaky_relu") -> Tensor[dtype]:
87 | # var fan = calculate_fan(shape, mode)
88 |
89 | # # TODO: add support for other gains: https://github.com/pytorch/pytorch/blob/main/torch/nn/init.py#L68
90 | # # Gain for linear and conv layers is 1
91 | # var gain = 1
92 | # var std = gain / sqrt(fan)
93 |
94 | # var t = Tensor[dtype](shape)
95 | # rand_normal(t, mean = 0, std = std.cast[DType.float64]())
96 | # return t^
97 |
--------------------------------------------------------------------------------
/basalt/nn/layers/__init__.mojo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basalt-org/basalt/fe16eadcfcfee9271f9df0dd94d11b7c50d868ba/basalt/nn/layers/__init__.mojo
--------------------------------------------------------------------------------
/basalt/nn/layers/conv.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Graph, Symbol, OP
2 | from basalt import Tensor, TensorShape
3 | from basalt.utils import q_sqrt
4 | from basalt.autograd.params import Param
5 | from basalt.autograd.attributes import AttributeVector, Attribute
6 |
7 | from utils.index import IndexList
8 |
9 |
10 | fn Conv2d(
11 | inout g: Graph,
12 | inputs: Symbol,
13 | out_channels: Int,
14 | kernel_size: IndexList[2],
15 | padding: IndexList[2] = 0,
16 | stride: IndexList[2] = 1,
17 | dilation: IndexList[2] = 1,
18 | ) -> Symbol:
19 | """
20 | A 2D Convolution Layer.
21 |
22 | Parameters
23 | inputs.shape [batch, in_channels, iX, iY]
24 | kernel.shape [out_channels, in_channels, kX, kY] (or weights)
25 | bias.shape [out_channels].
26 | output.shape [batch, out_channels, oX, oY].
27 | """
28 |
29 | var in_channels: Int = inputs.shape[1]
30 | var fan_in: Scalar[dtype] = in_channels * kernel_size[0] * kernel_size[1]
31 | var bound = q_sqrt(fan_in)
32 | var weights = g.param(
33 | TensorShape(out_channels, in_channels, kernel_size[0], kernel_size[1]),
34 | init=Param("random_uniform", -bound, bound)
35 | # init=Param("kaiming_uniform", 0)
36 | )
37 | var bias = g.param(
38 | TensorShape(out_channels), init=Param("random_uniform", -bound, bound)
39 | )
40 |
41 | return g.op(
42 | OP.CONV2D,
43 | inputs,
44 | weights,
45 | bias,
46 | attributes=AttributeVector(
47 | Attribute("padding", padding),
48 | Attribute("stride", stride),
49 | Attribute("dilation", dilation),
50 | ),
51 | )
52 |
--------------------------------------------------------------------------------
/basalt/nn/layers/dropout.mojo:
--------------------------------------------------------------------------------
1 | # TODO
2 |
--------------------------------------------------------------------------------
/basalt/nn/layers/linear.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Tensor, TensorShape
2 | from basalt import Graph, Symbol, OP
3 | from basalt.utils import q_sqrt
4 | from basalt.autograd.params import Param
5 |
6 |
7 | fn Linear(
8 | inout g: Graph,
9 | inputs: Symbol,
10 | n_outputs: Int,
11 | ) -> Symbol:
12 | """
13 | A fully connected layer.
14 | """
15 |
16 | var fan_in: Scalar[dtype] = inputs.shape[1]
17 | var bound = q_sqrt(fan_in)
18 | var weights = g.param(
19 | TensorShape(inputs.shape[1], n_outputs),
20 | init=Param("random_uniform", -bound, bound)
21 | # init=Param("random_uniform", 1) # NOTE: mode: fan_out required as weight are defined transposed
22 | )
23 | var b = g.param(TensorShape(n_outputs), init=Param("random_uniform", -bound, bound))
24 |
25 | var res = g.op(OP.DOT, inputs, weights)
26 | return g.op(OP.ADD, res, b)
27 |
--------------------------------------------------------------------------------
/basalt/nn/layers/pool.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Tensor, TensorShape
2 | from collections.optional import Optional
3 | from utils.index import IndexList
4 |
5 | from basalt import Graph, Symbol, OP
6 | from basalt.autograd.attributes import AttributeVector, Attribute
7 |
8 |
9 | fn set_static_stride(
10 | kernel_size: IndexList[2], stride: Optional[Int] = None
11 | ) -> IndexList[2]:
12 | if stride:
13 | return IndexList[2](stride.value(), stride.value())
14 | else:
15 | return kernel_size
16 |
17 |
18 | fn MaxPool2d(
19 | inout g: Graph,
20 | inputs: Symbol,
21 | kernel_size: IndexList[2],
22 | stride: Optional[Int] = None,
23 | padding: IndexList[2] = 0,
24 | dilation: IndexList[2] = 1,
25 | ) -> Symbol:
26 | """
27 | A 2D Max Pooling Layer.
28 |
29 | Kernel is unaware of the in_channels and out_channels of the input tensor.
30 | kernel.size (kX, kY)
31 | """
32 |
33 | # TODO: assert padding <= kernel_size / 2 (at compile time)
34 |
35 | var stride_temp = set_static_stride(kernel_size, stride)
36 |
37 | return MaxPool2d(g, inputs, kernel_size, stride_temp, padding, dilation)
38 |
39 |
40 | fn MaxPool2d(
41 | inout g: Graph,
42 | inputs: Symbol,
43 | kernel_size: IndexList[2],
44 | stride: IndexList[2], # stride should be 1 or more
45 | padding: IndexList[2] = 0,
46 | dilation: IndexList[2] = 1,
47 | ) -> Symbol:
48 | """
49 | A 2D Max Pooling Layer.
50 |
51 | Kernel is unaware of the in_channels and out_channels of the input tensor.
52 | kernel.size (kX, kY)
53 | """
54 | # TODO: assert padding <= kernel_size / 2 (at compile time)
55 |
56 | return g.op(
57 | OP.MAXPOOL2D,
58 | inputs,
59 | attributes=AttributeVector(
60 | Attribute("kernel_size", kernel_size),
61 | Attribute("padding", padding),
62 | Attribute("stride", stride),
63 | Attribute("dilation", dilation),
64 | ),
65 | )
66 |
67 |
68 | # # TODO
69 |
--------------------------------------------------------------------------------
/basalt/nn/layers/sequential.mojo:
--------------------------------------------------------------------------------
1 | # TODO
2 |
--------------------------------------------------------------------------------
/basalt/nn/loss.mojo:
--------------------------------------------------------------------------------
1 | import basalt.nn as nn
2 | from basalt import Tensor, TensorShape
3 | from basalt import Graph, Symbol, OP
4 |
5 |
6 | fn MSELoss(
7 | inout g: Graph,
8 | y_pred: Symbol,
9 | y_true: Symbol,
10 | ) -> Symbol:
11 | # 1/N * sum( (outputs - targets)^2 )
12 |
13 | var diff = g.op(OP.SUB, y_true, y_pred)
14 | var loss = g.op(OP.POW, diff, 2)
15 | var mean_loss = g.op(OP.MEAN, loss)
16 |
17 | return mean_loss
18 |
19 |
20 | fn CrossEntropyLoss(
21 | inout g: Graph,
22 | y_pred: Symbol,
23 | y_true: Symbol,
24 | ) -> Symbol:
25 | # -1/N * sum( targets * log_softmax(outputs) )
26 |
27 | var log_softmax = nn.LogSoftmax(g, y_pred, axis=1)
28 |
29 | # CrossEntropy (reduction Mean)
30 | var targets_log_softmax = g.op(OP.MUL, y_true, log_softmax)
31 | var ret = g.op(OP.SUM, targets_log_softmax)
32 | var negDivN = g.op(OP.MUL, ret, -1.0 / y_pred.shape[0])
33 |
34 | return negDivN
35 |
--------------------------------------------------------------------------------
/basalt/nn/optim.mojo:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from algorithm import vectorize, parallelize
3 |
4 | from .model import Parameters
5 | from basalt import Graph, Tensor, TensorShape
6 | from basalt.utils.collection import Collection
7 | from basalt.utils.math_util import add, sub, mul, div
8 |
9 |
10 | fn get_trainable_parameters(g: Graph) -> List[Symbol]:
11 | """
12 | Get all symbols of trainable parameters.
13 | """
14 |
15 | var trainable_parameters = List[Symbol]()
16 |
17 | for i in range(len(g.params)):
18 | if g.params.symbols[i].trainable:
19 | trainable_parameters.append(g.params.symbols[i])
20 |
21 | return trainable_parameters ^
22 |
23 |
24 | @value
25 | struct Adam[
26 | g: Graph,
27 | trainable_parameters: List[Symbol] = get_trainable_parameters(g),
28 | ]:
29 | var parameters: Pointer[Parameters, MutableAnyOrigin]
30 |
31 | var lr: Scalar[dtype]
32 | var beta1: Scalar[dtype]
33 | var beta2: Scalar[dtype]
34 | var epsilon: Scalar[dtype]
35 | var iter: Int
36 |
37 | var rms_grads: Collection
38 | var momentum_grads: Collection
39 |
40 | fn __init__(
41 | inout self,
42 | ref[MutableAnyOrigin] parameters: Parameters,
43 | lr: Scalar[dtype] = 0.001,
44 | beta1: Scalar[dtype] = 0.9,
45 | beta2: Scalar[dtype] = 0.999,
46 | epsilon: Scalar[dtype] = 1e-8,
47 | ):
48 | self.parameters = Pointer.address_of(parameters)
49 |
50 | self.lr = lr
51 | self.beta1 = beta1
52 | self.beta2 = beta2
53 | self.epsilon = epsilon
54 | self.iter = 0
55 |
56 | # Capacity of the collections should be the n of trainable parameters
57 | self.rms_grads = Collection(capacity=len(trainable_parameters))
58 | self.momentum_grads = Collection(capacity=len(trainable_parameters))
59 |
60 | self.allocate_rms_and_momentum()
61 |
62 | fn zero_grad(inout self):
63 | """Set all gradients to zero."""
64 | self.parameters[].grads.set_zero()
65 |
66 | fn step(inout self):
67 | """Update model parameters."""
68 | self.iter += 1
69 |
70 | # Loop over all trainable parameters
71 | @parameter
72 | fn p_step(i: Int):
73 | var param = trainable_parameters[i]
74 |
75 | @parameter
76 | fn v_step[nelts: Int](j: Int):
77 | var momentum_grads = self.momentum_grads[param].load[nelts](j)
78 | var rms_grads = self.rms_grads[param].load[nelts](j)
79 | var grads = self.parameters[].grads[param].load[nelts](j)
80 | var params = self.parameters[].tensors[param].load[nelts](j)
81 |
82 | # Momentum beta 1
83 | # f1 = beta1 * momentum + (1 - beta1) * grad
84 | momentum_grads = self.beta1 * momentum_grads + (1 - self.beta1) * grads
85 | self.momentum_grads[param].store[nelts](j, momentum_grads)
86 |
87 | # Bias correction
88 | # f2 = f1 / (1 - beta1 ** iter)
89 | momentum_grads = momentum_grads / (1 - self.beta1**self.iter)
90 |
91 | # RMS beta 2
92 | # f1 = beta2 * rms + (1 - beta2) * grad ** 2
93 | rms_grads = self.beta2 * rms_grads + (1 - self.beta2) * grads * grads
94 | self.rms_grads[param].store[nelts](j, rms_grads)
95 |
96 | # Bias correction
97 | # f2 = f1 / (1 - beta2 ** iter)
98 | rms_grads = rms_grads / (1 - self.beta2**self.iter)
99 |
100 | # tensor = tensor - lr * (f2 / (sqrt(rms) + epsilon))
101 | params = params - self.lr * (
102 | momentum_grads / (sqrt(rms_grads) + self.epsilon)
103 | )
104 | self.parameters[].tensors[param].store[nelts](j, params)
105 |
106 | vectorize[v_step, 1](param.shape.num_elements())
107 |
108 | parallelize[p_step](len(trainable_parameters))
109 |
110 | fn allocate_rms_and_momentum(inout self):
111 | # They are initialized to zero
112 | # Loop over all trainable parameters
113 | for i in range(len(trainable_parameters)):
114 | var param = trainable_parameters[i]
115 | self.rms_grads.append(Tensor[dtype](param.shape), param)
116 | self.momentum_grads.append(Tensor[dtype](param.shape), param)
117 |
--------------------------------------------------------------------------------
/basalt/nn/tensor.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_true
2 | from algorithm import vectorize
3 | from utils.index import IndexList
4 | from memory import memset_zero, memcpy, UnsafePointer
5 |
6 |
7 | alias MAX_RANK = 8
8 |
9 |
10 | @register_passable("trivial")
11 | struct TensorShape(Stringable):
12 | var _rank: Int
13 | var _shape: IndexList[MAX_RANK]
14 |
15 | fn __init__(inout self, *shape: Int):
16 | self._rank = len(shape)
17 | self._shape = IndexList[MAX_RANK]()
18 | for i in range(min(self._rank, MAX_RANK)):
19 | self._shape[i] = shape[i]
20 |
21 | fn __init__(inout self, shapes: VariadicList[Int]):
22 | self._rank = len(shapes)
23 | self._shape = IndexList[MAX_RANK]()
24 | for i in range(min(self._rank, MAX_RANK)):
25 | self._shape[i] = shapes[i]
26 |
27 | fn __init__(inout self, shape: List[Int]):
28 | self._rank = len(shape)
29 | self._shape = IndexList[MAX_RANK]()
30 | for i in range(min(self._rank, MAX_RANK)):
31 | self._shape[i] = shape[i]
32 |
33 | fn __init__[num: Int](inout self, shape: IndexList[num]):
34 | self._rank = num
35 | self._shape = IndexList[MAX_RANK]()
36 | for i in range(min(self._rank, MAX_RANK)):
37 | self._shape[i] = shape[i]
38 |
39 | fn __init__(inout self, rank: Int, shape: IndexList[MAX_RANK]):
40 | self._rank = rank
41 | self._shape = shape
42 |
43 | @always_inline("nodebug")
44 | fn __getitem__(self, index: Int) -> Int:
45 | return self._shape[index if index >= 0 else self._rank + index]
46 |
47 | @always_inline("nodebug")
48 | fn __setitem__(inout self, index: Int, value: Int):
49 | self._shape[index if index >= 0 else self._rank + index] = value
50 |
51 | @always_inline("nodebug")
52 | fn rank(self) -> Int:
53 | return self._rank
54 |
55 | fn num_elements(self) -> Int:
56 | var result = 1
57 | for i in range(self._rank):
58 | result *= self._shape[i]
59 | return result
60 |
61 | fn strides(self) -> IndexList[MAX_RANK]:
62 | var result = IndexList[MAX_RANK](0)
63 | result[self._rank - 1] = 1
64 | for i in range(self._rank - 2, -1, -1):
65 | result[i] = result[i + 1] * self._shape[i + 1]
66 | return result
67 |
68 | fn __str__(self) -> String:
69 | var s: String = "("
70 | for i in range(self._rank):
71 | s += str(self._shape[i])
72 | if i < self._rank - 1:
73 | s += ", "
74 | return s + ")"
75 |
76 | @always_inline("nodebug")
77 | fn __eq__(self, other: TensorShape) -> Bool:
78 | if self.rank() != other.rank():
79 | return False
80 | for i in range(self.rank()):
81 | if self[i] != other[i]:
82 | return False
83 | return True
84 |
85 | @always_inline("nodebug")
86 | fn __ne__(self, other: TensorShape) -> Bool:
87 | return not self.__eq__(other)
88 |
89 | fn __contains__(self, value: Int) -> Bool:
90 | for i in range(self.rank()):
91 | if self[i] == value:
92 | return True
93 | return False
94 |
95 | fn to_list(self) -> List[Int]:
96 | var result = List[Int]()
97 | for i in range(self.rank()):
98 | result.append(self[i])
99 | return result
100 |
101 |
102 | struct Tensor[dtype: DType](Stringable, Movable, CollectionElement):
103 | var _data: UnsafePointer[Scalar[dtype]]
104 | var _shape: TensorShape
105 |
106 | fn __init__(inout self, *dims: Int):
107 | self._shape = TensorShape(dims)
108 | self._data = UnsafePointer[Scalar[dtype]].alloc(self._shape.num_elements())
109 | memset_zero(self._data, self._shape.num_elements())
110 |
111 | fn __init__(inout self, owned shape: TensorShape):
112 | self._data = UnsafePointer[Scalar[dtype]].alloc(shape.num_elements())
113 | memset_zero(self._data, shape.num_elements())
114 | self._shape = shape
115 |
116 | fn __init__(inout self, shapes: VariadicList[Int]):
117 | self._shape = TensorShape(shapes)
118 | self._data = UnsafePointer[Scalar[dtype]].alloc(self._shape.num_elements())
119 | memset_zero(self._data, self._shape.num_elements())
120 |
121 | fn __init__(
122 | inout self, owned data: UnsafePointer[Scalar[dtype]], owned shape: TensorShape
123 | ):
124 | # NOTE: Remember to use _ = your_tensor that you passed, so there is no weird behavior in this function
125 | self._data = UnsafePointer[Scalar[dtype]].alloc(shape.num_elements())
126 | self._shape = shape
127 |
128 | memcpy(self._data, data, self._shape.num_elements())
129 | _ = data
130 |
131 | fn __moveinit__(inout self, owned other: Tensor[dtype]):
132 | self._data = other._data
133 | self._shape = other._shape
134 |
135 | fn __copyinit__(inout self, other: Tensor[dtype]):
136 | # print("[WARNING] Copying tensor")
137 | self._data = UnsafePointer[Scalar[dtype]].alloc(other._shape.num_elements())
138 | memcpy(self._data, other._data, other.num_elements())
139 | self._shape = other._shape
140 |
141 | @always_inline("nodebug")
142 | fn __getitem__(self, index: Int) -> Scalar[dtype]:
143 | return self._data[index]
144 |
145 | @always_inline("nodebug")
146 | fn __setitem__(self, index: Int, value: Scalar[dtype]):
147 | self._data[index] = value
148 |
149 | @always_inline("nodebug")
150 | fn data(self) -> UnsafePointer[Scalar[dtype]]:
151 | return self._data
152 |
153 | @always_inline("nodebug")
154 | fn shape(self) -> TensorShape:
155 | return self._shape
156 |
157 | @always_inline("nodebug")
158 | fn load[simd_width: Int](self, index: Int) -> SIMD[dtype, simd_width]:
159 | return self._data.load[width=simd_width](index)
160 |
161 | @always_inline("nodebug")
162 | fn store[simd_width: Int](self, index: Int, value: SIMD[dtype, simd_width]):
163 | self._data.store(index, value)
164 |
165 | @always_inline("nodebug")
166 | fn strides(self) -> IndexList[MAX_RANK]:
167 | return self._shape.strides()
168 |
169 | @always_inline("nodebug")
170 | fn rank(self) -> Int:
171 | return self._shape.rank()
172 |
173 | @always_inline("nodebug")
174 | fn num_elements(self) -> Int:
175 | return self._shape.num_elements()
176 |
177 | @always_inline("nodebug")
178 | fn dim(self, index: Int) -> Int:
179 | return self._shape[index]
180 |
181 | @always_inline("nodebug")
182 | fn zero(self):
183 | memset_zero(self._data, self.num_elements())
184 |
185 | @always_inline("nodebug")
186 | fn ireshape(inout self, new_shape: TensorShape) raises:
187 | # NOTE Consider not raising on error
188 | assert_true(self.num_elements() == new_shape.num_elements())
189 | self._shape = new_shape
190 |
191 | fn __str__(self) -> String:
192 | # temp fix
193 | var s: String = "["
194 | for i in range(self.num_elements()):
195 | s += str(self[i])
196 | if i < self.num_elements() - 1:
197 | s += ", "
198 | return s + "]"
199 |
200 |
201 | @always_inline("nodebug")
202 | fn __del__(owned self):
203 | self._data.free()
204 |
--------------------------------------------------------------------------------
/basalt/utils/__init__.mojo:
--------------------------------------------------------------------------------
1 | from memory.unsafe import bitcast
2 |
3 |
4 | @always_inline("nodebug")
5 | fn q_sqrt(value: Float32) -> Float32:
6 | var y = bitcast[DType.float32](0x5F3759DF - (bitcast[DType.uint32](value) >> 1))
7 | return -y * ((0.5 * value * y).fma(y, -1.5))
8 |
--------------------------------------------------------------------------------
/basalt/utils/bytes.mojo:
--------------------------------------------------------------------------------
1 | from math import nan
2 | from utils.numerics import inf
3 | from utils.static_tuple import StaticTuple
4 |
5 | alias ScalarBytes = DType.uint64.sizeof()
6 |
7 |
8 | @register_passable("trivial")
9 | struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable):
10 | """
11 | Static sequence of bytes.
12 | """
13 |
14 | var data: StaticTuple[UInt8, capacity]
15 |
16 | fn __init__(inout self):
17 | var data = StaticTuple[UInt8, capacity](0)
18 |
19 | for i in range(capacity):
20 | data[i] = 0
21 |
22 | self.data = data
23 |
24 | fn __init__(inout self, s: String):
25 | var data = StaticTuple[UInt8, capacity](0)
26 | var length = len(s)
27 |
28 | for i in range(capacity):
29 | data[i] = ord(s[i]) if i < length else 0
30 |
31 | self.data = data
32 |
33 | @always_inline("nodebug")
34 | fn __len__(self) -> Int:
35 | return capacity
36 |
37 | @always_inline("nodebug")
38 | fn __setitem__(inout self, index: Int, value: UInt8):
39 | self.data[index] = value
40 |
41 | @always_inline("nodebug")
42 | fn __getitem__(self, index: Int) -> UInt8:
43 | return self.data[index]
44 |
45 | @always_inline("nodebug")
46 | fn __eq__(self, other: Self) -> Bool:
47 | for i in range(capacity):
48 | if self[i] != other[i]:
49 | return False
50 | return True
51 |
52 | @always_inline("nodebug")
53 | fn __ne__(self, other: Self) -> Bool:
54 | for i in range(capacity):
55 | if self[i] != other[i]:
56 | return True
57 | return False
58 |
59 | @always_inline("nodebug")
60 | fn __str__(self) -> String:
61 | var result: String = ""
62 |
63 | for i in range(capacity):
64 | var val = self[i]
65 | if val != 0:
66 | result += chr(int(val))
67 |
68 | return result
69 |
70 |
71 | fn scalar_to_bytes[
72 | dtype: DType, Size: Int = ScalarBytes
73 | ](value: Scalar[dtype]) -> Bytes[Size]:
74 | constrained[Size >= ScalarBytes, "Size must be at least ${ScalarBytes}"]()
75 |
76 | var bits = bitcast[DType.uint64](value.cast[expand_type[dtype]()]())
77 | var data = Bytes[Size]()
78 |
79 | for i in range(ScalarBytes):
80 | data[i] = (bits >> (i << 3)).cast[DType.uint8]()
81 |
82 | return data
83 |
84 |
85 | fn bytes_to_scalar[dtype: DType](data: Bytes) -> Scalar[dtype]:
86 | constrained[data.capacity >= ScalarBytes, "Size must be at least ${ScalarBytes}"]()
87 |
88 | var bits: UInt64 = 0
89 |
90 | for i in range(ScalarBytes):
91 | bits |= data[i].cast[DType.uint64]() << (i << 3)
92 |
93 | return bitcast[expand_type[dtype]()](bits).cast[dtype]()
94 |
95 |
96 | fn expand_type[dtype: DType]() -> DType:
97 | @parameter
98 | if dtype.is_floating_point():
99 | return DType.float64
100 | elif dtype.is_signed():
101 | return DType.int64
102 | elif dtype.is_integral():
103 | return DType.uint64
104 |
105 | constrained[False, "Type must be numeric"]()
106 | return DType.invalid
107 |
--------------------------------------------------------------------------------
/basalt/utils/collection.mojo:
--------------------------------------------------------------------------------
1 | from memory.unsafe_pointer import UnsafePointer
2 | from memory import memset_zero, memcpy
3 |
4 | from basalt import Tensor, Symbol
5 |
6 |
7 | struct Collection(CollectionElement, Sized):
8 | """
9 | A collection of tensors with associated symbols.
10 | """
11 |
12 | var size: Int
13 | var capacity: Int
14 | var data: UnsafePointer[Tensor[dtype]]
15 | var symbols: UnsafePointer[Scalar[DType.uint32]]
16 |
17 | @always_inline("nodebug")
18 | fn __init__(inout self, *, capacity: Int = 1):
19 | """
20 | Initializes a new Collection with the given capacity.
21 | """
22 | self.size = 0
23 | self.capacity = capacity
24 | self.data = UnsafePointer[Tensor[dtype]].alloc(capacity)
25 | UnsafePointer.init_pointee_move((self.data + self.size), Tensor[dtype]())
26 | self.symbols = UnsafePointer[Scalar[DType.uint32]].alloc(capacity)
27 |
28 | @always_inline("nodebug")
29 | fn __moveinit__(inout self, owned existing: Self):
30 | """
31 | Move initializes a Collection from an existing one.
32 | """
33 | self.size = existing.size
34 | self.capacity = existing.capacity
35 | self.data = existing.data
36 | self.symbols = existing.symbols
37 |
38 | @always_inline("nodebug")
39 | fn __copyinit__(inout self, existing: Self):
40 | """
41 | Copy initializes a Collection from an existing one.
42 | """
43 | self.capacity = existing.capacity
44 | self.size = existing.size
45 | self.data = UnsafePointer[Tensor[dtype]].alloc(existing.capacity)
46 | self.symbols = UnsafePointer[Scalar[DType.uint32]].alloc(existing.capacity)
47 | memcpy(self.symbols, existing.symbols, existing.capacity)
48 |
49 | for i in range(existing.size):
50 | UnsafePointer.init_pointee_move((self.data + i), (existing.data + i)[])
51 |
52 | @always_inline("nodebug")
53 | fn __del__(owned self):
54 | """
55 | Destructor for the Collection.
56 | """
57 | for i in range(self.size):
58 | UnsafePointer.destroy_pointee((self.data + i))
59 | if self.data:
60 | self.data.free()
61 | if self.symbols:
62 | self.symbols.free()
63 |
64 | @always_inline("nodebug")
65 | fn __len__(self) -> Int:
66 | """
67 | Returns the number of elements in the Collection.
68 | """
69 | return self.size
70 |
71 | @always_inline("nodebug")
72 | fn _realloc(inout self, new_capacity: Int):
73 | """
74 | Reallocates the Collection to the new capacity.
75 | """
76 | var new_data = UnsafePointer[Tensor[dtype]].alloc(new_capacity)
77 | var new_symbols = UnsafePointer[Scalar[DType.uint32]].alloc(new_capacity)
78 |
79 | for i in range(self.size):
80 | UnsafePointer.init_pointee_move((new_data + i), (self.data + i)[])
81 | new_symbols[i] = self.symbols[i]
82 |
83 | self.data.free()
84 | self.symbols.free()
85 |
86 | self.data = new_data
87 | self.symbols = new_symbols
88 | self.capacity = new_capacity
89 |
90 | @always_inline("nodebug")
91 | fn append(inout self, owned value: Tensor[dtype], symbol: Symbol):
92 | """
93 | Appends a tensor and its associated symbol to the Collection.
94 | """
95 | self.append(value ^, symbol.name)
96 |
97 | @always_inline("nodebug")
98 | fn append(inout self, owned value: Tensor[dtype], symbol_name: UInt32):
99 | """
100 | Appends a tensor and its associated symbol name to the Collection.
101 | """
102 | if self.size >= self.capacity:
103 | self._realloc(max(1, self.capacity * 2))
104 | UnsafePointer.init_pointee_move((self.data + self.size), value ^)
105 | self.symbols[self.size] = symbol_name
106 | self.size += 1
107 |
108 | @always_inline("nodebug")
109 | fn get_index(self, symbol_name: UInt32) -> Int:
110 | """
111 | Returns the index of the tensor with the given symbol name.
112 | """
113 | alias factor = 8
114 | # 2 -> 5.32s MNIST
115 | # 4 -> 4.95s MNIST
116 | # 8 -> 4.85s MNIST
117 | # 16 -> 5.19s MNIST
118 | # NOTE: This ideally should just be a hashmap
119 |
120 | for i in range(0, self.size, factor):
121 | var elems = self.symbols.load[width=factor](i) == symbol_name
122 |
123 | for j in range(factor):
124 | if elems[j]:
125 | return i + j
126 |
127 | var split = divmod(self.size, factor)
128 |
129 | for i in range(split[1]):
130 | var index = split[0] + i
131 |
132 | if self.symbols[index] == symbol_name:
133 | return index
134 |
135 | return -1
136 |
137 | fn __getitem__(
138 | self,
139 | symbol: Symbol,
140 | ) -> ref[self.data[0]] Tensor[dtype]:
141 | # TODO: This is a hack, we should instead use dict, because there can be cases where the object doesn't exist and also self.data[0] can be a value that doesn't exit because the list is empty (but we hack this by assigning an empty value)
142 | """
143 | Returns a reference to the tensor with the given symbol.
144 | """
145 | var index = self.get_index(symbol.name)
146 |
147 |
148 | return (self.data + index)[]
149 |
150 | @always_inline("nodebug")
151 | fn clear(inout self):
152 | """
153 | Clears the Collection, removing all tensors and symbols.
154 | """
155 | for i in range(self.size):
156 | UnsafePointer.destroy_pointee((self.data + i))
157 | memset_zero(self.symbols, self.capacity)
158 | self.size = 0
159 |
160 | @always_inline("nodebug")
161 | fn set_zero(self):
162 | """
163 | Zeroes out all the tensors in the collection.
164 | """
165 | for i in range(self.size):
166 | self.data[i].zero()
167 |
--------------------------------------------------------------------------------
/basalt/utils/dataloader.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_equal
2 | from math import min
3 | from memory import memcpy
4 |
5 | from basalt import dtype, nelts
6 | from basalt import Tensor, TensorShape
7 |
8 |
9 | @value
10 | struct Batch[dtype: DType](CollectionElement):
11 | var data: Tensor[dtype]
12 | var labels: Tensor[dtype]
13 |
14 | fn __init__(inout self, batch_data: Tensor[dtype], batch_labels: Tensor[dtype]):
15 | self.data = batch_data
16 | self.labels = batch_labels
17 |
18 | fn __init__(
19 | inout self,
20 | df_data: Tensor[dtype],
21 | df_labels: Tensor[dtype],
22 | start: Int,
23 | batch_data_shape: TensorShape,
24 | batch_labels_shape: TensorShape,
25 | ):
26 | # TODO: find a better way to do this
27 | # Links to the copies of the input tensors in model.forward()
28 | self.data = Tensor[dtype](batch_data_shape)
29 | self.labels = Tensor[dtype](batch_labels_shape)
30 | memcpy(
31 | self.data.data(),
32 | df_data.data().offset(start * batch_data_shape.strides()[0]),
33 | batch_data_shape.num_elements(),
34 | )
35 | memcpy(
36 | self.labels.data(),
37 | df_labels.data().offset(start * batch_labels_shape.strides()[0]),
38 | batch_labels_shape.num_elements(),
39 | )
40 |
41 | fn __getitem__(self, index: Int) -> Tensor[dtype]:
42 | if index == 0:
43 | return self.data
44 | elif index == 1:
45 | return self.labels
46 | else:
47 | print("[ERROR] Batch.__getitem__(): Index out of bounds")
48 | return Tensor[dtype]()
49 |
50 |
51 | @value
52 | struct DataLoader:
53 | var data: Tensor[dtype]
54 | var labels: Tensor[dtype]
55 | var batch_size: Int
56 | var _current_index: Int
57 | var _num_batches: Int
58 | var _data_batch_shape: TensorShape
59 | var _label_batch_shape: TensorShape
60 |
61 | fn __init__(
62 | inout self,
63 | data: Tensor[dtype],
64 | labels: Tensor[dtype],
65 | batch_size: Int,
66 | ):
67 | self.data = data
68 | self.labels = labels
69 | self.batch_size = batch_size
70 |
71 | # Number of batches to iter, NOTE: ignore the remainder for now
72 | # var remainder = 1 if self.data.dim(0) % self.batch_size != 0 else 0
73 | self._current_index = 0
74 | self._num_batches = self.data.dim(0) // self.batch_size # + remainder
75 |
76 | # Batch shapes
77 | self._data_batch_shape = self.data.shape()
78 | self._label_batch_shape = self.labels.shape()
79 | self._data_batch_shape[0] = self.batch_size
80 | self._label_batch_shape[0] = self.batch_size
81 |
82 | @always_inline
83 | fn __len__(self) -> Int:
84 | """
85 | Returns the number of the batches left in the dataset.
86 | """
87 | return self._num_batches
88 |
89 | fn __iter__(self) -> Self:
90 | # TODO: Starting the iterator requires to return (COPY!) the whole dataloader which containts the whole dataset
91 | # Does this mean that the whole dataset is copied every epoch ?!
92 | return self
93 |
94 | fn __has_next__(self) -> Bool:
95 | return self._num_batches > 0
96 |
97 | fn __next__(inout self) -> Batch[dtype]:
98 | # NOTE: ignore the remainder for now
99 | # var end = min(self._current_index + self.batch_size, self.data.dim(0))
100 | # self._data_shape[0] = end - self._current_index
101 | # self._label_shape[0] = end - self._current_index
102 |
103 | var temp_current_index = self._current_index
104 | self._current_index += self.batch_size
105 | self._num_batches -= 1
106 |
107 | return Batch[dtype](
108 | self.data,
109 | self.labels,
110 | temp_current_index,
111 | self._data_batch_shape,
112 | self._label_batch_shape,
113 | )
114 |
--------------------------------------------------------------------------------
/basalt/utils/datasets.mojo:
--------------------------------------------------------------------------------
1 | from algorithm import vectorize
2 |
3 | from basalt import dtype
4 | from basalt import Tensor, TensorShape
5 | from basalt.utils.tensorutils import elwise_op, tmean, tstd
6 |
7 |
8 | @always_inline
9 | fn div[dtype: DType, simd_width: Int](a: SIMD[dtype, simd_width], b: Scalar[dtype]) -> SIMD[dtype, simd_width]:
10 | return a / b
11 |
12 |
13 | struct BostonHousing:
14 | alias n_inputs = 13
15 |
16 | var data: Tensor[dtype]
17 | var labels: Tensor[dtype]
18 |
19 | fn __init__(inout self, file_path: String) raises:
20 | var s = read_file(file_path)
21 | # Skip the first and last lines
22 | # This does assume your last line in the file has a newline at the end
23 | var list_of_lines = s.split("\n")[1:-1]
24 |
25 | # Length is number of lines
26 | var N = len(list_of_lines)
27 |
28 | self.data = Tensor[dtype](N, self.n_inputs) # All columns except the last one
29 | self.labels = Tensor[dtype](N, 1) # Only the last column (MEDV)
30 |
31 | var line: List[String] = List[String]()
32 |
33 | # Load data in Tensor
34 | for item in range(N):
35 | line = list_of_lines[item].split(",")
36 | self.labels[item] = cast_string[dtype](line[-1])
37 |
38 | for n in range(self.n_inputs):
39 | self.data[item * self.n_inputs + n] = cast_string[dtype](line[n])
40 |
41 | # Normalize data
42 | # TODO: redo when tensorutils tmean2 and tstd2 are implemented
43 | alias nelts = simdwidthof[dtype]()
44 | var col = Tensor[dtype](N)
45 | for j in range(self.n_inputs):
46 | for k in range(N):
47 | col[k] = self.data[k * self.n_inputs + j]
48 | for i in range(N):
49 | self.data[i * self.n_inputs + j] = (self.data[i * self.n_inputs + j] - tmean(col)) / tstd(col)
50 |
51 |
52 | struct MNIST:
53 | var data: Tensor[dtype]
54 | var labels: Tensor[dtype]
55 |
56 | fn __init__(inout self, file_path: String) raises:
57 | var s = read_file(file_path)
58 | # Skip the first and last lines
59 | # This does assume your last line in the file has a newline at the end
60 | var list_of_lines = s.split("\n")[1:-1]
61 |
62 | # Length is number of lines
63 | var N = len(list_of_lines)
64 | self.data = Tensor[dtype](N, 1, 28, 28)
65 | self.labels = Tensor[dtype](N)
66 |
67 | var line: List[String] = List[String]()
68 |
69 | # Load data in Tensor
70 | for item in range(N):
71 | line = list_of_lines[item].split(",")
72 | self.labels[item] = atol(line[0])
73 | for i in range(self.data.shape()[2]):
74 | for j in range(self.data.shape()[3]):
75 | self.data[item * 28 * 28 + i * 28 + j] = atol(line[i * 28 + j + 1])
76 |
77 | # Normalize data
78 | alias nelts = simdwidthof[dtype]()
79 |
80 | @parameter
81 | fn vecdiv[nelts: Int](idx: Int):
82 | self.data.store[nelts](idx, div(self.data.load[nelts](idx), 255.0))
83 |
84 | vectorize[vecdiv, nelts](self.data.num_elements())
85 |
86 |
87 | fn read_file(file_path: String) raises -> String:
88 | var s: String
89 | with open(file_path, "r") as f:
90 | s = f.read()
91 | return s
92 |
93 |
94 | fn find_first(s: String, delimiter: String) -> Int:
95 | for i in range(len(s)):
96 | if s[i] == delimiter:
97 | return i
98 | return -1
99 |
100 |
101 | fn cast_string[dtype: DType](s: String) raises -> Scalar[dtype]:
102 | """
103 | Cast a string with decimal to a SIMD vector of dtype.
104 | """
105 |
106 | var idx = find_first(s, delimiter=".")
107 | var x: Scalar[dtype] = -1
108 |
109 | if idx == -1:
110 | # No decimal point
111 | x = atol(s)
112 | return x
113 | else:
114 | var c_int: Scalar[dtype]
115 | var c_frac: Scalar[dtype]
116 | c_int = atol(s[:idx])
117 | c_frac = atol(s[idx + 1 :])
118 | x = c_int + c_frac / (10 ** len(s[idx + 1 :]))
119 | return x
120 |
--------------------------------------------------------------------------------
/basalt/utils/graph_render.py:
--------------------------------------------------------------------------------
1 | import onnx
2 | from onnx import helper
3 | from onnx import TensorProto
4 | import netron
5 |
6 |
7 | def get_param_data(param_shape):
8 | factor = 1
9 | for dim in param_shape:
10 | factor *= dim
11 | return [0] * factor
12 |
13 |
14 | def create_onnx_graph_from_json(graph, type="node"):
15 | # Create a list to hold nodes, inputs, outputs, and initializers
16 | nodes = []
17 | inputs = []
18 | outputs = []
19 | initializers = []
20 | intermediates = []
21 |
22 | # Process params as initializers (if operator-graph)
23 | visited = []
24 | if type == "operator":
25 | onnx_inputs = graph["inputs"] + graph.get("params", [])
26 | elif type == "node":
27 | onnx_inputs = graph["inputs"]
28 |
29 | # Process params as initializers
30 | for initializer in graph.get("params", []):
31 | name = initializer["name"]
32 | dtype = TensorProto.FLOAT # TODO
33 | shape = list(map(int, initializer["shape"].split("x")))
34 | tensor = helper.make_tensor(name, dtype, shape, get_param_data(shape))
35 | initializers.append(tensor)
36 | visited.append(name)
37 |
38 | # Process inputs
39 | for input in onnx_inputs:
40 | name = input["name"]
41 | dtype = TensorProto.FLOAT # TODO
42 | shape = list(map(int, input["shape"].split("x")))
43 | inputs.append(helper.make_tensor_value_info(name, dtype, shape))
44 | visited.append(name)
45 |
46 | # Process outputs
47 | for output in graph["outputs"]:
48 | name = output["name"]
49 | dtype = TensorProto.FLOAT # TODO
50 | shape = list(map(int, output["shape"].split("x")))
51 | outputs.append(helper.make_tensor_value_info(name, dtype, shape))
52 | visited.append(name)
53 |
54 | # Process nodes
55 | for node in graph["nodes"]:
56 | operator = node["operator"]
57 | onnx_node = helper.make_node(
58 | operator,
59 | inputs=[input["name"] for input in node["inputs"]],
60 | outputs=[output["name"] for output in node["outputs"]],
61 | name=f"{node['operator']}_node",
62 | )
63 |
64 | # Process attributes
65 | for attribute in node["attributes"]:
66 | attr_type = 0
67 | if attribute["type"] == "FLOAT":
68 | attr_type = onnx.AttributeProto.FLOAT
69 | elif attribute["type"] == "INT":
70 | attr_type = onnx.AttributeProto.INT
71 | elif attribute["type"] == "STRING":
72 | attr_type = onnx.AttributeProto.STRING
73 | elif attribute["type"] == "FLOATS":
74 | attr_type = onnx.AttributeProto.FLOATS
75 | elif attribute["type"] == "INTS":
76 | attr_type = onnx.AttributeProto.INTS
77 | else:
78 | raise ValueError(f"Unsupported attribute type: {attribute['type']}")
79 |
80 | onnx_attribute = helper.make_attribute(
81 | attribute["name"], attribute["value"], attr_type=attr_type
82 | )
83 | onnx_node.attribute.append(onnx_attribute)
84 |
85 | nodes.append(onnx_node)
86 |
87 | # Process intermediates
88 | for output in node["outputs"]:
89 | if output["name"] not in visited:
90 | name = output["name"]
91 | dtype = TensorProto.FLOAT
92 | shape = list(map(int, output["shape"].split("x")))
93 | intermediates.append(helper.make_tensor_value_info(name, dtype, shape))
94 | visited.append(name)
95 |
96 | # Process loss
97 | if "loss" in graph.keys():
98 | loss = graph["loss"][0]
99 | name = loss["name"]
100 | if name not in visited:
101 | dtype = TensorProto.FLOAT
102 | shape = list(map(int, loss["shape"].split("x")))
103 | outputs.append(helper.make_tensor_value_info(name, dtype, shape))
104 | visited.append(name)
105 |
106 | # Create the graph
107 | graph_def = helper.make_graph(
108 | nodes,
109 | graph.get("graph_name", "basalt-ONNX"),
110 | inputs,
111 | outputs,
112 | initializer=initializers,
113 | value_info=intermediates,
114 | )
115 |
116 | # Create the model
117 | model_def = helper.make_model(graph_def, producer_name="basalt")
118 |
119 | # Save the model to a file
120 | onnx.save(model_def, "output_model.onnx")
121 |
122 |
123 | def netron_render(graph, type="node"):
124 | assert type in ["node", "operator"]
125 | create_onnx_graph_from_json(graph, type=type)
126 | netron.start("output_model.onnx")
127 |
--------------------------------------------------------------------------------
/basalt/utils/math_util.mojo:
--------------------------------------------------------------------------------
1 | @always_inline
2 | fn add[
3 | dtype: DType, simd_width: Int
4 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[
5 | dtype, simd_width
6 | ]:
7 | return a + b
8 |
9 |
10 | @always_inline
11 | fn sub[
12 | dtype: DType, simd_width: Int
13 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[
14 | dtype, simd_width
15 | ]:
16 | return a - b
17 |
18 |
19 | @always_inline
20 | fn mul[
21 | dtype: DType, simd_width: Int
22 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[
23 | dtype, simd_width
24 | ]:
25 | return a * b
26 |
27 |
28 | @always_inline
29 | fn div[
30 | dtype: DType, simd_width: Int
31 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[
32 | dtype, simd_width
33 | ]:
34 | return a / b
35 |
36 |
37 | @always_inline
38 | fn round_simd[
39 | dtype: DType, simd_width: Int
40 | ](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]:
41 | return round(x)
42 |
--------------------------------------------------------------------------------
/basalt/utils/perf_utils.mojo:
--------------------------------------------------------------------------------
1 | from time.time import monotonic as now
2 | from memory import UnsafePointer, memcpy, memset
3 |
4 | from basalt.autograd.node import Node
5 |
6 |
7 | @always_inline("nodebug")
8 | fn fit_string[num: Int](s: String) -> String:
9 | var data = UnsafePointer[Byte]().alloc(num + 1)
10 | var copy_len = min(num, len(s))
11 |
12 | memcpy(data, s.unsafe_ptr(), copy_len)
13 | memset(data + copy_len, ord(" "), num - copy_len)
14 | data[num] = 0
15 |
16 | return String(ptr=data, length=num + 1)
17 |
18 |
19 | @always_inline("nodebug")
20 | fn truncate_decimals[num: Int](s: String) -> String:
21 | try:
22 | var parts = s.split(".")
23 | var truncated = parts[0]
24 |
25 | if len(parts) > 1:
26 | var decimal_parts = parts[1].split("e")
27 | truncated += "." + fit_string[num](decimal_parts[0])
28 |
29 | if len(decimal_parts) > 1:
30 | truncated += "e" + decimal_parts[1]
31 |
32 | return truncated
33 | except e:
34 | print("[WARNING] could not truncate decimals: ", e)
35 | return s
36 |
37 |
38 | @value
39 | struct PerfMetricsValues:
40 | var node: Node
41 | var ns: Float64
42 |
43 |
44 | @value
45 | struct PerfMetrics:
46 | var forward_perf_metrics: List[PerfMetricsValues]
47 | var backward_perf_metrics: List[PerfMetricsValues]
48 | var epochs_forward: Int
49 | var epochs_backward: Int
50 | var start: Int
51 |
52 | fn __init__(inout self):
53 | self.forward_perf_metrics = List[PerfMetricsValues]()
54 | self.backward_perf_metrics = List[PerfMetricsValues]()
55 | self.epochs_forward = 0
56 | self.epochs_backward = 0
57 | self.start = 0
58 |
59 | fn __init__(inout self, graph: Graph):
60 | self.forward_perf_metrics = List[PerfMetricsValues]()
61 | self.backward_perf_metrics = List[PerfMetricsValues]()
62 |
63 | self.forward_perf_metrics.reserve(graph.nodes.size)
64 | self.backward_perf_metrics.reserve(graph.nodes.size)
65 |
66 | for i in range(graph.nodes.size):
67 | self.forward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0))
68 | self.backward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0))
69 |
70 | self.epochs_forward = 0
71 | self.epochs_backward = 0
72 | self.start = 0
73 |
74 | fn start_forward_pass(inout self):
75 | self.start = now()
76 |
77 | fn end_forward_pass(inout self, pos: Int):
78 | self.forward_perf_metrics[pos].ns += now() - self.start
79 | self.epochs_forward += 1
80 |
81 | fn start_backward_pass(inout self):
82 | self.start = now()
83 |
84 | fn end_backward_pass(inout self, pos: Int):
85 | self.backward_perf_metrics[pos].ns += now() - self.start
86 | self.epochs_backward += 1
87 |
88 | fn print_perf_metrics[
89 | type_part: String
90 | ](self, time_format: String = "ns", print_shape: Bool = False):
91 | constrained[type_part == "Forward" or type_part == "Backward", "Only 'Forward' or 'Backward' are accepted types."]()
92 |
93 | alias is_forward = type_part == "Forward"
94 |
95 | var metrics = self.forward_perf_metrics if is_forward else self.backward_perf_metrics
96 | var epochs = self.epochs_forward if is_forward else self.epochs_backward
97 | var size = len(metrics)
98 | var total_time: Float64 = 0
99 |
100 | if size == 0:
101 | return
102 |
103 | if is_forward:
104 | print("\n\nForward pass performance metrics:")
105 | else:
106 | print("\n\nBackward pass performance metrics:")
107 |
108 | for i in range(size):
109 | total_time += metrics[i].ns
110 |
111 | var header = (
112 | fit_string[5]("Node")
113 | + "| "
114 | + fit_string[15]("Operator")
115 | + "| "
116 | + fit_string[20]("Time [" + time_format + "]")
117 | + "| "
118 | + fit_string[20]("Percentage [%]")
119 | )
120 |
121 | if print_shape:
122 | header += "| " + fit_string[70]("Shape\t = OP( , , )")
123 |
124 | print(header)
125 |
126 | var header_length = len(header)
127 | var seperator = UnsafePointer[UInt8]().alloc(header_length + 1)
128 |
129 | memset(seperator, ord("-"), header_length)
130 | seperator[header_length] = 0
131 |
132 | print(String(ptr=seperator, length=len(header) + 1))
133 |
134 | for i in range(size):
135 | var value = metrics[i]
136 | var time = value.ns / epochs
137 |
138 | if time_format == "ms":
139 | time /= 1e6
140 | elif time_format == "s":
141 | time /= 1e9
142 |
143 | var percentage = (value.ns / total_time) * 100
144 |
145 | var print_value = (
146 | fit_string[5](str(i))
147 | + "| "
148 | + fit_string[15](str(value.node.operator))
149 | + "| "
150 | + fit_string[20](truncate_decimals[4](str(time)))
151 | + "| "
152 | + fit_string[20](truncate_decimals[3](str(percentage)) + " %")
153 | + "| "
154 | )
155 |
156 | if print_shape:
157 | var shape_str = fit_string[15]("<" + str(value.node.outputs[0].shape) + ">")
158 |
159 | for j in range(1, len(value.node.outputs)):
160 | shape_str += ", " + fit_string[15]("<" + str(value.node.outputs[j].shape) + ">")
161 |
162 | shape_str += fit_string[7](" = OP(") + fit_string[15]("<" + str(value.node.inputs[0].shape) + ">")
163 |
164 | for j in range(1, len(value.node.inputs)):
165 | shape_str += ", " + fit_string[15]("<" + str(value.node.inputs[j].shape) + ">")
166 |
167 | shape_str += ")"
168 |
169 | print(print_value, end="")
170 | print(shape_str)
171 | else:
172 | print(print_value)
173 |
174 | if time_format == "ms":
175 | total_time /= 1e6
176 | elif time_format == "s":
177 | total_time /= 1e9
178 |
179 | print(
180 | "\nTotal average "
181 | + type_part
182 | + " time: "
183 | + str(total_time)
184 | + " "
185 | + time_format
186 | )
187 |
188 |
189 | fn print_forward_perf_metrics(self, time_format: String = "ns", print_shape: Bool = False):
190 | self.print_perf_metrics["Forward"](time_format, print_shape)
191 |
192 | fn print_backward_perf_metrics(self, time_format: String = "ns", print_shape: Bool = False):
193 | self.print_perf_metrics["Backward"](time_format, print_shape)
194 |
--------------------------------------------------------------------------------
/basalt/utils/rand_utils.mojo:
--------------------------------------------------------------------------------
1 | from basalt import Tensor
2 | from random import rand, randn
3 | from algorithm import vectorize
4 | from utils.static_tuple import StaticTuple
5 |
6 |
7 | @always_inline
8 | fn rand_uniform[dtype: DType](inout res: Tensor[dtype], low: Scalar[dtype], high: Scalar[dtype]):
9 | var scale = high - low
10 |
11 | rand[dtype](res.data(), res.num_elements())
12 |
13 | @parameter
14 | fn vecscale[nelts: Int](idx: Int):
15 | res.store[nelts](idx, res.load[nelts](idx).fma(scale, low))
16 |
17 | vectorize[vecscale, nelts](res.num_elements())
18 |
19 |
20 | @always_inline
21 | fn rand_normal[dtype: DType](inout res: Tensor[dtype], mean: Float64, std: Float64):
22 | randn[dtype](res.data(), res.num_elements(), mean, std**2)
23 |
24 |
25 | @register_passable("trivial")
26 | struct MersenneTwister:
27 | """
28 | Pseudo-random generator Mersenne Twister (MT19937-32bit).
29 | """
30 |
31 | alias N: Int = 624
32 | alias M: Int = 397
33 | alias MATRIX_A: Int32 = 0x9908B0DF
34 | alias UPPER_MASK: Int32 = 0x80000000
35 | alias LOWER_MASK: Int32 = 0x7FFFFFFF
36 | alias TEMPERING_MASK_B: Int32 = 0x9D2C5680
37 | alias TEMPERING_MASK_C: Int32 = 0xEFC60000
38 |
39 | var state: StaticTuple[Int32, Self.N]
40 | var index: Int
41 |
42 | fn __init__(inout self, seed: Int):
43 | alias W: Int = 32
44 | alias F: Int32 = 1812433253
45 | alias D: Int32 = 0xFFFFFFFF
46 |
47 | self.index = Self.N
48 | self.state = StaticTuple[Int32, Self.N]()
49 | self.state[0] = seed & D
50 |
51 | for i in range(1, Self.N):
52 | var prev = self.state[i - 1]
53 | self.state[i] = (F * (prev ^ (prev >> (W - 2))) + i) & D
54 |
55 | fn next(inout self) -> Int32:
56 | if self.index >= Self.N:
57 | for i in range(Self.N):
58 | var x = (self.state[i] & Self.UPPER_MASK) + (self.state[(i + 1) % Self.N] & Self.LOWER_MASK)
59 | var xA = x >> 1
60 | if x % 2 != 0:
61 | xA ^= Self.MATRIX_A
62 | self.state[i] = self.state[(i + Self.M) % Self.N] ^ xA
63 | self.index = 0
64 |
65 | var y = self.state[self.index]
66 | y ^= y >> 11
67 | y ^= (y << 7) & Self.TEMPERING_MASK_B
68 | y ^= (y << 15) & Self.TEMPERING_MASK_C
69 | y ^= y >> 18
70 | self.index += 1
71 |
72 | return y
73 |
74 | fn next_ui8(inout self) -> UInt8:
75 | return self.next().value & int(0xFF)
76 |
--------------------------------------------------------------------------------
/basalt/utils/tensor_creation_utils.mojo:
--------------------------------------------------------------------------------
1 | from python import Python, PythonObject
2 | from memory import memcpy, UnsafePointer
3 |
4 | # maybe this functions should be from the Tensor struct (like tensor.to_numpy()) and tensor.__init__(np_array: PythonObject) to create a tensor from a numpy array and tensor.copy_np_data(np_array: PythonObject) to copy the numpy array to the tensor.
5 |
6 |
7 | fn to_numpy(tensor: Tensor) -> PythonObject:
8 | try:
9 | var np = Python.import_module("numpy")
10 |
11 | np.set_printoptions(4)
12 |
13 | var rank = tensor.rank()
14 | var dims = PythonObject([])
15 | for i in range(rank):
16 | dims.append(tensor.dim(i))
17 | var pyarray: PythonObject = np.empty(dims, dtype=np.float32)
18 |
19 | var pointer_d = pyarray.__array_interface__["data"][0].unsafe_get_as_pointer[DType.float32]()
20 | var d: UnsafePointer[Float32] = tensor.data().bitcast[Float32]()
21 | memcpy(pointer_d, d, tensor.num_elements())
22 |
23 | _ = tensor
24 |
25 | return pyarray^
26 | except e:
27 | print("Error in to numpy", e)
28 | return PythonObject()
29 |
30 |
31 | fn to_tensor(np_array: PythonObject) raises -> Tensor[dtype]:
32 | var shape = List[Int]()
33 | for i in range(np_array.ndim):
34 | shape.append(int(float(np_array.shape[i])))
35 | if np_array.ndim == 0:
36 | # When the numpy array is a scalar, you need or the reshape to a size 1 ndarray or do this, if not the memcpy gets a memory error (Maybe because it is a register value?).
37 | var tensor = Tensor[dtype](TensorShape(1))
38 | tensor[0] = float(np_array).cast[dtype]()
39 | return tensor^
40 |
41 | var tensor = Tensor[dtype](TensorShape(shape))
42 |
43 | var np_array_2: PythonObject
44 | try:
45 | var np = Python.import_module("numpy")
46 | # copy is also necessary for ops like slices to make them contiguous instead of references.
47 | np_array_2 = np.float32(np_array.copy())
48 | except e:
49 | np_array_2 = np_array.copy()
50 | print("Error in to_tensor", e)
51 |
52 | var pointer_d = np_array_2.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()
53 | memcpy(tensor.data(), pointer_d, tensor.num_elements())
54 |
55 | _ = np_array_2
56 | _ = np_array
57 |
58 | return tensor^
59 |
60 |
61 | fn copy_np_data(inout tensor: Tensor, np_array: PythonObject) raises:
62 | var np_array_2: PythonObject
63 | try:
64 | var np = Python.import_module("numpy")
65 | # copy is also necessary for ops like slices to make them contiguous instead of references.
66 | np_array_2 = np.float32(np_array.copy())
67 | except e:
68 | np_array_2 = np_array.copy()
69 | print("Error in to_tensor", e)
70 |
71 | var pointer_d = np_array_2.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()
72 | var d: UnsafePointer[Float32] = tensor.data().bitcast[Float32]()
73 | memcpy(d, pointer_d, tensor.num_elements())
74 |
75 | # This shouldn't be necessary anymore, but I'm leaving it here for now.
76 | # _ = np_array_2
77 | # _ = np_array
78 | # _ = tensor
79 |
--------------------------------------------------------------------------------
/examples/data/mnist_torch.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basalt-org/basalt/fe16eadcfcfee9271f9df0dd94d11b7c50d868ba/examples/data/mnist_torch.onnx
--------------------------------------------------------------------------------
/examples/housing.mojo:
--------------------------------------------------------------------------------
1 | from time.time import monotonic as now
2 |
3 | import basalt.nn as nn
4 | from basalt import Tensor, TensorShape
5 | from basalt import Graph, Symbol, OP
6 | from basalt.utils.datasets import BostonHousing
7 | from basalt.utils.dataloader import DataLoader
8 | from basalt.nn.model import Parameters
9 |
10 |
11 | fn linear_regression(batch_size: Int, n_inputs: Int, n_outputs: Int) -> Graph:
12 | var g = Graph()
13 |
14 | var x = g.input(TensorShape(batch_size, n_inputs))
15 | var y_true = g.input(TensorShape(batch_size, n_outputs))
16 |
17 | var y_pred = nn.Linear(g, x, n_outputs)
18 | g.out(y_pred)
19 |
20 | var loss = nn.MSELoss(g, y_pred, y_true)
21 | g.loss(loss)
22 |
23 | return g ^
24 |
25 |
26 | fn main():
27 | # Train Parameters
28 | alias batch_size = 32
29 | alias num_epochs = 200
30 | alias learning_rate = 0.02
31 |
32 | alias graph = linear_regression(batch_size, 13, 1)
33 |
34 | # try: graph.render("operator")
35 | # except: print("Could not render graph")
36 |
37 | var model = nn.Model[graph]()
38 | var optim = nn.optim.Adam[graph](model.parameters, lr=learning_rate)
39 |
40 | # Batchwise data loader
41 | print("Loading data...")
42 | var train_data: BostonHousing
43 | try:
44 | train_data = BostonHousing(file_path="./examples/data/housing.csv")
45 | except:
46 | print("Could not load data")
47 | return
48 |
49 | var training_loader = DataLoader(
50 | data=train_data.data, labels=train_data.labels, batch_size=batch_size
51 | )
52 |
53 | print("Training started.")
54 | var start = now()
55 | for epoch in range(num_epochs):
56 | var num_batches: Int = 0
57 | var epoch_loss: Float32 = 0.0
58 | for batch in training_loader:
59 | # Forward pass
60 | var loss = model.forward(batch.data, batch.labels)
61 |
62 | # Backward pass
63 | optim.zero_grad()
64 | model.backward()
65 | optim.step()
66 |
67 | epoch_loss += loss[0]
68 | num_batches += 1
69 |
70 | print(
71 | "Epoch: [",
72 | epoch + 1,
73 | "/",
74 | num_epochs,
75 | "] \t Avg loss per epoch:",
76 | epoch_loss / num_batches,
77 | )
78 |
79 | print("Training finished: ", (now() - start) / 1e9, "seconds")
80 |
81 | # print("\n\nInferencing model...\n")
82 | # for batch in training_loader:
83 | # var output = model.inference(batch.data)
84 |
85 | # # Print first (and only output)
86 | # print("Predicted: ", output[0])
87 |
--------------------------------------------------------------------------------
/examples/housing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import optim
6 | from torch.utils.data import Dataset, DataLoader, TensorDataset
7 | import time
8 |
9 |
10 | class BostonHousing(Dataset):
11 | def __init__(self, data: pd.DataFrame):
12 | # Data: All columns except the last one / Target: Only the last column (MEDV)
13 | self.data = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32)
14 | self.target = torch.tensor(data.iloc[:, -1].values, dtype=torch.float32).view(
15 | -1, 1
16 | )
17 |
18 | # Normalize data
19 | self.data = (self.data - self.data.mean(dim=0)) / self.data.std(dim=0)
20 |
21 | # Create dataset
22 | self.dataset = TensorDataset(self.data, self.target)
23 |
24 | def __len__(self):
25 | return len(self.dataset)
26 |
27 | def __getitem__(self, idx):
28 | return self.dataset[idx]
29 |
30 |
31 | class LinearRegression(nn.Module):
32 | def __init__(self, input_dim):
33 | super(LinearRegression, self).__init__()
34 | self.linear = nn.Linear(input_dim, 1)
35 |
36 | def forward(self, x):
37 | return self.linear(x)
38 |
39 |
40 | if __name__ == "__main__":
41 | # Load data and split in training and testing sets
42 | df = pd.read_csv("./examples/data/housing.csv")
43 |
44 | TRAIN_PCT = 0.99
45 | shuffled_df = df.sample(frac=1, random_state=42)
46 | train_df = shuffled_df[: int(TRAIN_PCT * len(df))]
47 | test_df = shuffled_df[int(TRAIN_PCT * len(df)) :]
48 |
49 | train_data = BostonHousing(train_df)
50 | test_data = BostonHousing(test_df)
51 |
52 | # Train Parameters
53 | batch_size = 32
54 | num_epochs = 200
55 | learning_rate = 0.02
56 |
57 | # Batchwise data loader
58 | loaders = {
59 | "train": DataLoader(
60 | train_data, batch_size=batch_size, shuffle=False, num_workers=1
61 | ),
62 | "test": DataLoader(
63 | test_data, batch_size=batch_size, shuffle=False, num_workers=1
64 | ),
65 | }
66 |
67 | device = torch.device("cpu")
68 | # model = torch.compile(LinearRegression(train_data.data.shape[1]), fullgraph=True, options={"epilogue_fusion": True, "max_autotune": True})
69 | model = LinearRegression(train_data.data.shape[1])
70 | loss_func = nn.MSELoss()
71 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
72 | # optimizer = optim.SGD(model.parameters(), lr=learning_rate)
73 |
74 | # it seems the python for loop is what is making the program slow (so pytorch has a disadvantage thanks to python)
75 | model.train()
76 |
77 | start = time.time()
78 | for epoch in range(num_epochs):
79 | epoch_loss = 0
80 | num_batches = 0
81 | for batch_data, batch_labels in loaders["train"]:
82 | start_batch = time.time()
83 |
84 | # Forward pass
85 | outputs = model(batch_data)
86 | loss = loss_func(outputs, batch_labels)
87 |
88 | # Backward pass
89 | optimizer.zero_grad()
90 | loss.backward()
91 | optimizer.step()
92 |
93 | epoch_loss += loss.item()
94 | num_batches += 1
95 |
96 | # print time in ms
97 | # print(f'Batch time: {1000 * (time.time() - start_batch):.2f} ms') # The speed of a batch in basalt and pytorch are similar or pytorch can be faster
98 |
99 | print(
100 | f"Epoch [{epoch + 1}/{num_epochs}],\t Avg loss per epoch:"
101 | f" {epoch_loss / num_batches}"
102 | )
103 |
104 | print(f"Training time: {time.time() - start:.2f} seconds")
105 |
106 | # Evaluate the model
107 | model.eval()
108 | with torch.no_grad():
109 | test_predictions = model(test_data.data)
110 | mse_loss = loss_func(test_predictions, test_data.target).item()
111 | print(f"Mean Squared Error on Test Data: {mse_loss:.4f}")
112 |
--------------------------------------------------------------------------------
/examples/mnist.mojo:
--------------------------------------------------------------------------------
1 | from time.time import monotonic as now
2 |
3 | import basalt.nn as nn
4 | from basalt import Tensor, TensorShape
5 | from basalt import Graph, Symbol, OP, dtype
6 | from basalt.utils.datasets import MNIST
7 | from basalt.utils.dataloader import DataLoader
8 | from basalt.autograd.attributes import AttributeVector, Attribute
9 |
10 |
11 | # def plot_image(data: Tensor, num: Int):
12 | # from python.python import Python, PythonObject
13 |
14 | # np = Python.import_module("numpy")
15 | # plt = Python.import_module("matplotlib.pyplot")
16 |
17 | # var pyimage: PythonObject = np.empty((28, 28), np.float64)
18 | # for m in range(28):
19 | # for n in range(28):
20 | # pyimage.itemset((m, n), data[num * 28 * 28 + m * 28 + n])
21 |
22 | # plt.imshow(pyimage)
23 | # plt.show()
24 |
25 |
26 | fn create_CNN(batch_size: Int) -> Graph:
27 | var g = Graph()
28 | var x = g.input(TensorShape(batch_size, 1, 28, 28))
29 |
30 | var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2)
31 | var x2 = nn.ReLU(g, x1)
32 | var x3 = nn.MaxPool2d(g, x2, kernel_size=2)
33 | var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2)
34 | var x5 = nn.ReLU(g, x4)
35 | var x6 = nn.MaxPool2d(g, x5, kernel_size=2)
36 | var x7 = g.op(
37 | OP.RESHAPE,
38 | x6,
39 | attributes=AttributeVector(
40 | Attribute(
41 | "shape",
42 | TensorShape(x6.shape[0], x6.shape[1] * x6.shape[2] * x6.shape[3]),
43 | )
44 | ),
45 | )
46 | var out = nn.Linear(g, x7, n_outputs=10)
47 | g.out(out)
48 |
49 | var y_true = g.input(TensorShape(batch_size, 10))
50 | var loss = nn.CrossEntropyLoss(g, out, y_true)
51 | # var loss = nn.MSELoss(g, out, y_true)
52 | g.loss(loss)
53 |
54 | return g ^
55 |
56 |
57 | fn main():
58 | alias num_epochs = 20
59 | alias batch_size = 4
60 | alias learning_rate = 1e-3
61 |
62 | alias graph = create_CNN(batch_size)
63 |
64 | # try: graph.render("operator")
65 | # except: print("Could not render graph")
66 |
67 | var model = nn.Model[graph]()
68 | var optim = nn.optim.Adam[graph](model.parameters, lr=learning_rate)
69 |
70 | print("Loading data ...")
71 | var train_data: MNIST
72 | try:
73 | train_data = MNIST(file_path="./examples/data/mnist_test_small.csv")
74 | # _ = plot_image(train_data.data, 1)
75 | except e:
76 | print("Could not load data")
77 | print(e)
78 | return
79 |
80 | var training_loader = DataLoader(
81 | data=train_data.data, labels=train_data.labels, batch_size=batch_size
82 | )
83 |
84 | print("Training started/")
85 | var start = now()
86 |
87 | for epoch in range(num_epochs):
88 | var num_batches: Int = 0
89 | var epoch_loss: Float32 = 0.0
90 | var epoch_start = now()
91 | for batch in training_loader:
92 | # [ONE HOT ENCODING!]
93 | var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10)
94 | for bb in range(batch.labels.dim(0)):
95 | labels_one_hot[int((bb * 10 + batch.labels[bb]))] = 1.0
96 |
97 | # Forward pass
98 | var loss = model.forward(batch.data, labels_one_hot)
99 |
100 | # Backward pass
101 | optim.zero_grad()
102 | model.backward()
103 | optim.step()
104 |
105 | epoch_loss += loss[0]
106 | num_batches += 1
107 |
108 | print(
109 | "Epoch [",
110 | epoch + 1,
111 | "/",
112 | num_epochs,
113 | "],\t Step [",
114 | num_batches,
115 | "/",
116 | train_data.data.dim(0) // batch_size,
117 | "],\t Loss:",
118 | epoch_loss / num_batches,
119 | )
120 |
121 | print("Epoch time: ", (now() - epoch_start) / 1e9, "seconds")
122 |
123 | print("Training finished: ", (now() - start) / 1e9, "seconds")
124 |
125 | model.print_perf_metrics("ms", True)
126 |
--------------------------------------------------------------------------------
/examples/mnist.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | import os
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim
10 | from torch.autograd import Variable
11 | from torch.utils.data import Dataset, DataLoader, TensorDataset
12 |
13 |
14 | class MNIST(Dataset):
15 | def __init__(self, csv_file):
16 | data = pd.read_csv(csv_file)
17 | self.labels = torch.tensor(data.iloc[:, 0].values, dtype=torch.int64)
18 | self.images = torch.tensor(
19 | data.iloc[:, 1:].values, dtype=torch.float32
20 | ).reshape(-1, 1, 28, 28)
21 |
22 | # Normalize data
23 | self.images = self.images / 255.0
24 |
25 | self.dataset = TensorDataset(self.images, self.labels)
26 |
27 | def __len__(self):
28 | return len(self.dataset)
29 |
30 | def __getitem__(self, idx):
31 | return self.dataset[idx]
32 |
33 |
34 | class CNN(nn.Module):
35 | def __init__(self):
36 | super(CNN, self).__init__()
37 | self.conv1 = nn.Sequential(
38 | nn.Conv2d(
39 | in_channels=1,
40 | out_channels=16,
41 | kernel_size=5,
42 | stride=1,
43 | padding=2,
44 | ),
45 | nn.ReLU(),
46 | nn.MaxPool2d(kernel_size=2),
47 | )
48 | self.conv2 = nn.Sequential(
49 | nn.Conv2d(16, 32, 5, 1, 2),
50 | nn.ReLU(),
51 | nn.MaxPool2d(2),
52 | )
53 | # fully connected layer, output 10 classes
54 | self.out = nn.Linear(32 * 7 * 7, 10)
55 |
56 | def forward(self, x):
57 | x = self.conv1(x)
58 | x = self.conv2(x)
59 | # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
60 | x = x.view(x.size(0), -1)
61 | output = self.out(x)
62 | return output
63 |
64 |
65 | if __name__ == "__main__":
66 | num_epochs = 20
67 | batch_size = 4
68 | learning_rate = 1e-3
69 |
70 | # Load data
71 | train_data = MNIST("./examples/data/mnist_test_small.csv")
72 |
73 | # Visualize data
74 | num = 0
75 | plt.imshow(np.array(train_data[num][0]).squeeze())
76 | plt.title("%i" % train_data[num][1])
77 | plt.show()
78 |
79 | # Batchwise data loader
80 | loaders = {
81 | "train": DataLoader(
82 | train_data, batch_size=batch_size, shuffle=True, num_workers=1
83 | ),
84 | }
85 |
86 | device = torch.device("cpu")
87 | cnn = CNN()
88 | loss_func = nn.CrossEntropyLoss()
89 | optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
90 |
91 | # Train the model
92 | cnn.train()
93 | total_step = len(loaders["train"])
94 | start = time.time()
95 | for epoch in range(num_epochs):
96 | for i, (images, labels) in enumerate(loaders["train"]):
97 | b_x = Variable(images)
98 | b_y = Variable(labels)
99 |
100 | output = cnn(b_x)
101 | loss = loss_func(output, b_y)
102 |
103 | optimizer.zero_grad()
104 | loss.backward()
105 | optimizer.step()
106 |
107 | print(
108 | "Epoch [{}/{}],\t Step [{}/{}],\t Loss: {:.6f}".format(
109 | epoch + 1, num_epochs, i + 1, total_step, loss.item()
110 | )
111 | )
112 |
113 | print(f"Training time: {time.time() - start:.2f} seconds")
114 |
115 | # Export to ONNX
116 | export_onnx = os.environ.get("export_onnx", 0)
117 | if export_onnx == "1":
118 | dummy_input = torch.randn(1, 1, 28, 28)
119 |
120 | # cnn.out.weight = nn.Parameter(cnn.out.weight.T) # transpose because torch saves the weight of linear layer as (output_dim, input_dim) (so they transposed and there is not a real reason for this)
121 | torch.onnx.export(cnn, dummy_input, "./examples/data/mnist_torch.onnx", verbose=True)
--------------------------------------------------------------------------------
/examples/mnist_load_model.mojo:
--------------------------------------------------------------------------------
1 | from time.time import monotonic as now
2 | from pathlib import Path
3 |
4 | import basalt.nn as nn
5 | from basalt import Tensor, TensorShape
6 | from basalt import Graph, Symbol, OP, dtype
7 | from basalt.utils.datasets import MNIST
8 | from basalt.utils.dataloader import DataLoader
9 | from basalt.autograd.attributes import AttributeVector, Attribute
10 |
11 |
12 | # def plot_image(data: Tensor, num: Int):
13 | # from python.python import Python, PythonObject
14 |
15 | # np = Python.import_module("numpy")
16 | # plt = Python.import_module("matplotlib.pyplot")
17 |
18 | # var pyimage: PythonObject = np.empty((28, 28), np.float64)
19 | # for m in range(28):
20 | # for n in range(28):
21 | # pyimage.itemset((m, n), data[num * 28 * 28 + m * 28 + n])
22 |
23 | # plt.imshow(pyimage)
24 | # plt.show()
25 |
26 |
27 | fn create_CNN(batch_size: Int) -> Graph:
28 | var g = Graph()
29 | var x = g.input(TensorShape(batch_size, 1, 28, 28))
30 |
31 | var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2)
32 | var x2 = nn.ReLU(g, x1)
33 | var x3 = nn.MaxPool2d(g, x2, kernel_size=2)
34 | var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2)
35 | var x5 = nn.ReLU(g, x4)
36 | var x6 = nn.MaxPool2d(g, x5, kernel_size=2)
37 | var x7 = g.op(
38 | OP.RESHAPE,
39 | x6,
40 | attributes=AttributeVector(
41 | Attribute(
42 | "shape",
43 | TensorShape(x6.shape[0], x6.shape[1] * x6.shape[2] * x6.shape[3]),
44 | )
45 | ),
46 | )
47 | var out = nn.Linear(g, x7, n_outputs=10)
48 | g.out(out)
49 |
50 | return g ^
51 |
52 |
53 | fn main():
54 | alias num_epochs = 1
55 | alias batch_size = 4
56 | alias learning_rate = 1e-3
57 |
58 | alias graph = create_CNN(batch_size)
59 |
60 | # try: graph.render("operator")
61 | # except: print("Could not render graph")
62 |
63 | var model = nn.Model[graph]()
64 | model.load_model_data("./examples/data/mnist_torch.onnx")
65 |
66 | print("Loading data ...")
67 | var train_data: MNIST
68 | try:
69 | train_data = MNIST(file_path="./examples/data/mnist_test_small.csv")
70 | # _ = plot_image(train_data.data, 1)
71 | except e:
72 | print("Could not load data")
73 | print(e)
74 | return
75 |
76 | var training_loader = DataLoader(
77 | data=train_data.data, labels=train_data.labels, batch_size=batch_size
78 | )
79 |
80 | # Testing
81 | print("Testing started")
82 | var start = now()
83 |
84 | var correct = 0
85 | for batch in training_loader:
86 | var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10)
87 | for bb in range(batch.labels.dim(0)):
88 | labels_one_hot[int(bb * 10 + batch.labels[bb])] = 1.0
89 |
90 | var output = model.inference(batch.data, labels_one_hot)[0]
91 |
92 | fn argmax(tensor: Tensor[dtype], dim: Int) -> Tensor[dtype]:
93 | var result = Tensor[dtype](tensor.dim(0))
94 | for i in range(tensor.dim(0)):
95 | var max_val = tensor[i * 10]
96 | var max_idx = 0
97 | for j in range(1, 10):
98 | if tensor[i * 10 + j] > max_val:
99 | max_val = tensor[i * 10 + j]
100 | max_idx = j
101 | result[i] = max_idx
102 |
103 | return result
104 |
105 | var pred = argmax(output, dim=1)
106 |
107 | for i in range(batch.labels.dim(0)):
108 | if pred[i] == batch.labels[i]:
109 | correct += 1
110 |
111 | print("Accuracy: ", correct / train_data.data.dim(0) * 100, "%")
112 | print("Testing finished: ", (now() - start) / 1e9, "seconds")
113 |
114 | # model.print_perf_metrics("ms", True)
115 |
116 | model.export_model("./output_model.onnx")
--------------------------------------------------------------------------------
/examples/sin_estimate.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from time.time import monotonic as now
3 | import math
4 |
5 | import basalt.nn as nn
6 | from basalt import Tensor, TensorShape
7 | from basalt import dtype
8 | from basalt import Graph, Symbol, OP
9 | from basalt.utils.tensorutils import fill
10 |
11 |
12 | fn create_simple_nn(batch_size: Int, n_inputs: Int, n_outputs: Int) -> Graph:
13 | var g = Graph()
14 |
15 | var x = g.input(TensorShape(batch_size, n_inputs))
16 | var y_true = g.input(TensorShape(batch_size, n_outputs))
17 |
18 | var x1 = nn.Linear(g, x, n_outputs=32)
19 | var x2 = nn.ReLU(g, x1)
20 | var x3 = nn.Linear(g, x2, n_outputs=32)
21 | var x4 = nn.ReLU(g, x3)
22 | var y_pred = nn.Linear(g, x4, n_outputs=n_outputs)
23 | g.out(y_pred)
24 |
25 | var loss = nn.MSELoss(g, y_pred, y_true)
26 | g.loss(loss)
27 |
28 | g.compile()
29 |
30 | return g ^
31 |
32 |
33 | fn main():
34 | alias batch_size = 32
35 | alias n_inputs = 1
36 | alias n_outputs = 1
37 | alias learning_rate = 0.01
38 |
39 | alias epochs = 20000
40 |
41 | alias graph = create_simple_nn(batch_size, n_inputs, n_outputs)
42 |
43 | # try: graph.render("operator")
44 | # except: print("Could not render graph")
45 |
46 | var model = nn.Model[graph]()
47 | var optimizer = nn.optim.Adam[graph](model.parameters, lr=learning_rate)
48 |
49 | var x_data = Tensor[dtype](batch_size, n_inputs)
50 | var y_data = Tensor[dtype](batch_size, n_outputs)
51 |
52 | print("Training started")
53 | var start = now()
54 | for i in range(epochs):
55 | rand[dtype](x_data.data(), x_data.num_elements())
56 |
57 | for j in range(batch_size):
58 | x_data[j] = x_data[j] * 2 - 1
59 | y_data[j] = math.sin(x_data[j])
60 |
61 | var out = model.forward(x_data, y_data)
62 |
63 | if (i + 1) % 1000 == 0:
64 | print("[", i + 1, "/", epochs, "] \tLoss: ", out[0])
65 |
66 | optimizer.zero_grad()
67 | model.backward()
68 | optimizer.step()
69 |
70 | print("Training finished: ", (now() - start) / 1e9, "seconds")
71 |
--------------------------------------------------------------------------------
/examples/sin_estimate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import optim
4 | import time
5 |
6 |
7 | class SimpleNN(nn.Module):
8 | def __init__(self, n_inputs, n_outputs):
9 | super(SimpleNN, self).__init__()
10 | self.linear1 = nn.Linear(in_features=n_inputs, out_features=32)
11 | self.relu1 = nn.ReLU()
12 | self.linear2 = nn.Linear(in_features=32, out_features=32)
13 | self.relu2 = nn.ReLU()
14 | self.linear3 = nn.Linear(in_features=32, out_features=n_outputs)
15 |
16 | def forward(self, x):
17 | x1 = self.linear1(x)
18 | x2 = self.relu1(x1)
19 | x3 = self.linear2(x2)
20 | x4 = self.relu2(x3)
21 | y_pred = self.linear3(x4)
22 | return y_pred
23 |
24 |
25 | if __name__ == "__main__":
26 | batch_size = 32
27 | n_inputs = 1
28 | n_outputs = 1
29 | learning_rate = 0.01
30 |
31 | device = torch.device("cpu")
32 | model = SimpleNN(n_inputs, n_outputs).to(device)
33 | loss_func = nn.MSELoss()
34 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
35 |
36 | x = torch.rand(batch_size, n_inputs).to(device) * 2 - 1
37 | y = torch.sin(x).to(device)
38 |
39 | epochs = 20000
40 |
41 | model.train()
42 | start = time.time()
43 | for i in range(epochs):
44 | x = torch.rand(batch_size, n_inputs).to(device) * 2 - 1
45 | y = torch.sin(x).to(device)
46 |
47 | outputs = model(x)
48 | loss = loss_func(outputs, y)
49 |
50 | # Backward pass
51 | optimizer.zero_grad()
52 | loss.backward()
53 | optimizer.step()
54 |
55 | if (i + 1) % 1000 == 0:
56 | print(f"Epoch [{i + 1}/{epochs}],\t Loss: {loss.item()}")
57 |
58 | print(f"Training time: {time.time() - start:.2f} seconds. Loss: {loss.item()}")
59 |
--------------------------------------------------------------------------------
/mojoproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | authors = ["stijn", "benny", "andres"]
3 | channels = ["conda-forge", "https://conda.modular.com/max"]
4 | description = "Basalt is a stand-alone machine learning framework that leverages the power of Mojo."
5 | name = "Basalt"
6 | platforms = ["osx-arm64", "linux-64"]
7 | version = "0.1.0"
8 |
9 | [tasks]
10 | test = { cmd = "magic run mojo test -I . tests" }
11 | test_mojo = { cmd = "magic run mojo test -I . tests/mojo" }
12 | test_python = { cmd = "magic run mojo test -I . tests/python" }
13 |
14 | [dependencies]
15 | max = ">=24.6.0,<25"
16 |
17 | [pypi-dependencies]
18 | torch = ">=2.5.1, <3"
19 | torchvision = ">=0.20.1, <0.21"
20 | torchaudio = ">=2.5.1, <3"
21 | onnx = ">=1.17.0, <2"
22 |
--------------------------------------------------------------------------------
/python-requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.0
2 | matplotlib==3.8.0
3 | pandas==2.1.1
4 | onnx
5 | netron
--------------------------------------------------------------------------------
/tests/__init__.mojo:
--------------------------------------------------------------------------------
1 | from .testing_utils import *
2 |
--------------------------------------------------------------------------------
/tests/mojo/test_activations.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_equal
2 |
3 | from basalt import dtype
4 | from basalt.nn import (
5 | Tensor,
6 | TensorShape,
7 | Model,
8 | Softmax,
9 | LogSoftmax,
10 | ReLU,
11 | LeakyReLU,
12 | Sigmoid,
13 | Tanh,
14 | )
15 | from basalt.autograd import Graph, Symbol
16 | from basalt.utils.tensorutils import fill
17 |
18 | from tests import assert_tensors_equal
19 |
20 |
21 | alias Activation = fn (inout g: Graph, input: Symbol) -> Symbol
22 | alias AxisActivation = fn (inout g: Graph, input: Symbol, axis: Int) -> Symbol
23 | alias LeakyReLUActivation = fn (
24 | inout g: Graph, input: Symbol, negative_slope: Scalar[dtype]
25 | ) -> Symbol
26 |
27 |
28 | fn create_graph[
29 | shape: TensorShape,
30 | func: AxisActivation,
31 | axis: Int,
32 | ]() -> Graph:
33 | var g = Graph()
34 | var x = g.input(shape)
35 | var activation = func(g, x, axis)
36 | g.out(activation)
37 | return g^
38 |
39 |
40 | fn create_graph[
41 | shape: TensorShape,
42 | func: LeakyReLUActivation,
43 | negative_slope: Scalar[dtype],
44 | ]() -> Graph:
45 | var g = Graph()
46 | var x = g.input(shape)
47 | var activation = func(g, x, negative_slope)
48 | g.out(activation)
49 | return g^
50 |
51 |
52 | fn create_graph[shape: TensorShape, func: Activation]() -> Graph:
53 | var g = Graph()
54 | var x = g.input(shape)
55 | var activation = func(g, x)
56 | g.out(activation)
57 | return g^
58 |
59 |
60 | fn test_graph[
61 | shape: TensorShape,
62 | func: AxisActivation,
63 | nodes: Int,
64 | axis: Int,
65 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises:
66 | alias graph = create_graph[shape, func, axis]()
67 |
68 | var model = Model[graph](inference_only=True)
69 | var res = model.inference(input)[0]
70 |
71 | assert_tensors_equal["almost"](res, expected)
72 | assert_equal(len(graph.nodes), nodes)
73 |
74 |
75 | fn test_graph[
76 | shape: TensorShape,
77 | func: LeakyReLUActivation,
78 | nodes: Int,
79 | negative_slope: Scalar[dtype],
80 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises:
81 | alias graph = create_graph[shape, func, negative_slope]()
82 |
83 | var model = Model[graph](inference_only=True)
84 | var res = model.inference(input)[0]
85 |
86 | assert_tensors_equal["almost"](res, expected)
87 | assert_equal(len(graph.nodes), nodes)
88 |
89 |
90 | # TODO: All these overloads feel redundant. Find a way to condense them
91 | fn test_graph[
92 | shape: TensorShape,
93 | func: Activation,
94 | nodes: Int,
95 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises:
96 | alias graph = create_graph[shape, func]()
97 |
98 | var model = Model[graph](inference_only=True)
99 | var res = model.inference(input)[0]
100 |
101 | assert_tensors_equal["almost", "Tensor equality failed"](res, expected)
102 | assert_equal(len(graph.nodes), nodes, "Node count failed")
103 |
104 |
105 | fn test_SOFTMAX() raises:
106 | alias shape = TensorShape(2, 3, 2)
107 | alias nodes = 5
108 |
109 | var input = Tensor[dtype](shape)
110 | fill(input, 4)
111 |
112 | var expected = Tensor[dtype](shape)
113 |
114 | fill(expected, 0.5)
115 | test_graph[shape, Softmax, nodes, 0](input, expected)
116 |
117 | fill(expected, 1.0 / 3.0)
118 | test_graph[shape, Softmax, nodes, 1](input, expected)
119 |
120 | fill(expected, 0.5)
121 | test_graph[shape, Softmax, nodes, 2](input, expected)
122 |
123 |
124 | fn test_LOGSOFTMAX() raises:
125 | alias shape = TensorShape(2, 3, 2)
126 | alias nodes = 6
127 |
128 | var input = Tensor[dtype](shape)
129 | fill(input, 4)
130 |
131 | var expected = Tensor[dtype](shape)
132 |
133 | fill(expected, -0.69314718)
134 | test_graph[shape, LogSoftmax, nodes, 0](input, expected)
135 |
136 | fill(expected, -1.09861231)
137 | test_graph[shape, LogSoftmax, nodes, 1](input, expected)
138 |
139 | fill(expected, -0.69314718)
140 | test_graph[shape, LogSoftmax, nodes, 2](input, expected)
141 |
142 |
143 | fn test_RELU() raises:
144 | alias shape = TensorShape(2, 3)
145 | alias nodes = 1
146 |
147 | var input = Tensor[dtype](shape)
148 |
149 | for i in range(6):
150 | input[i] = 3 if i < 3 else -3
151 |
152 | var expected = Tensor[dtype](shape)
153 |
154 | for i in range(6):
155 | expected[i] = 3 if i < 3 else 0
156 |
157 | test_graph[shape, ReLU, nodes](input, expected)
158 |
159 |
160 | fn test_LEAKYRELU() raises:
161 | alias negative_slope = Float32(0.1)
162 |
163 | alias shape = TensorShape(2, 3)
164 | alias nodes = 1
165 |
166 | var input = Tensor[dtype](shape)
167 |
168 | for i in range(6):
169 | input[i] = i - 3
170 |
171 | var expected = Tensor[dtype](shape)
172 |
173 | for i in range(6):
174 | expected[i] = i - 3 if i - 3 > 0 else negative_slope * (i - 3)
175 |
176 | test_graph[shape, LeakyReLU, nodes, negative_slope](input, expected)
177 |
178 |
179 | fn test_SIGMOID() raises:
180 | alias shape = TensorShape(2, 3)
181 | alias nodes = 1
182 |
183 | var input = Tensor[dtype](shape)
184 | fill(input, 0)
185 |
186 | var expected = Tensor[dtype](shape)
187 |
188 | fill(expected, 0.5)
189 | test_graph[shape, Sigmoid, nodes](input, expected)
190 |
191 |
192 | fn test_TANH() raises:
193 | alias shape = TensorShape(2, 3)
194 | alias nodes = 1
195 |
196 | var input = Tensor[dtype](shape)
197 | fill(input, 0)
198 |
199 | var expected = Tensor[dtype](shape)
200 |
201 | fill(expected, 0.0)
202 | test_graph[shape, Tanh, nodes](input, expected)
203 |
204 |
205 | fn main():
206 | try:
207 | test_SOFTMAX()
208 | test_LOGSOFTMAX()
209 | test_RELU()
210 | test_LEAKYRELU()
211 | test_SIGMOID()
212 | test_TANH()
213 | except e:
214 | print("[ERROR] Error in activations")
215 | print(e)
216 |
--------------------------------------------------------------------------------
/tests/mojo/test_attributes.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_equal, assert_true
2 | from utils.index import IndexList
3 |
4 | from basalt.nn import TensorShape
5 | from basalt.autograd.attributes import Attribute
6 |
7 |
8 | fn test_attribute_key() raises:
9 | alias a = Attribute(name="test", value=-1)
10 |
11 | assert_true(str(a.name) == "test")
12 |
13 |
14 | fn test_attribute_int() raises:
15 | alias value: Int = 1
16 | alias a = Attribute(name="test", value=value)
17 |
18 | assert_true(a.to_int() == 1)
19 |
20 |
21 | fn test_attribute_string() raises:
22 | alias value: String = "hello"
23 | alias a = Attribute(name="test", value=value)
24 |
25 | assert_true(a.to_string() == value)
26 |
27 |
28 | fn test_attribute_tensor_shape() raises:
29 | alias value: TensorShape = TensorShape(1, 2, 3)
30 | alias a = Attribute(name="test", value=value)
31 |
32 | assert_true(a.to_shape() == value)
33 |
34 |
35 | fn test_attribute_static_int_tuple() raises:
36 | alias value: IndexList[7] = IndexList[7](1, 2, 3, 4, 5, 6, 7)
37 | alias a = Attribute(name="test", value=value)
38 |
39 | assert_true(a.to_static[7]() == value)
40 |
41 |
42 | fn test_attribute_scalar() raises:
43 | fn test_float32() raises:
44 | alias value_a: Float32 = 1.23456
45 | alias a1 = Attribute(name="test", value=value_a)
46 | assert_true(
47 | a1.to_scalar[DType.float32]() == value_a,
48 | "Float32 scalar attribute failed",
49 | )
50 |
51 | alias value_b: Float32 = 65151
52 | alias a2 = Attribute(name="test", value=value_b)
53 | assert_true(
54 | a2.to_scalar[DType.float32]() == value_b,
55 | "Float32 scalar attribute failed",
56 | )
57 |
58 | fn test_float_literal() raises:
59 | alias value_c: FloatLiteral = -1.1
60 | alias a3 = Attribute(name="test", value=value_c)
61 | assert_true(
62 | a3.to_scalar[DType.float32]() == value_c,
63 | "FloatLiteral scalar attribute failed",
64 | )
65 |
66 | fn test_float64() raises:
67 | alias value_a: Float64 = -1.23456
68 | alias a1 = Attribute(name="test", value=value_a)
69 | assert_true(
70 | a1.to_scalar[DType.float64]() == value_a,
71 | "Float64 scalar attribute failed",
72 | )
73 |
74 | alias value_b: Float64 = 123456
75 | alias a2 = Attribute(name="test", value=value_b)
76 | assert_true(
77 | a2.to_scalar[DType.float64]() == value_b,
78 | "Float64 scalar attribute failed",
79 | )
80 |
81 | fn test_int32() raises:
82 | alias value_a: Int32 = 666
83 | alias a1 = Attribute(name="test", value=value_a)
84 | assert_true(
85 | a1.to_scalar[DType.int32]() == value_a,
86 | "Int32 scalar attribute failed",
87 | )
88 |
89 | alias value_b: Int32 = -666
90 | alias a2 = Attribute(name="test", value=value_b)
91 | assert_true(
92 | a2.to_scalar[DType.int32]() == value_b,
93 | "Int32 scalar attribute failed",
94 | )
95 |
96 | fn test_attribute_small_scalar() raises:
97 | alias value_a: Float32 = 1e-18
98 | alias a = Attribute(name="test", value=value_a)
99 | assert_true(
100 | a.to_scalar[DType.float32]() == value_a,
101 | "SMALL scalar attribute failed",
102 | )
103 |
104 | fn test_attribute_big_scalar() raises:
105 | alias value_a: Float32 = 1e40
106 | alias a = Attribute(name="test", value=value_a)
107 | assert_true(
108 | a.to_scalar[DType.float32]() == value_a,
109 | "BIG scalar attribute failed",
110 | )
111 |
112 | test_float32()
113 | test_float_literal()
114 | test_float64()
115 | test_int32()
116 | test_attribute_small_scalar()
117 | test_attribute_big_scalar()
118 |
119 |
120 | fn main():
121 | try:
122 | test_attribute_key()
123 | test_attribute_int()
124 | test_attribute_string()
125 | test_attribute_tensor_shape()
126 | test_attribute_static_int_tuple()
127 | test_attribute_scalar()
128 | except e:
129 | print("[ERROR] Error in attributes")
130 | print(e)
131 |
--------------------------------------------------------------------------------
/tests/mojo/test_collection.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_equal
2 |
3 | from basalt import dtype
4 | from basalt.nn import Tensor, TensorShape
5 | from basalt.autograd import Symbol
6 | from basalt.utils.collection import Collection
7 | from basalt.utils.tensorutils import fill
8 |
9 | from tests import assert_tensors_equal
10 |
11 |
12 | fn test_append_tensors() raises:
13 | alias t1_shape = TensorShape(1, 10)
14 | alias t2_shape = TensorShape(2, 20)
15 | var s1 = Symbol(0, dtype, t1_shape, True)
16 | var s2 = Symbol(1, dtype, t2_shape, True)
17 |
18 | var c = Collection(capacity=2)
19 | assert_equal(c.capacity, 2)
20 | assert_equal(c.size, 0)
21 |
22 | c.append(Tensor[dtype](s1.shape), s1)
23 | assert_equal(c.size, 1)
24 |
25 | c.append(Tensor[dtype](s2.shape), s2)
26 | assert_equal(c.size, 2)
27 |
28 |
29 | fn test_get_tensor_reference() raises:
30 | alias t1_shape = TensorShape(1, 10)
31 | alias t2_shape = TensorShape(2, 20)
32 | var s1 = Symbol(0, dtype, t1_shape, True)
33 | var s2 = Symbol(1, dtype, t2_shape, True)
34 |
35 | var t1 = Tensor[dtype](s1.shape)
36 | var t2 = Tensor[dtype](s2.shape)
37 | fill(t1, 1)
38 | fill(t2, 2)
39 |
40 | var c = Collection(capacity=2)
41 | c.append(t1 ^, s1)
42 | c.append(t2 ^, s2)
43 |
44 | var t1_expected = Tensor[dtype](s1.shape)
45 | var t2_expected = Tensor[dtype](s2.shape)
46 | fill(t1_expected, 1)
47 | fill(t2_expected, 2)
48 |
49 | assert_tensors_equal(c[s1], t1_expected)
50 | assert_tensors_equal(c[s2], t2_expected)
51 |
52 |
53 | fn test_resize_collection() raises:
54 | alias t1_shape = TensorShape(1, 10)
55 | alias t2_shape = TensorShape(2, 20)
56 | alias t3_shape = TensorShape(3, 30)
57 | var s1 = Symbol(0, dtype, t1_shape, True)
58 | var s2 = Symbol(1, dtype, t2_shape, True)
59 | var s3 = Symbol(2, dtype, t3_shape, True)
60 |
61 | var t1 = Tensor[dtype](s1.shape)
62 | var t2 = Tensor[dtype](s2.shape)
63 | var t3 = Tensor[dtype](s3.shape)
64 | fill(t1, 1)
65 | fill(t2, 2)
66 | fill(t3, 3)
67 |
68 | var c = Collection(capacity=1)
69 | assert_equal(c.size, 0)
70 | assert_equal(c.capacity, 1)
71 |
72 | c.append(t1 ^, s1)
73 | assert_equal(c.size, 1)
74 | assert_equal(c.capacity, 1)
75 |
76 | c.append(t2 ^, s2)
77 | assert_equal(c.size, 2)
78 | assert_equal(c.capacity, 2)
79 |
80 | c.append(t3 ^, s3)
81 | assert_equal(c.size, 3)
82 | assert_equal(c.capacity, 4)
83 |
84 | var t1_expected = Tensor[dtype](s1.shape)
85 | var t2_expected = Tensor[dtype](s2.shape)
86 | var t3_expected = Tensor[dtype](s3.shape)
87 | fill(t1_expected, 1)
88 | fill(t2_expected, 2)
89 | fill(t3_expected, 3)
90 |
91 | assert_tensors_equal(c[s1], t1_expected)
92 | assert_tensors_equal(c[s2], t2_expected)
93 | assert_tensors_equal(c[s3], t3_expected)
94 |
95 |
96 | fn test_set_zero() raises:
97 | alias t1_shape = TensorShape(1, 10)
98 | alias t2_shape = TensorShape(2, 20)
99 | var s1 = Symbol(0, dtype, t1_shape, True)
100 | var s2 = Symbol(1, dtype, t2_shape, True)
101 | var t1 = Tensor[dtype](s1.shape)
102 | var t2 = Tensor[dtype](s2.shape)
103 | fill(t1, 1)
104 | fill(t2, 2)
105 |
106 | var c = Collection(capacity=2)
107 | c.append(t1 ^, s1)
108 | c.append(t2 ^, s2)
109 |
110 | var t1_expected = Tensor[dtype](s1.shape)
111 | var t2_expected = Tensor[dtype](s2.shape)
112 | fill(t1_expected, 1)
113 | fill(t2_expected, 2)
114 | assert_tensors_equal(c[s1], t1_expected)
115 | assert_tensors_equal(c[s2], t2_expected)
116 |
117 | c.set_zero()
118 |
119 | assert_tensors_equal(c[s1], Tensor[dtype](t1_shape))
120 | assert_tensors_equal(c[s2], Tensor[dtype](t2_shape))
121 |
122 |
123 | fn test_operate_on_reference() raises:
124 | alias res_shape = TensorShape(1, 10)
125 | alias t1_shape = TensorShape(1, 10)
126 | var sr = Symbol(0, dtype, t1_shape, True)
127 | var s1 = Symbol(1, dtype, t1_shape, True)
128 | var res = Tensor[dtype](res_shape)
129 | var t1 = Tensor[dtype](s1.shape)
130 |
131 | var c = Collection(capacity=2)
132 | c.append(res ^, sr)
133 | c.append(t1 ^, s1)
134 |
135 | fn some_operation[
136 | res_shape: TensorShape, t_shape: TensorShape
137 | ](inout res: Tensor[dtype], t1: Tensor[dtype]):
138 | for i in range(res.num_elements()):
139 | res[i] = t1[i]
140 |
141 | for i in range(1, 10):
142 | some_operation[res_shape, t1_shape](c[sr], c[s1])
143 | fill(c[s1], i)
144 |
145 | var res_expected = Tensor[dtype](res_shape)
146 | var t1_expected = Tensor[dtype](t1_shape)
147 | fill(res_expected, i - 1)
148 | fill(t1_expected, i)
149 |
150 | assert_tensors_equal(c[sr], res_expected)
151 | assert_tensors_equal(c[s1], t1_expected)
152 |
153 |
154 | fn main() raises:
155 | try:
156 | test_append_tensors()
157 | test_get_tensor_reference()
158 | test_resize_collection()
159 | test_set_zero()
160 | test_operate_on_reference()
161 | except e:
162 | print(e)
163 | raise e
164 |
--------------------------------------------------------------------------------
/tests/mojo/test_loss.mojo:
--------------------------------------------------------------------------------
1 | from testing import assert_equal, assert_almost_equal
2 |
3 | from basalt import dtype, nelts
4 | from basalt.autograd import Graph, Symbol, OP
5 | from basalt.nn import Model, Tensor, TensorShape, MSELoss, CrossEntropyLoss
6 | from basalt.utils.tensorutils import fill
7 |
8 |
9 | fn test_MSE_perfect() raises:
10 | alias y_pred_shape = TensorShape(2, 10) # batch of 2, 10 classes
11 | alias y_true_shape = TensorShape(2, 10)
12 |
13 | fn create_graph() -> Graph:
14 | var g = Graph()
15 |
16 | var y_pred = g.input(y_pred_shape)
17 | var y_true = g.input(y_true_shape)
18 |
19 | var loss = MSELoss(g, y_pred, y_true)
20 |
21 | g.out(loss)
22 |
23 | return g ^
24 |
25 | alias graph = create_graph()
26 | assert_equal(len(graph.nodes), 3)
27 |
28 | var y_pred = Tensor[dtype](y_pred_shape)
29 | var y_true = Tensor[dtype](y_true_shape)
30 |
31 | fill(y_pred, 1)
32 | fill(y_true, 1)
33 |
34 | var model = Model[graph](inference_only=True)
35 |
36 | var loss = model.inference(y_pred, y_true)[0]
37 |
38 | assert_equal(loss.dim(0), 1) # MSE summed over all elements
39 | assert_equal(loss[0], 0) # loss is 0
40 |
41 |
42 | fn test_MSE_imperfect() raises:
43 | alias y_pred_shape = TensorShape(1, 10) # batch of 1, 10 classes
44 | alias y_true_shape = TensorShape(1, 10)
45 |
46 | fn create_graph() -> Graph:
47 | var g = Graph()
48 |
49 | var y_pred = g.input(y_pred_shape)
50 | var y_true = g.input(y_true_shape)
51 |
52 | var loss = MSELoss(g, y_pred, y_true)
53 |
54 | g.out(loss)
55 |
56 | return g ^
57 |
58 | alias graph = create_graph()
59 | assert_equal(len(graph.nodes), 3)
60 |
61 | var y_pred = Tensor[dtype](y_pred_shape)
62 | var y_true = Tensor[dtype](y_true_shape)
63 |
64 | fill(y_pred, 1)
65 |
66 | for i in range(10):
67 | y_true[i] = i
68 |
69 | var model = Model[graph](inference_only=True)
70 |
71 | var loss = model.inference(y_pred, y_true)[0]
72 |
73 | var expected_loss: Scalar[dtype] = 0.0
74 |
75 | for i in range(10):
76 | expected_loss += (y_pred[i] - y_true[i]) ** 2
77 |
78 | expected_loss = expected_loss / y_true_shape[1]
79 |
80 | assert_almost_equal(loss[0], expected_loss)
81 |
82 |
83 | fn test_CrossEntropy_perfect() raises:
84 | alias y_pred_shape = TensorShape(2, 3) # batch of 2, 3 classes
85 | alias y_true_shape = TensorShape(2, 3)
86 |
87 | fn create_graph() -> Graph:
88 | var g = Graph()
89 |
90 | var y_pred = g.input(y_pred_shape)
91 | var y_true = g.input(y_true_shape)
92 |
93 | var loss = CrossEntropyLoss(g, y_pred, y_true)
94 |
95 | g.out(loss)
96 |
97 | return g ^
98 |
99 | alias graph = create_graph()
100 | assert_equal(len(graph.nodes), 9)
101 |
102 | var y_pred = Tensor[dtype](y_pred_shape)
103 | var y_true = Tensor[dtype](y_true_shape)
104 |
105 | y_pred[0 * y_pred.dim(1) + 0] = 0.1
106 | y_pred[0 * y_pred.dim(1) + 1] = 0.2
107 | y_pred[0 * y_pred.dim(1) + 2] = 0.7
108 | y_true[0 * y_true.dim(1) + 0] = 0
109 | y_true[0 * y_true.dim(1) + 1] = 0
110 | y_true[0 * y_true.dim(1) + 2] = 1
111 |
112 | y_pred[1 * y_pred.dim(1) + 0] = 0.7
113 | y_pred[1 * y_pred.dim(1) + 1] = 0.2
114 | y_pred[1 * y_pred.dim(1) + 2] = 0.1
115 | y_true[1 * y_true.dim(1) + 0] = 1
116 | y_true[1 * y_true.dim(1) + 1] = 0
117 | y_true[1 * y_true.dim(1) + 2] = 0
118 |
119 | var model = Model[graph](inference_only=True)
120 |
121 | var loss = model.inference(y_pred, y_true)[0]
122 |
123 | assert_equal(loss.shape(), TensorShape(1))
124 | assert_almost_equal(loss[0], 0.76794958)
125 |
126 |
127 | fn test_CrossEntropy_imperfect() raises:
128 | alias y_pred_shape = TensorShape(2, 3) # batch of 2, 3 classes
129 | alias y_true_shape = TensorShape(2, 3)
130 |
131 | fn create_graph() -> Graph:
132 | var g = Graph()
133 |
134 | var y_pred = g.input(y_pred_shape)
135 | var y_true = g.input(y_true_shape)
136 |
137 | var loss = CrossEntropyLoss(g, y_pred, y_true)
138 |
139 | g.out(loss)
140 |
141 | return g ^
142 |
143 | alias graph = create_graph()
144 |
145 | var y_pred = Tensor[dtype](y_pred_shape)
146 | var y_true = Tensor[dtype](y_true_shape)
147 |
148 | y_pred[0 * y_pred.dim(1) + 0] = 0.1
149 | y_pred[0 * y_pred.dim(1) + 1] = 0.2
150 | y_pred[0 * y_pred.dim(1) + 2] = 0.7
151 | y_true[0 * y_true.dim(1) + 0] = 0
152 | y_true[0 * y_true.dim(1) + 1] = 1
153 | y_true[0 * y_true.dim(1) + 2] = 0
154 |
155 | y_pred[1 * y_pred.dim(1) + 0] = 0.7
156 | y_pred[1 * y_pred.dim(1) + 1] = 0.2
157 | y_pred[1 * y_pred.dim(1) + 2] = 0.1
158 | y_true[1 * y_true.dim(1) + 0] = 0
159 | y_true[1 * y_true.dim(1) + 1] = 0
160 | y_true[1 * y_true.dim(1) + 2] = 1
161 |
162 | var model = Model[graph](inference_only=True)
163 |
164 | var loss = model.inference(y_pred, y_true)[0]
165 |
166 | assert_equal(loss.shape(), TensorShape(1))
167 | assert_almost_equal(loss[0], 1.31794953)
168 |
169 |
170 | fn main():
171 | try:
172 | test_MSE_perfect()
173 | test_MSE_imperfect()
174 | test_CrossEntropy_perfect()
175 | test_CrossEntropy_imperfect()
176 | except e:
177 | print("[ERROR] Error in loss")
178 | print(e)
179 |
--------------------------------------------------------------------------------
/tests/mojo/test_ops.mojo:
--------------------------------------------------------------------------------
1 | from math import exp, log
2 | from utils.index import IndexList
3 |
4 | from basalt import dtype, nelts
5 | from basalt.autograd import OP
6 | from basalt.autograd.attributes import Attribute, AttributeVector
7 | from basalt.utils.tensorutils import fill
8 | from basalt.nn import Tensor, TensorShape
9 |
10 | from tests import test_unary_op, test_binary_op, test_ternary_op
11 |
12 |
13 | fn test_ADD() raises:
14 | alias t1_shape = TensorShape(2, 3)
15 | alias t2_shape = TensorShape(2, 3)
16 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
17 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
18 | fill(t1, 1.0)
19 | fill(t2, 1.0)
20 |
21 | var expected = Tensor[dtype](2, 3)
22 | fill(expected, 2.0)
23 |
24 | test_binary_op[OP.ADD, t1_shape, t2_shape](t1, t2, expected)
25 |
26 |
27 | fn test_SUB() raises:
28 | alias t1_shape = TensorShape(2, 3)
29 | alias t2_shape = TensorShape(2, 3)
30 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
31 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
32 | fill(t1, 2.0)
33 | fill(t2, 1.0)
34 |
35 | var expected = Tensor[dtype](2, 3)
36 | fill(expected, 1.0)
37 |
38 | test_binary_op[OP.SUB, t1_shape, t2_shape](t1, t2, expected)
39 |
40 |
41 | fn test_MUL() raises:
42 | alias t1_shape = TensorShape(2, 3)
43 | alias t2_shape = TensorShape(2, 3)
44 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
45 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
46 | fill(t1, 2.0)
47 | fill(t2, 3.0)
48 |
49 | var expected = Tensor[dtype](2, 3)
50 | fill(expected, 6.0)
51 |
52 | test_binary_op[OP.MUL, t1_shape, t2_shape](t1, t2, expected)
53 |
54 |
55 | fn test_DIV() raises:
56 | alias t1_shape = TensorShape(2, 3)
57 | alias t2_shape = TensorShape(2, 3)
58 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
59 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
60 | fill(t1, 6.0)
61 | fill(t2, 2.0)
62 |
63 | var expected = Tensor[dtype](2, 3)
64 | fill(expected, 3.0)
65 |
66 | test_binary_op[OP.DIV, t1_shape, t2_shape](t1, t2, expected)
67 |
68 |
69 | fn test_DOT() raises:
70 | alias t1_shape = TensorShape(2, 3)
71 | alias t2_shape = TensorShape(3, 2)
72 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
73 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
74 | fill(t1, 1.0)
75 | fill(t2, 2.0)
76 |
77 | var expected = Tensor[dtype](2, 2)
78 | fill(expected, 6.0)
79 |
80 | test_binary_op[OP.DOT, t1_shape, t2_shape](t1, t2, expected)
81 |
82 |
83 | fn test_EXP() raises:
84 | alias t1_shape = TensorShape(2, 3)
85 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
86 | fill(t1, 2.0)
87 |
88 | var expected = Tensor[dtype](2, 3)
89 | fill(expected, exp(SIMD[dtype, 1](2.0)))
90 |
91 | test_unary_op[OP.EXP, t1_shape](t1, expected)
92 |
93 |
94 | fn test_LOG() raises:
95 | alias t1_shape = TensorShape(2, 3)
96 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
97 | fill(t1, 2.0)
98 |
99 | var expected = Tensor[dtype](2, 3)
100 | fill(expected, log(SIMD[dtype, 1](2.0)))
101 |
102 | test_unary_op[OP.LOG, t1_shape](t1, expected)
103 |
104 |
105 | fn test_POW() raises:
106 | alias t1_shape = TensorShape(2, 3)
107 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
108 | fill(t1, 2.0)
109 |
110 | alias t2_shape = TensorShape(1)
111 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
112 | t2[0] = 2.0
113 |
114 | var expected = Tensor[dtype](2, 3)
115 | fill(expected, 4.0)
116 |
117 | test_binary_op[OP.POW, t1_shape, t2_shape](t1, t2, expected)
118 |
119 |
120 | fn test_SUM() raises:
121 | alias t1_shape = TensorShape(2, 3, 4)
122 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
123 | fill(t1, 1.0)
124 |
125 | # No axis specified
126 | var expected = Tensor[dtype](1)
127 | fill(expected, 24.0)
128 | test_unary_op[OP.SUM, t1_shape](t1, expected)
129 |
130 | # Test axis 1
131 | alias attrs = AttributeVector(Attribute("axis", 1))
132 | expected = Tensor[dtype](2, 1, 4)
133 | fill(expected, 3.0)
134 | test_unary_op[OP.SUM, t1_shape, attrs](t1, expected)
135 |
136 |
137 | fn test_MAX() raises:
138 | alias t1_shape = TensorShape(2, 3, 2)
139 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
140 | for i in range(t1_shape.num_elements()):
141 | t1[i] = i + 1
142 |
143 | # No axis specified
144 | var expected = Tensor[dtype](1)
145 | fill(expected, t1_shape.num_elements())
146 | test_unary_op[OP.MAX, t1_shape](t1, expected)
147 |
148 | @parameter
149 | fn fill_tensor[
150 | size: Int
151 | ](inout tensor: Tensor[dtype], values: IndexList[size]):
152 | for i in range(tensor.num_elements()):
153 | tensor[i] = values[i]
154 |
155 | # Test axis 0
156 | alias attrs = AttributeVector(Attribute("axis", 0))
157 | var expected_max_axis_0_temp = IndexList[6](7, 8, 9, 10, 11, 12)
158 | expected = Tensor[dtype](1, 3, 2)
159 | fill_tensor(expected, expected_max_axis_0_temp)
160 | test_unary_op[OP.MAX, t1_shape, attrs](t1, expected)
161 |
162 | # Test axis 1
163 | alias attrs_1 = AttributeVector(Attribute("axis", 1))
164 | var expected_max_axis_1_temp = IndexList[4](5, 6, 11, 12)
165 | expected = Tensor[dtype](2, 1, 2)
166 | fill_tensor(expected, expected_max_axis_1_temp)
167 | test_unary_op[OP.MAX, t1_shape, attrs_1](t1, expected)
168 |
169 | # Test axis 2
170 | alias attrs_2 = AttributeVector(Attribute("axis", 2))
171 | var expected_max_axis_2_temp = IndexList[6](2, 4, 6, 8, 10, 12)
172 | expected = Tensor[dtype](2, 3, 1)
173 | fill_tensor(expected, expected_max_axis_2_temp)
174 | test_unary_op[OP.MAX, t1_shape, attrs_2](t1, expected)
175 |
176 |
177 | fn test_MEAN() raises:
178 | alias t1_shape = TensorShape(2, 3)
179 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
180 | fill(t1, 5.0)
181 |
182 | # No axis specified
183 | var expected = Tensor[dtype](1)
184 | fill(expected, 5.0)
185 | test_unary_op[OP.MEAN, t1_shape](t1, expected)
186 |
187 | # Test axis 0
188 | alias attrs = AttributeVector(Attribute("axis", 0))
189 | expected = Tensor[dtype](1, 3)
190 | fill(expected, 5.0)
191 | test_unary_op[OP.MEAN, t1_shape, attrs](t1, expected)
192 |
193 | # Test axis 1
194 | alias attrs_1 = AttributeVector(Attribute("axis", 1))
195 | expected = Tensor[dtype](2, 1)
196 | fill(expected, 5.0)
197 | test_unary_op[OP.MEAN, t1_shape, attrs_1](t1, expected)
198 |
199 |
200 | fn test_TRANSPOSE() raises:
201 | alias t1_shape = TensorShape(2, 3, 4)
202 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
203 | for i in range(t1_shape.num_elements()):
204 | t1[i] = i + 1
205 |
206 | # Test tranpose (no attributes = reversing the axis by default)
207 | var expected = Tensor[dtype](4, 3, 2)
208 | var expected_strides = expected.strides()
209 | for i in range(t1_shape[0]):
210 | for j in range(t1_shape[1]):
211 | for k in range(t1_shape[2]):
212 | expected[k * expected_strides[0] + j * expected_strides[1] + i] = t1[
213 | i * t1_shape[1] * t1_shape[2] + j * t1_shape[2] + k
214 | ]
215 |
216 | test_unary_op[OP.TRANSPOSE, t1_shape](t1, expected)
217 |
218 | # Test tranpose 1, 2, 0
219 | alias attrs = AttributeVector(Attribute("axes", TensorShape(1, 2, 0)))
220 | var expected_axis_1 = Tensor[dtype](3, 4, 2)
221 | var expected_axis_1_strides = expected_axis_1.strides()
222 | for i in range(t1_shape[0]):
223 | for j in range(t1_shape[1]):
224 | for k in range(t1_shape[2]):
225 | expected_axis_1[
226 | j * expected_axis_1_strides[0] + k * expected_axis_1_strides[1] + i
227 | ] = t1[i * t1_shape[1] * t1_shape[2] + j * t1_shape[2] + k]
228 |
229 | test_unary_op[OP.TRANSPOSE, t1_shape, attrs](t1, expected_axis_1)
230 |
231 |
232 | fn test_FLATTEN() raises:
233 | alias t1_shape = TensorShape(2, 3, 4)
234 | var t1 = Tensor[dtype](t1_shape)
235 | fill(t1, 1.0)
236 |
237 | var expected = Tensor[dtype](24)
238 | fill(expected, 1.0)
239 |
240 | test_unary_op[OP.FLATTEN, t1_shape](t1, expected)
241 |
242 |
243 | fn test_RESHAPE() raises:
244 | alias t_shape = TensorShape(2, 2, 5)
245 | alias new_shape = TensorShape(2, 10)
246 |
247 | var t = Tensor[dtype](t_shape)
248 | var expected = Tensor[dtype](new_shape)
249 | for i in range(20):
250 | t[i] = i + 1
251 | expected[i] = i + 1
252 |
253 | alias attrs = AttributeVector(Attribute("shape", new_shape))
254 | test_unary_op[OP.RESHAPE, t_shape, attrs](t, expected)
255 |
256 |
257 | fn test_FMA() raises:
258 | alias t1_shape = TensorShape(2, 3)
259 | alias t2_shape = TensorShape(2, 3)
260 | alias t3_shape = TensorShape(2, 3)
261 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape)
262 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape)
263 | var t3: Tensor[dtype] = Tensor[dtype](t3_shape)
264 | fill(t1, 1.0)
265 | fill(t2, 2.0)
266 | fill(t3, 3.0)
267 |
268 | var expected = Tensor[dtype](2, 3)
269 | fill(expected, 1.0 * 2.0 + 3.0)
270 |
271 | test_ternary_op[OP.FMA, t1_shape, t2_shape, t3_shape](t1, t2, t3, expected)
272 |
273 |
274 | fn main():
275 | try:
276 | test_ADD()
277 | test_SUB()
278 | test_MUL()
279 | test_DIV()
280 | test_DOT()
281 | test_EXP()
282 | test_LOG()
283 | test_POW()
284 | test_SUM()
285 | test_MAX()
286 | test_MEAN()
287 | test_TRANSPOSE()
288 | test_FLATTEN()
289 | test_RESHAPE()
290 | test_FMA()
291 | except e:
292 | print("[ERROR] Error in ops")
293 | print(e)
294 |
--------------------------------------------------------------------------------
/tests/python/test_broadcast_shapes.mojo:
--------------------------------------------------------------------------------
1 | from python.python import Python, PythonObject
2 | from testing import assert_true
3 |
4 | from basalt.nn import Tensor, TensorShape
5 | from basalt.utils.tensorutils import broadcast_shapes
6 |
7 |
8 | fn to_tensor_shape(owned shape: PythonObject) raises -> TensorShape:
9 | var tensor_shape = List[Int]()
10 | for dim in shape:
11 | tensor_shape.append(int(float(dim)))
12 | return TensorShape(tensor_shape)
13 |
14 |
15 | fn np_broadcast_shapes(s1: TensorShape, s2: TensorShape) raises -> TensorShape:
16 | var np = Python.import_module("numpy")
17 | var s1_py: PythonObject = []
18 | var s2_py: PythonObject = []
19 | for i in range(s1.rank()):
20 | s1_py += [s1[i]]
21 | for i in range(s2.rank()):
22 | s2_py += [s2[i]]
23 |
24 | var py_shape = np.broadcast_shapes(s1_py, s2_py)
25 |
26 | return to_tensor_shape(py_shape)
27 |
28 |
29 | fn test_broadcast_shapes() raises:
30 | var s1 = TensorShape(3, 5, 2)
31 | var s2 = TensorShape(3, 5, 2)
32 | var s3 = broadcast_shapes(s1, s2)
33 | assert_true(s3 == np_broadcast_shapes(s1, s2))
34 |
35 | s1 = TensorShape(3, 5, 2)
36 | s2 = TensorShape(1, 2)
37 | s3 = broadcast_shapes(s1, s2)
38 | assert_true(s3 == np_broadcast_shapes(s1, s2))
39 |
40 | s1 = TensorShape(5, 1)
41 | s2 = TensorShape(3, 5, 1)
42 | s3 = broadcast_shapes(s1, s2)
43 | assert_true(s3 == np_broadcast_shapes(s1, s2))
44 |
45 | s1 = TensorShape(3, 1, 2)
46 | s2 = TensorShape(3, 5, 2)
47 | s3 = broadcast_shapes(s1, s2)
48 | assert_true(s3 == np_broadcast_shapes(s1, s2))
49 |
50 | s1 = TensorShape(1, 1, 1)
51 | s2 = TensorShape(3, 5, 2)
52 | s3 = broadcast_shapes(s1, s2)
53 | assert_true(s3 == np_broadcast_shapes(s1, s2))
54 |
55 | s1 = TensorShape(2)
56 | s2 = TensorShape(3, 5, 2)
57 | s3 = broadcast_shapes(s1, s2)
58 | assert_true(s3 == np_broadcast_shapes(s1, s2))
59 |
60 | s1 = TensorShape()
61 | s2 = TensorShape(3, 5, 2)
62 | s3 = broadcast_shapes(s1, s2)
63 | assert_true(s3 == np_broadcast_shapes(s1, s2))
64 |
65 | # # Both errors expected
66 | # print("EXPECTED RAISE!")
67 | # try:
68 | # s1 = TensorShape(3, 2, 2)
69 | # s2 = TensorShape(3, 5, 2)
70 | # s3 = broadcast_shapes(s1, s2)
71 | # _ = np_broadcast_shapes(s1, s2)
72 | # except e:
73 | # print("Numpy:", e)
74 |
75 | # print("EXPECTED RAISE!")
76 | # try:
77 | # s1 = TensorShape(3)
78 | # s2 = TensorShape(2)
79 | # s3 = broadcast_shapes(s1, s2)
80 | # _ = np_broadcast_shapes(s1, s2)
81 | # except e:
82 | # print("Numpy:", e)
83 |
84 |
85 | fn test_broadcast_shapes_multiple() raises:
86 | var np = Python.import_module("numpy")
87 |
88 | var s1 = TensorShape(1, 2)
89 | var s2 = TensorShape(3, 1)
90 | var s3 = TensorShape(3, 2)
91 | var res = broadcast_shapes(s1, s2, s3)
92 | var res_np = to_tensor_shape(np.broadcast_shapes((1, 2), (3, 1), (3, 2)))
93 | assert_true(res == res_np)
94 |
95 | s1 = TensorShape(6, 7)
96 | s2 = TensorShape(5, 6, 1)
97 | s3 = TensorShape(7)
98 | var s4 = TensorShape(5, 1, 7)
99 | res = broadcast_shapes(s1, s2, s3, s4)
100 | res_np = to_tensor_shape(np.broadcast_shapes((6, 7), (5, 6, 1), (7), (5, 1, 7)))
101 | assert_true(res == res_np)
102 |
103 |
104 | fn main():
105 | try:
106 | test_broadcast_shapes()
107 | test_broadcast_shapes_multiple()
108 | except e:
109 | print("[Error] In test broadcasting.")
110 | print(e)
111 |
--------------------------------------------------------------------------------
/tests/python/test_dynamic_ops_torch.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from python.python import Python, PythonObject
3 |
4 | from basalt import dtype, nelts
5 | from basalt.autograd import Graph, Symbol, OP
6 | from basalt.autograd.attributes import Attribute, AttributeVector
7 | from basalt.nn import Model, Tensor, TensorShape
8 |
9 | from tests import (
10 | assert_tensors_equal,
11 | to_numpy,
12 | to_tensor,
13 | create_graph_concat,
14 | create_graph_split,
15 | )
16 |
17 |
18 | @value
19 | struct torch_output_cat:
20 | var expected: Tensor[dtype]
21 | var grad_1: Tensor[dtype]
22 | var grad_2: Tensor[dtype]
23 | var grad_3: Tensor[dtype]
24 |
25 |
26 | fn torch_cat(
27 | input_1: Tensor, input_2: Tensor, input_3: Tensor, upper_grad: Tensor, dim: Int
28 | ) -> torch_output_cat:
29 | try:
30 | var py = Python.import_module("builtins")
31 | var torch = Python.import_module("torch")
32 | var np = Python.import_module("numpy")
33 |
34 | var input_1 = torch.from_numpy(to_numpy(input_1)).requires_grad_(True)
35 | var input_2 = torch.from_numpy(to_numpy(input_2)).requires_grad_(True)
36 | var input_3 = torch.from_numpy(to_numpy(input_3)).requires_grad_(True)
37 |
38 | var expected: PythonObject
39 |
40 | var tensors = py.list()
41 | tensors.append(input_1)
42 | tensors.append(input_2)
43 | tensors.append(input_3)
44 | expected = torch.cat(tensors, dim=dim)
45 |
46 | # uppergrad & backwards
47 | var upper_grad = torch.from_numpy(to_numpy(upper_grad))
48 | _ = expected.backward(upper_grad)
49 |
50 | return torch_output_cat(
51 | to_tensor(expected.detach().numpy()),
52 | to_tensor(input_1.grad.numpy()),
53 | to_tensor(input_2.grad.numpy()),
54 | to_tensor(input_3.grad.numpy()),
55 | )
56 |
57 | except e:
58 | print("Error importing torch: ", e)
59 | var d = Tensor[dtype](1)
60 | return torch_output_cat(d, d, d, d)
61 |
62 |
63 | fn test_CONCAT() raises:
64 | alias t1_shape = TensorShape(11, 3, 17, 19)
65 | alias t2_shape = TensorShape(11, 3, 17, 19)
66 | alias t3_shape = TensorShape(11, 3, 17, 19)
67 | var t1 = Tensor[dtype](t1_shape)
68 | var t2 = Tensor[dtype](t2_shape)
69 | var t3 = Tensor[dtype](t3_shape)
70 | rand(t1.data(), t1.num_elements())
71 | rand(t2.data(), t2.num_elements())
72 | rand(t3.data(), t3.num_elements())
73 |
74 | # default: dim = 0
75 | alias graph = create_graph_concat(t1_shape, t2_shape, t3_shape, dim=0)
76 | var model = Model[graph]()
77 | var res = model.forward(t1, t2, t3)
78 |
79 | alias ug_shape = TensorShape(33, 3, 17, 19)
80 | var ug = Tensor[dtype](ug_shape)
81 | rand(ug.data(), ug.num_elements())
82 |
83 | var expected_and_grad = torch_cat(t1, t2, t3, ug, dim=0)
84 | model.backward(ug)
85 |
86 | assert_tensors_equal["almost"](res, expected_and_grad.expected)
87 | assert_tensors_equal["almost"](
88 | model.parameters.grads[graph.nodes[0].inputs[0]],
89 | expected_and_grad.grad_1,
90 | )
91 | assert_tensors_equal["almost"](
92 | model.parameters.grads[graph.nodes[0].inputs[1]],
93 | expected_and_grad.grad_2,
94 | )
95 | assert_tensors_equal["almost"](
96 | model.parameters.grads[graph.nodes[0].inputs[2]],
97 | expected_and_grad.grad_3,
98 | )
99 |
100 | # dim = 2
101 | alias graph_2 = create_graph_concat(t1_shape, t2_shape, t3_shape, dim=2)
102 | var model_2 = Model[graph_2]()
103 | var res_2 = model_2.forward(t1, t2, t3)
104 |
105 | alias ug_shape_2 = TensorShape(11, 3, 51, 19)
106 | var ug_2 = Tensor[dtype](ug_shape_2)
107 | rand(ug_2.data(), ug_2.num_elements())
108 |
109 | var expected_and_grad_2 = torch_cat(t1, t2, t3, ug_2, dim=2)
110 | model_2.backward(ug_2)
111 |
112 | assert_tensors_equal["almost"](res_2, expected_and_grad_2.expected)
113 | assert_tensors_equal["almost"](
114 | model_2.parameters.grads[graph_2.nodes[0].inputs[0]],
115 | expected_and_grad_2.grad_1,
116 | )
117 | assert_tensors_equal["almost"](
118 | model_2.parameters.grads[graph_2.nodes[0].inputs[1]],
119 | expected_and_grad_2.grad_2,
120 | )
121 | assert_tensors_equal["almost"](
122 | model_2.parameters.grads[graph_2.nodes[0].inputs[2]],
123 | expected_and_grad_2.grad_3,
124 | )
125 |
126 |
127 | @value
128 | struct torch_output_split:
129 | var expected1: Tensor[dtype]
130 | var expected2: Tensor[dtype]
131 | var expected3: Tensor[dtype]
132 | var grad: Tensor[dtype]
133 |
134 |
135 | fn torch_split(
136 | input: Tensor,
137 | upper_grad_1: Tensor,
138 | upper_grad_2: Tensor,
139 | upper_grad_3: Tensor,
140 | sections: List[Int],
141 | dim: Int,
142 | ) -> torch_output_split:
143 | try:
144 | var py = Python.import_module("builtins")
145 | var torch = Python.import_module("torch")
146 | var np = Python.import_module("numpy")
147 |
148 | var input = torch.from_numpy(to_numpy(input)).requires_grad_(True)
149 |
150 | var sizes = py.list()
151 | sizes.append(sections[0])
152 | sizes.append(sections[1])
153 | sizes.append(sections[2])
154 |
155 | var chunks: PythonObject = input.split(sizes, dim=dim)
156 |
157 | # uppergrad & backwards
158 | var upper_grad_1 = torch.from_numpy(to_numpy(upper_grad_1))
159 | var upper_grad_2 = torch.from_numpy(to_numpy(upper_grad_2))
160 | var upper_grad_3 = torch.from_numpy(to_numpy(upper_grad_3))
161 | _ = chunks[0].backward(upper_grad_1)
162 | _ = chunks[1].backward(upper_grad_2)
163 | _ = chunks[2].backward(upper_grad_3)
164 |
165 | return torch_output_split(
166 | to_tensor(chunks[0].detach().numpy()),
167 | to_tensor(chunks[1].detach().numpy()),
168 | to_tensor(chunks[2].detach().numpy()),
169 | to_tensor(input.grad.numpy()),
170 | )
171 |
172 | except e:
173 | print("Error importing torch: ", e)
174 | var d = Tensor[dtype](1)
175 | return torch_output_split(d, d, d, d)
176 |
177 |
178 | fn test_SPLIT() raises:
179 | alias t1_shape = TensorShape(11, 3, 17, 19)
180 | var t1 = Tensor[dtype](t1_shape)
181 | rand(t1.data(), t1.num_elements())
182 |
183 | # default: dim = 0
184 | alias sections = List[Int](3, 6, 2) # 11
185 | alias graph = create_graph_split(t1_shape, sections, dim=0)
186 | var model = Model[graph]()
187 | var results = model.inference(t1)
188 |
189 | alias ug1_shape = TensorShape(3, 3, 17, 19)
190 | alias ug2_shape = TensorShape(6, 3, 17, 19)
191 | alias ug3_shape = TensorShape(2, 3, 17, 19)
192 | var ug1 = Tensor[dtype](ug1_shape)
193 | var ug2 = Tensor[dtype](ug2_shape)
194 | var ug3 = Tensor[dtype](ug3_shape)
195 | rand(ug1.data(), ug1.num_elements())
196 | rand(ug2.data(), ug2.num_elements())
197 | rand(ug3.data(), ug3.num_elements())
198 |
199 | var expected_and_grad = torch_split(t1, ug1, ug2, ug3, sections, dim=0)
200 | model.backward(ug1, ug2, ug3)
201 |
202 | assert_tensors_equal["almost"](results[0], expected_and_grad.expected1)
203 | assert_tensors_equal["almost"](results[1], expected_and_grad.expected2)
204 | assert_tensors_equal["almost"](results[2], expected_and_grad.expected3)
205 | assert_tensors_equal["almost"](
206 | model.parameters.grads[graph.nodes[0].inputs[0]],
207 | expected_and_grad.grad,
208 | )
209 |
210 | # dim = 2
211 | alias sections_2 = List[Int](3, 6, 8) # 17
212 | alias graph_2 = create_graph_split(t1_shape, sections_2, dim=2)
213 | var model_2 = Model[graph_2]()
214 | var results_2 = model_2.inference(t1)
215 |
216 | alias ug1_shape_2 = TensorShape(11, 3, 3, 19)
217 | alias ug2_shape_2 = TensorShape(11, 3, 6, 19)
218 | alias ug3_shape_2 = TensorShape(11, 3, 8, 19)
219 | var ug1_2 = Tensor[dtype](ug1_shape_2)
220 | var ug2_2 = Tensor[dtype](ug2_shape_2)
221 | var ug3_2 = Tensor[dtype](ug3_shape_2)
222 | rand(ug1_2.data(), ug1_2.num_elements())
223 | rand(ug2_2.data(), ug2_2.num_elements())
224 | rand(ug3_2.data(), ug3_2.num_elements())
225 |
226 | var expected_and_grad_2 = torch_split(t1, ug1_2, ug2_2, ug3_2, sections_2, dim=2)
227 | model_2.backward(ug1_2, ug2_2, ug3_2)
228 |
229 | assert_tensors_equal["almost"](results_2[0], expected_and_grad_2.expected1)
230 | assert_tensors_equal["almost"](results_2[1], expected_and_grad_2.expected2)
231 | assert_tensors_equal["almost"](results_2[2], expected_and_grad_2.expected3)
232 | assert_tensors_equal["almost"](
233 | model_2.parameters.grads[graph_2.nodes[0].inputs[0]], expected_and_grad_2.grad
234 | )
235 |
236 |
237 | fn main():
238 | print("Running dynamic ops (compare with torch) tests")
239 | try:
240 | test_CONCAT()
241 | test_SPLIT()
242 | except e:
243 | print("[ERROR] Error in dynamic ops (compare with torch)")
244 | print(e)
245 | return
246 |
247 | print("Finished dynamic ops (compare with torch) tests")
248 |
--------------------------------------------------------------------------------
/tests/python/test_models_mnist.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from python import Python
3 | from testing import assert_almost_equal
4 | from utils.index import IndexList
5 |
6 | from basalt import dtype
7 | from basalt.autograd import Graph, OP
8 | from basalt.autograd.attributes import AttributeVector, Attribute
9 | from basalt.nn import (
10 | Tensor,
11 | TensorShape,
12 | Model,
13 | ReLU,
14 | MaxPool2d,
15 | CrossEntropyLoss,
16 | optim,
17 | )
18 | from basalt.autograd.params import Param
19 |
20 | from tests import assert_tensors_equal, to_numpy, to_tensor
21 |
22 |
23 | fn create_CNN(
24 | batch_size: Int,
25 | conv1_weights: List[Scalar[dtype]],
26 | conv1_bias: List[Scalar[dtype]],
27 | conv2_weights: List[Scalar[dtype]],
28 | conv2_bias: List[Scalar[dtype]],
29 | linear1_weights: List[Scalar[dtype]],
30 | linear1_bias: List[Scalar[dtype]],
31 | ) -> Graph:
32 | var g = Graph()
33 | var x = g.input(TensorShape(batch_size, 1, 28, 28))
34 |
35 | # conv1
36 | # var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2)
37 | var c1_w = g.param(TensorShape(16, x.shape[1], 5, 5), init=Param(conv1_weights))
38 | var c1_b = g.param(TensorShape(16), init=Param(conv1_bias))
39 | var x1 = g.op(
40 | OP.CONV2D,
41 | x,
42 | c1_w,
43 | c1_b,
44 | attributes=AttributeVector(
45 | Attribute("padding", IndexList[2](2, 2)),
46 | Attribute("stride", IndexList[2](1, 1)),
47 | Attribute("dilation", IndexList[2](1, 1)),
48 | ),
49 | )
50 |
51 | var x2 = ReLU(g, x1)
52 | var x3 = MaxPool2d(g, x2, kernel_size=2)
53 |
54 | # conv2
55 | # var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2)
56 | var c2_w = g.param(TensorShape(32, x3.shape[1], 5, 5), init=Param(conv2_weights))
57 | var c2_b = g.param(TensorShape(32), init=Param(conv2_bias))
58 | var x4 = g.op(
59 | OP.CONV2D,
60 | x3,
61 | c2_w,
62 | c2_b,
63 | attributes=AttributeVector(
64 | Attribute("padding", IndexList[2](2, 2)),
65 | Attribute("stride", IndexList[2](1, 1)),
66 | Attribute("dilation", IndexList[2](1, 1)),
67 | ),
68 | )
69 |
70 | var x5 = ReLU(g, x4)
71 | var x6 = MaxPool2d(g, x5, kernel_size=2)
72 | var x6_shape = x6.shape
73 | var x7 = g.op(
74 | OP.RESHAPE,
75 | x6,
76 | attributes=AttributeVector(
77 | Attribute(
78 | "shape",
79 | TensorShape(x6_shape[0], x6_shape[1] * x6_shape[2] * x6_shape[3]),
80 | )
81 | ),
82 | )
83 |
84 | # linear1
85 | # var out = nn.Linear(g, x7, n_outputs=10)
86 | var l1_w = g.param(TensorShape(x7.shape[1], 10), init=Param(linear1_weights))
87 | var l1_b = g.param(TensorShape(10), init=Param(linear1_bias))
88 | var res = g.op(OP.DOT, x7, l1_w)
89 | var out = g.op(OP.ADD, res, l1_b)
90 | g.out(out)
91 |
92 | var y_true = g.input(TensorShape(batch_size, 10))
93 | var loss = CrossEntropyLoss(g, out, y_true)
94 | # var loss = nn.MSELoss(g, out, y_true)
95 | g.loss(loss)
96 |
97 | return g ^
98 |
99 |
100 | fn run_mojo[
101 | batch_size: Int,
102 | conv1_weights: List[Scalar[dtype]],
103 | conv1_bias: List[Scalar[dtype]],
104 | conv2_weights: List[Scalar[dtype]],
105 | conv2_bias: List[Scalar[dtype]],
106 | linear1_weights: List[Scalar[dtype]],
107 | linear1_bias: List[Scalar[dtype]],
108 | ](
109 | epochs: Int,
110 | learning_rate: Float64,
111 | inputs: Tensor[dtype],
112 | labels: Tensor[dtype],
113 | ) -> List[Scalar[dtype]]:
114 | alias graph = create_CNN(
115 | batch_size,
116 | conv1_weights,
117 | conv1_bias,
118 | conv2_weights,
119 | conv2_bias,
120 | linear1_weights,
121 | linear1_bias,
122 | )
123 |
124 | var model = Model[graph]()
125 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]())
126 |
127 | var losses = List[Scalar[dtype]]()
128 |
129 | for i in range(epochs):
130 | var loss = model.forward(inputs, labels)
131 |
132 | # Backward pass
133 | optim.zero_grad()
134 | model.backward()
135 | optim.step()
136 |
137 | losses.append(loss[0])
138 |
139 | return losses
140 |
141 |
142 | fn run_torch(
143 | epochs: Int,
144 | learning_rate: Float64,
145 | inputs: Tensor,
146 | labels: Tensor,
147 | owned conv1_weights: Tensor,
148 | owned conv1_bias: Tensor,
149 | owned conv2_weights: Tensor,
150 | owned conv2_bias: Tensor,
151 | owned linear1_weights: Tensor,
152 | owned linear1_bias: Tensor,
153 | ) -> List[Scalar[dtype]]:
154 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]()
155 |
156 | try:
157 | var torch = Python.import_module("torch")
158 | var F = Python.import_module("torch.nn.functional")
159 | var np = Python.import_module("numpy")
160 | Python.add_to_path("./tests/python")
161 | var torch_models = Python.import_module("test_models_torch")
162 |
163 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True)
164 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True)
165 |
166 | var conv1_weights = torch.from_numpy(to_numpy(conv1_weights)).requires_grad_(
167 | True
168 | )
169 | var conv1_bias = torch.from_numpy(to_numpy(conv1_bias)).requires_grad_(True)
170 | var conv2_weights = torch.from_numpy(to_numpy(conv2_weights)).requires_grad_(
171 | True
172 | )
173 | var conv2_bias = torch.from_numpy(to_numpy(conv2_bias)).requires_grad_(True)
174 | var linear1_weights = torch.from_numpy(
175 | to_numpy(linear1_weights)
176 | ).requires_grad_(True)
177 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True)
178 |
179 | var cnn = torch_models.CNN(
180 | conv1_weights,
181 | conv1_bias,
182 | conv2_weights,
183 | conv2_bias,
184 | linear1_weights,
185 | linear1_bias,
186 | )
187 |
188 | var loss_func = torch_models.CrossEntropyLoss2()
189 | # var loss_func = torch.nn.CrossEntropyLoss()
190 | var optimizer = torch.optim.Adam(cnn.parameters(), learning_rate)
191 |
192 | for i in range(epochs):
193 | var output = cnn.forward(inputs)
194 | var loss = loss_func(output, labels)
195 |
196 | _ = optimizer.zero_grad()
197 | _ = loss.backward()
198 | _ = optimizer.step()
199 |
200 | out.append(to_tensor(loss)[0])
201 |
202 | return out
203 |
204 | except e:
205 | print("Error importing torch")
206 | print(e)
207 | return out
208 |
209 |
210 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]:
211 | var weights = List[Scalar[dtype]](capacity=num_elements)
212 | for i in range(num_elements):
213 | if zero:
214 | weights.append(Scalar[dtype](0.0))
215 | else:
216 | weights.append(Scalar[dtype](0.02))
217 | return weights ^
218 |
219 |
220 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]:
221 | var t = Tensor[dtype](shape)
222 | if t.num_elements() != len(dv):
223 | print("[WARNING] tensor and dv not the shame shape")
224 | for i in range(t.num_elements()):
225 | t[i] = dv[i]
226 | return t ^
227 |
228 |
229 | fn main():
230 | alias learning_rate = 1e-3
231 | alias epochs = 100
232 | alias batch_size = 4
233 |
234 | var inputs = Tensor[dtype](batch_size, 1, 28, 28)
235 | rand[dtype](inputs.data(), inputs.num_elements())
236 | var labels = Tensor[dtype](batch_size, 10) # one-hot encoded (probabilities)
237 | for i in range(4):
238 | labels[i * 10 + i] = 1.0
239 |
240 | alias cv1_w_shape = TensorShape(16, 1, 5, 5)
241 | alias conv1_weights = create_weights(cv1_w_shape.num_elements(), zero=False)
242 | alias cv1_b_shape = TensorShape(16)
243 | alias conv1_bias = create_weights(16, zero=True)
244 |
245 | alias cv2_w_shape = TensorShape(32, 16, 5, 5)
246 | alias conv2_weights = create_weights(cv2_w_shape.num_elements(), zero=False)
247 | alias cv2_b_shape = TensorShape(32)
248 | alias conv2_bias = create_weights(32, zero=True)
249 |
250 | alias l1_w_shape = TensorShape(32 * 7 * 7, 10)
251 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False)
252 | alias l1_b_shape = TensorShape(10)
253 | alias linear1_bias = create_weights(10, zero=True)
254 |
255 | var losses_mojo = run_mojo[
256 | batch_size,
257 | conv1_weights,
258 | conv1_bias,
259 | conv2_weights,
260 | conv2_bias,
261 | linear1_weights,
262 | linear1_bias,
263 | ](
264 | epochs,
265 | learning_rate,
266 | inputs,
267 | labels,
268 | )
269 |
270 | var losses_torch = run_torch(
271 | epochs,
272 | learning_rate,
273 | inputs,
274 | labels,
275 | dv_to_tensor(conv1_weights, cv1_w_shape),
276 | dv_to_tensor(conv1_bias, cv1_b_shape),
277 | dv_to_tensor(conv2_weights, cv2_w_shape),
278 | dv_to_tensor(conv2_bias, cv2_b_shape),
279 | dv_to_tensor(linear1_weights, l1_w_shape),
280 | dv_to_tensor(linear1_bias, l1_b_shape),
281 | )
282 |
283 | for i in range(epochs):
284 | print("loss_mojo: ", losses_mojo[i], " loss_torch: ", losses_torch[i])
285 |
286 | for i in range(epochs):
287 | var loss_mojo = losses_mojo[i]
288 | var loss_torch = losses_torch[i]
289 | print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch)
290 | try:
291 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-5)
292 | except e:
293 | print("Losses not equal")
294 | print(e)
295 | break
296 |
--------------------------------------------------------------------------------
/tests/python/test_models_regression.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from python import Python
3 | from utils.numerics import max_finite
4 | from testing import assert_almost_equal
5 |
6 | from basalt import dtype
7 | from basalt.autograd import Graph, OP
8 | from basalt.nn import Tensor, TensorShape, Model, MSELoss, optim
9 | from basalt.utils.rand_utils import MersenneTwister
10 | from basalt.autograd.params import Param
11 |
12 | from tests import to_numpy, to_tensor
13 |
14 |
15 | fn create_linear_regression(
16 | batch_size: Int,
17 | n_outputs: Int,
18 | linear1_weights: List[Scalar[dtype]],
19 | linear1_bias: List[Scalar[dtype]],
20 | ) -> Graph:
21 | var g = Graph()
22 | var x = g.input(TensorShape(batch_size, 13))
23 |
24 | # linear1
25 | # var out = nn.Linear(g, x, n_outputs=1)
26 | var l1_w = g.param(TensorShape(13, n_outputs), init=Param(linear1_weights))
27 | var l1_b = g.param(TensorShape(n_outputs), init=Param(linear1_bias))
28 | var res = g.op(OP.DOT, x, l1_w)
29 | var out = g.op(OP.ADD, res, l1_b)
30 | g.out(out)
31 |
32 | var y_true = g.input(TensorShape(batch_size, n_outputs))
33 | var loss = MSELoss(g, out, y_true)
34 | g.loss(loss)
35 |
36 | return g ^
37 |
38 |
39 | fn run_mojo[
40 | batch_size: Int,
41 | n_outputs: Int,
42 | linear1_weights: List[Scalar[dtype]],
43 | linear1_bias: List[Scalar[dtype]],
44 | ](
45 | epochs: Int,
46 | learning_rate: Float64,
47 | inputs: Tensor[dtype],
48 | labels: Tensor[dtype],
49 | ) -> List[Scalar[dtype]]:
50 | alias graph = create_linear_regression(
51 | batch_size,
52 | n_outputs,
53 | linear1_weights,
54 | linear1_bias,
55 | )
56 |
57 | var model = Model[graph]()
58 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]())
59 |
60 | var losses = List[Scalar[dtype]]()
61 |
62 | for i in range(epochs):
63 | var loss = model.forward(inputs, labels)
64 |
65 | # Backward pass
66 | optim.zero_grad()
67 | model.backward()
68 | optim.step()
69 |
70 | losses.append(loss[0])
71 |
72 | return losses
73 |
74 |
75 | fn run_torch(
76 | epochs: Int,
77 | learning_rate: Float64,
78 | inputs: Tensor,
79 | labels: Tensor,
80 | owned linear1_weights: Tensor,
81 | owned linear1_bias: Tensor,
82 | ) -> List[Scalar[dtype]]:
83 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]()
84 |
85 | try:
86 | var torch = Python.import_module("torch")
87 | var F = Python.import_module("torch.nn.functional")
88 | var np = Python.import_module("numpy")
89 | Python.add_to_path("./tests/python")
90 | var torch_models = Python.import_module("test_models_torch")
91 |
92 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True)
93 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True)
94 |
95 | var linear1_weights = torch.from_numpy(
96 | to_numpy(linear1_weights)
97 | ).requires_grad_(True)
98 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True)
99 |
100 | var regression = torch_models.LinearRegression(
101 | linear1_weights,
102 | linear1_bias,
103 | )
104 |
105 | var loss_func = torch_models.MSELoss()
106 | var optimizer = torch.optim.Adam(regression.parameters(), learning_rate)
107 |
108 | for i in range(epochs):
109 | var output = regression.forward(inputs)
110 | var loss = loss_func(output, labels)
111 |
112 | _ = optimizer.zero_grad()
113 | _ = loss.backward()
114 | _ = optimizer.step()
115 |
116 | out.append(to_tensor(loss)[0].cast[dtype]())
117 |
118 | return out
119 |
120 | except e:
121 | print("Error importing torch")
122 | print(e)
123 | return out
124 |
125 |
126 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]:
127 | var prng = MersenneTwister(123456)
128 | var weights = List[Scalar[dtype]](capacity=num_elements)
129 | for i in range(num_elements):
130 | if zero:
131 | weights.append(Scalar[dtype](0.0))
132 | else:
133 | var rand_float = prng.next().cast[dtype]() / max_finite[DType.int32]().cast[
134 | dtype
135 | ]()
136 | weights.append(Scalar[dtype](rand_float / 10))
137 | return weights ^
138 |
139 |
140 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]:
141 | var t = Tensor[dtype](shape)
142 | if t.num_elements() != len(dv):
143 | print("[WARNING] tensor and dv not the shame shape")
144 | for i in range(t.num_elements()):
145 | t[i] = dv[i]
146 | return t ^
147 |
148 |
149 | fn main():
150 | alias learning_rate = 1e-3
151 | alias epochs = 100
152 | alias batch_size = 64
153 | alias n_outputs = 10
154 |
155 | var inputs = Tensor[dtype](batch_size, 13)
156 | rand[dtype](inputs.data(), inputs.num_elements())
157 | var labels = Tensor[dtype](batch_size, n_outputs)
158 | for i in range(batch_size):
159 | for j in range(n_outputs):
160 | labels[i * n_outputs + j] = 1
161 |
162 | alias l1_w_shape = TensorShape(13, n_outputs)
163 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False)
164 | alias l1_b_shape = TensorShape(n_outputs)
165 | alias linear1_bias = create_weights(l1_b_shape.num_elements(), zero=False)
166 |
167 | var losses_mojo = run_mojo[batch_size, n_outputs, linear1_weights, linear1_bias,](
168 | epochs,
169 | learning_rate,
170 | inputs,
171 | labels,
172 | )
173 |
174 | var losses_torch = run_torch(
175 | epochs,
176 | learning_rate,
177 | inputs,
178 | labels,
179 | dv_to_tensor(linear1_weights, l1_w_shape),
180 | dv_to_tensor(linear1_bias, l1_b_shape),
181 | )
182 |
183 | var success = True
184 | for i in range(epochs):
185 | var loss_mojo = losses_mojo[i]
186 | var loss_torch = losses_torch[i]
187 | # print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch)
188 | try:
189 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-4)
190 | except e:
191 | print("Losses not equal")
192 | print(e)
193 | success = False
194 | break
195 |
196 | if success:
197 | print("SUCCESS: All losses in Linear Regression model are equal.")
198 |
--------------------------------------------------------------------------------
/tests/python/test_models_sin_estimate.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from python import Python
3 | from utils.numerics import max_finite
4 | from testing import assert_almost_equal
5 | import math
6 |
7 | from basalt import dtype
8 | from basalt.autograd import Graph, OP
9 | from basalt.nn import Tensor, TensorShape, Model, ReLU, MSELoss, optim
10 | from basalt.utils.rand_utils import MersenneTwister
11 | from basalt.autograd.params import Param
12 |
13 | from tests import to_numpy, to_tensor
14 |
15 |
16 | fn create_simple_nn(
17 | batch_size: Int,
18 | linear1_weights: List[Scalar[dtype]],
19 | linear1_bias: List[Scalar[dtype]],
20 | linear2_weights: List[Scalar[dtype]],
21 | linear2_bias: List[Scalar[dtype]],
22 | linear3_weights: List[Scalar[dtype]],
23 | linear3_bias: List[Scalar[dtype]],
24 | ) -> Graph:
25 | var g = Graph()
26 |
27 | var x = g.input(TensorShape(batch_size, 1))
28 | var y_true = g.input(TensorShape(batch_size, 1))
29 |
30 | # Linear 1: nn.Linear(g, x, n_outputs=32)
31 | var l1_w = g.param(TensorShape(1, 32), init=Param(linear1_weights))
32 | var l1_b = g.param(TensorShape(32), init=Param(linear1_bias))
33 | var res_1 = g.op(OP.DOT, x, l1_w)
34 | var x1 = g.op(OP.ADD, res_1, l1_b)
35 |
36 | # ReLU 1
37 | var x2 = ReLU(g, x1)
38 |
39 | # Linear 2: nn.Linear(g, x2, n_outputs=32)
40 | var l2_w = g.param(TensorShape(32, 32), init=Param(linear2_weights))
41 | var l2_b = g.param(TensorShape(32), init=Param(linear2_bias))
42 | var res_2 = g.op(OP.DOT, x2, l2_w)
43 | var x3 = g.op(OP.ADD, res_2, l2_b)
44 |
45 | # ReLU 2
46 | var x4 = ReLU(g, x3)
47 |
48 | # Linear 3: nn.Linear(g, x4, n_outputs=1)
49 | var l3_w = g.param(TensorShape(32, 1), init=Param(linear3_weights))
50 | var l3_b = g.param(TensorShape(1), init=Param(linear3_bias))
51 | var res_3 = g.op(OP.DOT, x4, l3_w)
52 | var y_pred = g.op(OP.ADD, res_3, l3_b)
53 | g.out(y_pred)
54 |
55 | var loss = MSELoss(g, y_pred, y_true)
56 | g.loss(loss)
57 |
58 | return g ^
59 |
60 |
61 | fn run_mojo[
62 | batch_size: Int,
63 | linear1_weights: List[Scalar[dtype]],
64 | linear1_bias: List[Scalar[dtype]],
65 | linear2_weights: List[Scalar[dtype]],
66 | linear2_bias: List[Scalar[dtype]],
67 | linear3_weights: List[Scalar[dtype]],
68 | linear3_bias: List[Scalar[dtype]],
69 | ](
70 | epochs: Int,
71 | learning_rate: Float64,
72 | inputs: Tensor[dtype],
73 | labels: Tensor[dtype],
74 | ) -> List[Scalar[dtype]]:
75 | alias graph = create_simple_nn(
76 | batch_size,
77 | linear1_weights,
78 | linear1_bias,
79 | linear2_weights,
80 | linear2_bias,
81 | linear3_weights,
82 | linear3_bias,
83 | )
84 |
85 | var model = Model[graph]()
86 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]())
87 |
88 | var losses = List[Scalar[dtype]]()
89 |
90 | for i in range(epochs):
91 | var loss = model.forward(inputs, labels)
92 |
93 | # Backward pass
94 | optim.zero_grad()
95 | model.backward()
96 | optim.step()
97 |
98 | losses.append(loss[0])
99 |
100 | return losses
101 |
102 |
103 | fn run_torch(
104 | epochs: Int,
105 | learning_rate: Float64,
106 | inputs: Tensor,
107 | labels: Tensor,
108 | owned linear1_weights: Tensor,
109 | owned linear1_bias: Tensor,
110 | owned linear2_weights: Tensor,
111 | owned linear2_bias: Tensor,
112 | owned linear3_weights: Tensor,
113 | owned linear3_bias: Tensor,
114 | ) -> List[Scalar[dtype]]:
115 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]()
116 |
117 | try:
118 | var torch = Python.import_module("torch")
119 | var F = Python.import_module("torch.nn.functional")
120 | var np = Python.import_module("numpy")
121 | Python.add_to_path("./tests/python")
122 | var torch_models = Python.import_module("test_models_torch")
123 |
124 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True)
125 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True)
126 |
127 | var linear1_weights = torch.from_numpy(
128 | to_numpy(linear1_weights)
129 | ).requires_grad_(True)
130 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True)
131 | var linear2_weights = torch.from_numpy(
132 | to_numpy(linear2_weights)
133 | ).requires_grad_(True)
134 | var linear2_bias = torch.from_numpy(to_numpy(linear2_bias)).requires_grad_(True)
135 | var linear3_weights = torch.from_numpy(
136 | to_numpy(linear3_weights)
137 | ).requires_grad_(True)
138 | var linear3_bias = torch.from_numpy(to_numpy(linear3_bias)).requires_grad_(True)
139 |
140 | var regression = torch_models.SimpleNN(
141 | linear1_weights,
142 | linear1_bias,
143 | linear2_weights,
144 | linear2_bias,
145 | linear3_weights,
146 | linear3_bias,
147 | )
148 |
149 | var loss_func = torch_models.MSELoss()
150 | var optimizer = torch.optim.Adam(regression.parameters(), learning_rate)
151 |
152 | for i in range(epochs):
153 | var output = regression.forward(inputs)
154 | var loss = loss_func(output, labels)
155 |
156 | _ = optimizer.zero_grad()
157 | _ = loss.backward()
158 | _ = optimizer.step()
159 |
160 | out.append(to_tensor(loss)[0].cast[dtype]())
161 |
162 | return out
163 |
164 | except e:
165 | print("Error importing torch")
166 | print(e)
167 | return out
168 |
169 |
170 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]:
171 | var prng = MersenneTwister(123456)
172 | var weights = List[Scalar[dtype]](capacity=num_elements)
173 | for i in range(num_elements):
174 | if zero:
175 | weights.append(Scalar[dtype](0.0))
176 | else:
177 | var rand_float = prng.next().cast[dtype]() / max_finite[DType.int32]().cast[
178 | dtype
179 | ]()
180 | weights.append(Scalar[dtype](rand_float / 10))
181 | return weights ^
182 |
183 |
184 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]:
185 | var t = Tensor[dtype](shape)
186 | if t.num_elements() != len(dv):
187 | print("[WARNING] tensor and dv not the shame shape")
188 | for i in range(t.num_elements()):
189 | t[i] = dv[i]
190 | return t ^
191 |
192 |
193 | fn main():
194 | alias learning_rate = 1e-3
195 | alias epochs = 100
196 | alias batch_size = 64
197 | alias n_outputs = 10
198 |
199 | var x_data = Tensor[dtype](batch_size, 1)
200 | rand[dtype](x_data.data(), x_data.num_elements())
201 | var y_data = Tensor[dtype](batch_size, 1)
202 | for j in range(batch_size):
203 | x_data[j] = x_data[j] * 2 - 1
204 | y_data[j] = math.sin(x_data[j])
205 |
206 | alias l1_w_shape = TensorShape(1, 32)
207 | alias l1_b_shape = TensorShape(32)
208 | alias l2_w_shape = TensorShape(32, 32)
209 | alias l2_b_shape = TensorShape(32)
210 | alias l3_w_shape = TensorShape(32, 1)
211 | alias l3_b_shape = TensorShape(1)
212 |
213 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False)
214 | alias linear1_bias = create_weights(l1_b_shape.num_elements(), zero=False)
215 | alias linear2_weights = create_weights(l2_w_shape.num_elements(), zero=False)
216 | alias linear2_bias = create_weights(l2_b_shape.num_elements(), zero=False)
217 | alias linear3_weights = create_weights(l3_w_shape.num_elements(), zero=False)
218 | alias linear3_bias = create_weights(l3_b_shape.num_elements(), zero=False)
219 |
220 | var losses_mojo = run_mojo[
221 | batch_size,
222 | linear1_weights,
223 | linear1_bias,
224 | linear2_weights,
225 | linear2_bias,
226 | linear3_weights,
227 | linear3_bias,
228 | ](epochs, learning_rate, x_data, y_data)
229 |
230 | var losses_torch = run_torch(
231 | epochs,
232 | learning_rate,
233 | x_data,
234 | y_data,
235 | dv_to_tensor(linear1_weights, l1_w_shape),
236 | dv_to_tensor(linear1_bias, l1_b_shape),
237 | dv_to_tensor(linear2_weights, l2_w_shape),
238 | dv_to_tensor(linear2_bias, l2_b_shape),
239 | dv_to_tensor(linear3_weights, l3_w_shape),
240 | dv_to_tensor(linear3_bias, l3_b_shape),
241 | )
242 |
243 | var success = True
244 | for i in range(epochs):
245 | var loss_mojo = losses_mojo[i]
246 | var loss_torch = losses_torch[i]
247 | # print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch)
248 | try:
249 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-4)
250 | except e:
251 | print("Losses not equal")
252 | print(e)
253 | success = False
254 | break
255 |
256 | if success:
257 | print("SUCCESS: All losses in Sin estimate model are equal.")
258 |
--------------------------------------------------------------------------------
/tests/python/test_models_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LinearRegression(nn.Module):
7 | def __init__(self, linear1_weights, linear1_bias):
8 | super(LinearRegression, self).__init__()
9 |
10 | self.linear1_weights = nn.Parameter(linear1_weights)
11 | self.linear1_bias = nn.Parameter(linear1_bias)
12 |
13 | def forward(self, x):
14 | output = F.linear(x, self.linear1_weights.T, self.linear1_bias)
15 | return output
16 |
17 |
18 | class MSELoss(nn.Module):
19 | def __init__(self):
20 | super(MSELoss, self).__init__()
21 |
22 | def forward(self, output, target):
23 | loss = F.mse_loss(output, target)
24 | return loss
25 |
26 |
27 | class CrossEntropyLoss(nn.Module):
28 | def __init__(self):
29 | super(CrossEntropyLoss, self).__init__()
30 |
31 | def forward(self, output, target):
32 | loss = -torch.sum(target * torch.log(output)) / output.size(0)
33 | return loss
34 |
35 |
36 | # Implement the class for crossentropy loss with logsoftmax
37 | class CrossEntropyLoss2(nn.Module):
38 | def __init__(self):
39 | super(CrossEntropyLoss2, self).__init__()
40 |
41 | def forward(self, output, target):
42 | loss = -torch.sum(target * F.log_softmax(output, dim=1)) / output.size(0)
43 | return loss
44 |
45 |
46 | class CNN(nn.Module):
47 | def __init__(
48 | self,
49 | conv1_weights,
50 | conv1_bias,
51 | conv2_weights,
52 | conv2_bias,
53 | linear1_weights,
54 | linear1_bias,
55 | ):
56 | super(CNN, self).__init__()
57 |
58 | self.conv1_weights = nn.Parameter(conv1_weights)
59 | self.conv1_bias = nn.Parameter(conv1_bias)
60 | self.conv2_weights = nn.Parameter(conv2_weights)
61 | self.conv2_bias = nn.Parameter(conv2_bias)
62 | self.linear1_weights = nn.Parameter(linear1_weights)
63 | self.linear1_bias = nn.Parameter(linear1_bias)
64 |
65 | def forward(self, x):
66 | x = F.conv2d(x, self.conv1_weights, self.conv1_bias, stride=1, padding=2)
67 | x = F.relu(x)
68 | x = F.max_pool2d(x, 2)
69 | x = F.conv2d(x, self.conv2_weights, self.conv2_bias, stride=1, padding=2)
70 | x = F.relu(x)
71 | x = F.max_pool2d(x, 2)
72 | x = x.view(x.size(0), -1)
73 | output = F.linear(x, self.linear1_weights.T, self.linear1_bias)
74 | return output
75 |
76 | def print_grads(self):
77 | print("\nCONV1 WEIGHTS", self.conv1_weights.grad.shape)
78 | print(self.conv1_weights.grad)
79 |
80 | print("\nCONV1 BIAS", self.conv1_bias.grad.shape)
81 | print(self.conv1_bias.grad)
82 |
83 | print("\nCONV2 WEIGHTS", self.conv2_weights.grad.shape)
84 | print(self.conv2_weights.grad)
85 |
86 | print("\nCONV2 BIAS", self.conv2_bias.grad.shape)
87 | print(self.conv2_bias.grad)
88 |
89 | print("\nLINEAR1 WEIGHTS", self.linear1_weights.grad.shape)
90 | print(self.linear1_weights.grad)
91 |
92 | print("\nLINEAR1 BIAS", self.linear1_bias.grad.shape)
93 | print(self.linear1_bias.grad)
94 |
95 |
96 | class SimpleNN(nn.Module):
97 | def __init__(
98 | self,
99 | linear1_weights,
100 | linear1_bias,
101 | linear2_weights,
102 | linear2_bias,
103 | linear3_weights,
104 | linear3_bias,
105 | ):
106 | super(SimpleNN, self).__init__()
107 |
108 | self.linear1_weights = nn.Parameter(linear1_weights)
109 | self.linear1_bias = nn.Parameter(linear1_bias)
110 | self.linear2_weights = nn.Parameter(linear2_weights)
111 | self.linear2_bias = nn.Parameter(linear2_bias)
112 | self.linear3_weights = nn.Parameter(linear3_weights)
113 | self.linear3_bias = nn.Parameter(linear3_bias)
114 |
115 | self.relu1 = nn.ReLU()
116 | self.relu2 = nn.ReLU()
117 |
118 | def forward(self, x):
119 | x1 = F.linear(x, self.linear1_weights.T, self.linear1_bias)
120 | x2 = self.relu1(x1)
121 | x3 = F.linear(x2, self.linear2_weights.T, self.linear2_bias)
122 | x4 = self.relu2(x3)
123 | y_pred = F.linear(x4, self.linear3_weights.T, self.linear3_bias)
124 | return y_pred
125 |
--------------------------------------------------------------------------------
/tests/python/test_pool.mojo:
--------------------------------------------------------------------------------
1 | from random import rand
2 | from python.python import Python
3 | from testing import assert_equal
4 | from utils.index import IndexList
5 |
6 | from basalt import dtype, nelts
7 | from basalt.autograd import Graph, OP
8 | from basalt.autograd.ops.pool import MAXPOOL2D
9 | from basalt.autograd.ops.conv import get_result_shape
10 | from basalt.autograd.attributes import Attribute, AttributeVector
11 | from basalt.nn import Tensor, TensorShape, Model
12 |
13 | from tests import assert_tensors_equal, to_numpy, to_tensor
14 |
15 |
16 | @value
17 | struct torch_maxpool2d_output:
18 | var expected: Tensor[dtype]
19 | var expected_grad: Tensor[dtype]
20 |
21 |
22 | fn torch_maxpool2d(
23 | inputs: Tensor,
24 | kernel_size: IndexList[2],
25 | padding: IndexList[2],
26 | stride: IndexList[2],
27 | dilation: IndexList[2],
28 | upper_grad: Tensor,
29 | ) -> torch_maxpool2d_output:
30 | var out: torch_maxpool2d_output
31 |
32 | try:
33 | var torch = Python.import_module("torch")
34 | var F = Python.import_module("torch.nn.functional")
35 | var np = Python.import_module("numpy")
36 |
37 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True)
38 |
39 | var expected = F.max_pool2d(
40 | inputs,
41 | (kernel_size[0], kernel_size[1]),
42 | (stride[0], stride[1]),
43 | (padding[0], padding[1]),
44 | (dilation[0], dilation[1]),
45 | )
46 |
47 | # uppergrad & backwards
48 | var upper_grad = torch.from_numpy(to_numpy(upper_grad))
49 | _ = expected.backward(upper_grad)
50 |
51 | # expected
52 | out = torch_maxpool2d_output(
53 | to_tensor(expected.detach().numpy()), to_tensor(inputs.grad.numpy())
54 | )
55 | return out
56 |
57 | except:
58 | print("Error in torch_maxpool2d")
59 | var d = Tensor[dtype](1)
60 | var out = torch_maxpool2d_output(d, d)
61 | return out
62 |
63 |
64 | fn test_pool_forward[
65 | input_shape: TensorShape,
66 | kernel_size: IndexList[2],
67 | padding: IndexList[2],
68 | stride: IndexList[2],
69 | dilation: IndexList[2],
70 | ](inputs: Tensor[dtype]) raises:
71 | fn create_graph() -> Graph:
72 | var g = Graph()
73 | var inp = g.input(input_shape)
74 |
75 | var res = g.op(
76 | OP.MAXPOOL2D,
77 | inp,
78 | attributes=AttributeVector(
79 | Attribute("kernel_size", kernel_size),
80 | Attribute("padding", padding),
81 | Attribute("stride", stride),
82 | Attribute("dilation", dilation),
83 | ),
84 | )
85 | g.out(res)
86 |
87 | return g ^
88 |
89 | alias graph = create_graph()
90 | assert_equal(len(graph.nodes), 1)
91 |
92 | var model = Model[graph](inference_only=True)
93 | var res = model.inference(inputs)[0]
94 |
95 | var torch_out = torch_maxpool2d(
96 | inputs,
97 | kernel_size=kernel_size,
98 | padding=padding,
99 | stride=stride,
100 | dilation=dilation,
101 | upper_grad=Tensor[dtype](res.shape()),
102 | )
103 |
104 | assert_tensors_equal(res, torch_out.expected)
105 |
106 |
107 | fn test_forward_1() raises:
108 | # padding=2, stride=1, dilation=1
109 | # input shape: (4, 1, 28, 28) kernel size: (5, 5)
110 | alias kernel_size = 5
111 | alias padding = 2
112 | alias stride = 1
113 | alias dilation = 1
114 | alias input_shape = TensorShape(4, 1, 28, 28)
115 | var inputs = Tensor[dtype](input_shape)
116 | rand[dtype](inputs.data(), inputs.num_elements())
117 |
118 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs)
119 |
120 |
121 | fn test_forward_2() raises:
122 | # padding=0, stride=1, dilation=1
123 | # input shape: (4, 1, 32, 17) kernel size: (2, 2)
124 | alias kernel_size = IndexList[2](2, 2)
125 | alias padding = 0
126 | alias stride = 1
127 | alias dilation = 1
128 | alias input_shape = TensorShape(4, 1, 32, 17)
129 | var inputs = Tensor[dtype](input_shape)
130 | rand[dtype](inputs.data(), inputs.num_elements())
131 |
132 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs)
133 |
134 |
135 | fn test_forward_3() raises:
136 | # padding=(3, 1), stride=(2, 3), dilation=(2, 3)
137 | # input shape: (4, 3, 32, 17) kernel size: (6, 6)
138 | alias kernel_size = IndexList[2](6, 6)
139 | alias padding = IndexList[2](3, 1)
140 | alias stride = IndexList[2](2, 3)
141 | alias dilation = IndexList[2](2, 3)
142 | alias input_shape = TensorShape(4, 3, 32, 17)
143 | var inputs = Tensor[dtype](input_shape)
144 | rand[dtype](inputs.data(), inputs.num_elements())
145 |
146 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs)
147 |
148 |
149 | fn test_pool_backward[
150 | ug_shape: TensorShape,
151 | input_shape: TensorShape,
152 | kernel_size: IndexList[2],
153 | padding: IndexList[2],
154 | stride: IndexList[2],
155 | dilation: IndexList[2],
156 | ](ug: Tensor[dtype], inputs: Tensor[dtype]) raises:
157 | alias attributes = AttributeVector(
158 | Attribute("kernel_size", kernel_size),
159 | Attribute("padding", padding),
160 | Attribute("stride", stride),
161 | Attribute("dilation", dilation),
162 | )
163 |
164 | var grad = MAXPOOL2D.backward[ug_shape, input_shape, attributes](ug, inputs)
165 |
166 | var torch_out = torch_maxpool2d(
167 | inputs,
168 | kernel_size=kernel_size,
169 | padding=padding,
170 | stride=stride,
171 | dilation=dilation,
172 | upper_grad=ug,
173 | )
174 |
175 | assert_tensors_equal["almost"](grad, torch_out.expected_grad)
176 |
177 |
178 | fn test_backward_1() raises:
179 | # padding=2, stride=1, dilation=1
180 | # input shape: (4, 1, 28, 28) kernel size: (5, 5)
181 | alias kernel_size = 5
182 | alias padding = 2
183 | alias stride = 1
184 | alias dilation = 1
185 | alias input_shape = TensorShape(4, 1, 28, 28)
186 | var inputs = Tensor[dtype](input_shape)
187 | rand[dtype](inputs.data(), inputs.num_elements())
188 |
189 | # uppergrad
190 | alias res = get_result_shape(
191 | input_shape, TensorShape(kernel_size, kernel_size), padding, stride, dilation
192 | )
193 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1])
194 | var ug = Tensor[dtype](ug_shape)
195 | rand[dtype](ug.data(), ug.num_elements())
196 |
197 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation](
198 | ug, inputs
199 | )
200 |
201 |
202 | fn test_backward_2() raises:
203 | # padding=0, stride=1, dilation=1
204 | # input shape: (4, 1, 32, 17) kernel size: (2, 2)
205 | alias kernel_size = 2
206 | alias padding = 0
207 | alias stride = 1
208 | alias dilation = 1
209 | alias input_shape = TensorShape(4, 1, 32, 17)
210 | var inputs = Tensor[dtype](input_shape)
211 | rand[dtype](inputs.data(), inputs.num_elements())
212 |
213 | # uppergrad
214 | alias res = get_result_shape(
215 | input_shape, TensorShape(kernel_size, kernel_size), padding, stride, dilation
216 | )
217 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1])
218 | var ug = Tensor[dtype](ug_shape)
219 | rand[dtype](ug.data(), ug.num_elements())
220 |
221 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation](
222 | ug, inputs
223 | )
224 |
225 |
226 | fn test_backward_3() raises:
227 | # padding=(3, 1), stride=(2, 3), dilation=(2, 3)
228 | # input shape: (4, 3, 32, 17) kernel size: (6, 6)
229 | alias kernel_size = IndexList[2](6, 6)
230 | alias padding = IndexList[2](3, 1)
231 | alias stride = IndexList[2](2, 3)
232 | alias dilation = IndexList[2](2, 3)
233 | alias input_shape = TensorShape(4, 3, 32, 17)
234 | var inputs = Tensor[dtype](input_shape)
235 | rand[dtype](inputs.data(), inputs.num_elements())
236 |
237 | # uppergrad
238 | alias kernel_size_static: IndexList[2] = kernel_size
239 | alias res = get_result_shape(
240 | input_shape, TensorShape(kernel_size_static), padding, stride, dilation
241 | )
242 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1])
243 | var ug = Tensor[dtype](ug_shape)
244 | rand[dtype](ug.data(), ug.num_elements())
245 |
246 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation](
247 | ug, inputs
248 | )
249 |
250 |
251 | fn main():
252 | try:
253 | test_forward_1()
254 | test_forward_2()
255 | test_forward_3()
256 | test_backward_1()
257 | test_backward_2()
258 | test_backward_3()
259 | except e:
260 | print("[Error] Error in MaxPool2D")
261 | print(e)
262 |
--------------------------------------------------------------------------------
/tests/testing_utils.mojo:
--------------------------------------------------------------------------------
1 | from python.python import Python
2 | from collections import OptionalReg
3 | from testing import assert_equal, assert_almost_equal
4 |
5 | from basalt import dtype
6 | from basalt.autograd import Graph, OP
7 | from basalt.autograd.ops.ops import backward_op
8 | from basalt.autograd.attributes import AttributeVector
9 | from basalt.nn import Tensor, TensorShape, Model
10 | from basalt.utils.tensor_creation_utils import to_numpy, to_tensor
11 |
12 |
13 | # The below regex should be used to convert deprecated calls
14 | # assert_tensors_equal\(([^,]+),\s*([^,]+),\s*"([^"]+)"\)
15 | # assert_tensors_equal["$3"]($1, $2)
16 | fn assert_tensors_equal[
17 | mode: String = "exact", msg: String = "Error"
18 | ](t1: Tensor[dtype], t2: Tensor[dtype]) raises:
19 | constrained[
20 | mode == "exact" or mode == "almost", "Mode must be either 'exact' or 'almost'"
21 | ]()
22 |
23 | assert_equal(t1.shape(), t2.shape(), "Tensor shape mismatch")
24 |
25 | for i in range(t1.num_elements()):
26 | if mode == "almost":
27 | assert_almost_equal(t1[i], t2[i], rtol=1e-5, atol=1e-5, msg=msg)
28 | else:
29 | assert_equal(t1[i], t2[i], msg=msg)
30 |
31 |
32 | fn test_unary_op[
33 | op: OP, t1_shape: TensorShape, attrs: OptionalReg[AttributeVector] = None
34 | ](t1: Tensor[dtype], expected: Tensor[dtype]) raises:
35 | fn create_graph() -> Graph:
36 | var g = Graph()
37 | var t1 = g.input(t1_shape)
38 |
39 | if attrs:
40 | var res = g.op(op, t1, attributes=attrs.value())
41 | g.out(res)
42 | return g ^
43 | else:
44 | var res = g.op(op, t1)
45 | g.out(res)
46 | return g ^
47 |
48 | alias graph = create_graph()
49 | assert_equal(len(graph.nodes), 1)
50 |
51 | var model = Model[graph](inference_only=True)
52 | var res = model.inference(t1)[0]
53 |
54 | assert_tensors_equal["almost"](res, expected)
55 |
56 |
57 | fn test_binary_op[
58 | op: OP,
59 | t1_shape: TensorShape,
60 | t2_shape: TensorShape,
61 | attrs: OptionalReg[AttributeVector] = None,
62 | ](t1: Tensor[dtype], t2: Tensor[dtype], expected: Tensor[dtype]) raises:
63 | fn create_graph() -> Graph:
64 | var g = Graph()
65 | var t1 = g.input(t1_shape)
66 | var t2 = g.input(t2_shape)
67 |
68 | if attrs:
69 | var res = g.op(op, t1, t2, attributes=attrs.value())
70 | g.out(res)
71 | return g ^
72 | else:
73 | var res = g.op(op, t1, t2)
74 | g.out(res)
75 | return g ^
76 |
77 | alias graph = create_graph()
78 | assert_equal(len(graph.nodes), 1)
79 |
80 | var model = Model[graph](inference_only=True)
81 | var res = model.inference(t1, t2)[0]
82 |
83 | assert_tensors_equal["almost"](res, expected)
84 |
85 |
86 | fn test_ternary_op[
87 | op: OP, t1_shape: TensorShape, t2_shape: TensorShape, t3_shape: TensorShape
88 | ](
89 | t1: Tensor[dtype], t2: Tensor[dtype], t3: Tensor[dtype], expected: Tensor[dtype]
90 | ) raises:
91 | @parameter
92 | fn create_graph() -> Graph:
93 | var g = Graph()
94 | var t1 = g.input(t1_shape)
95 | var t2 = g.input(t2_shape)
96 | var t3 = g.input(t3_shape)
97 |
98 | var res = g.op(op, t1, t2, t3)
99 | g.out(res)
100 |
101 | return g ^
102 |
103 | alias graph = create_graph()
104 | assert_equal(len(graph.nodes), 1)
105 |
106 | var model = Model[graph](inference_only=True)
107 | var res = model.inference(t1, t2, t3)[0]
108 |
109 | assert_tensors_equal["almost"](res, expected)
110 |
111 |
112 | fn test_unary_op_backward[
113 | op: OP,
114 | t1_shape: TensorShape,
115 | ug_shape: TensorShape,
116 | attrs: AttributeVector = AttributeVector(),
117 | ](t1: Tensor[dtype], ug: Tensor[dtype], grad_1_expected: Tensor[dtype],) raises:
118 | var grad_1 = Tensor[dtype](t1_shape)
119 | backward_op[0, op, ug_shape, t1_shape, attrs](ug, t1, grad_1)
120 | assert_tensors_equal["almost"](grad_1, grad_1_expected)
121 |
122 |
123 | fn test_binary_op_backward[
124 | op: OP,
125 | t1_shape: TensorShape,
126 | t2_shape: TensorShape,
127 | ug_shape: TensorShape,
128 | attrs: AttributeVector = AttributeVector(),
129 | ](
130 | t1: Tensor[dtype],
131 | t2: Tensor[dtype],
132 | ug: Tensor[dtype],
133 | grad_1_expected: Tensor[dtype],
134 | grad_2_expected: Tensor[dtype],
135 | ) raises:
136 | var grad_1 = Tensor[dtype](t1_shape)
137 | backward_op[0, op, ug_shape, t1_shape, t2_shape, attrs](ug, t1, t2, grad_1)
138 | assert_tensors_equal["almost"](grad_1, grad_1_expected)
139 |
140 | var grad_2 = Tensor[dtype](t2_shape)
141 | backward_op[1, op, ug_shape, t1_shape, t2_shape, attrs](ug, t1, t2, grad_2)
142 | assert_tensors_equal["almost"](grad_2, grad_2_expected)
143 |
144 |
145 | fn test_ternary_op_backward[
146 | op: OP,
147 | t1_shape: TensorShape,
148 | t2_shape: TensorShape,
149 | t3_shape: TensorShape,
150 | ug_shape: TensorShape,
151 | attrs: AttributeVector = AttributeVector(),
152 | ](
153 | t1: Tensor[dtype],
154 | t2: Tensor[dtype],
155 | t3: Tensor[dtype],
156 | ug: Tensor[dtype],
157 | grad_1_expected: Tensor[dtype],
158 | grad_2_expected: Tensor[dtype],
159 | grad_3_expected: Tensor[dtype],
160 | ) raises:
161 | var grad_1 = Tensor[dtype](t1_shape)
162 | backward_op[0, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs](
163 | ug, t1, t2, t3, grad_1
164 | )
165 | assert_tensors_equal["almost"](grad_1, grad_1_expected)
166 |
167 | var grad_2 = Tensor[dtype](t2_shape)
168 | backward_op[1, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs](
169 | ug, t1, t2, t3, grad_2
170 | )
171 | assert_tensors_equal["almost"](grad_2, grad_2_expected)
172 |
173 | var grad_3 = Tensor[dtype](t3_shape)
174 | backward_op[2, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs](
175 | ug, t1, t2, t3, grad_3
176 | )
177 | assert_tensors_equal["almost"](grad_3, grad_3_expected)
178 |
179 |
180 | fn create_graph_concat(
181 | t1_shape: TensorShape, t2_shape: TensorShape, t3_shape: TensorShape, dim: Int
182 | ) -> Graph:
183 | # Testing with 3 operands
184 | var g = Graph()
185 | var t1 = g.input(t1_shape, trainable=True)
186 | var t2 = g.input(t2_shape, trainable=True)
187 | var t3 = g.input(t3_shape, trainable=True)
188 | var res = g.concat(t1, t2, t3, dim=dim)
189 | g.out(res)
190 | g.loss(res)
191 | return g ^
192 |
193 |
194 | fn create_graph_split(t_shape: TensorShape, sections: List[Int], dim: Int) -> Graph:
195 | var g = Graph()
196 | var t = g.input(t_shape, trainable=True)
197 | var results = g.split(t, sections=sections, dim=dim)
198 | for i in range(len(sections)):
199 | g.out(results[i])
200 | g.loss(results[0]) # Any one
201 | return g ^
202 |
--------------------------------------------------------------------------------