├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── basalt ├── __init__.mojo ├── autograd │ ├── __init__.mojo │ ├── attributes.mojo │ ├── graph.mojo │ ├── node.mojo │ ├── ops │ │ ├── __init__.mojo │ │ ├── basics.mojo │ │ ├── conv.mojo │ │ ├── dynamics.mojo │ │ ├── matmul.mojo │ │ ├── mlops.mojo │ │ ├── ops.mojo │ │ └── pool.mojo │ ├── params.mojo │ └── symbol.mojo ├── nn │ ├── __init__.mojo │ ├── activations.mojo │ ├── initializers.mojo │ ├── layers │ │ ├── __init__.mojo │ │ ├── conv.mojo │ │ ├── dropout.mojo │ │ ├── linear.mojo │ │ ├── pool.mojo │ │ └── sequential.mojo │ ├── loss.mojo │ ├── model.mojo │ ├── optim.mojo │ └── tensor.mojo └── utils │ ├── __init__.mojo │ ├── bytes.mojo │ ├── collection.mojo │ ├── dataloader.mojo │ ├── datasets.mojo │ ├── graph_render.py │ ├── math_util.mojo │ ├── onnx_utils.mojo │ ├── perf_utils.mojo │ ├── rand_utils.mojo │ ├── tensor_creation_utils.mojo │ └── tensorutils.mojo ├── examples ├── data │ ├── housing.csv │ ├── mnist_test_small.csv │ └── mnist_torch.onnx ├── housing.mojo ├── housing.py ├── mnist.mojo ├── mnist.py ├── mnist_load_model.mojo ├── sin_estimate.mojo └── sin_estimate.py ├── magic.lock ├── mojoproject.toml ├── python-requirements.txt └── tests ├── __init__.mojo ├── mojo ├── test_activations.mojo ├── test_attributes.mojo ├── test_backward.mojo ├── test_collection.mojo ├── test_dynamic_ops.mojo ├── test_loss.mojo ├── test_mlops.mojo ├── test_ops.mojo ├── test_tensorutils.mojo └── test_tensorutils_data.mojo ├── python ├── test_broadcast_shapes.mojo ├── test_conv.mojo ├── test_dynamic_ops_torch.mojo ├── test_mlops_torch.mojo ├── test_models_mnist.mojo ├── test_models_regression.mojo ├── test_models_sin_estimate.mojo ├── test_models_torch.py ├── test_ops_torch.mojo └── test_pool.mojo └── testing_utils.mojo /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | 4 | basalt.📦 5 | 6 | examples/data/mnist_test.csv 7 | examples/data/mnist_train.csv 8 | examples/data/mnist_train_small.csv 9 | 10 | output_model.onnx 11 | Makefile 12 | 13 | ./temp 14 | flamegraph.svg 15 | 16 | .magic 17 | 18 | examples/data/yolov8n.onnx 19 | 20 | *.DS_Store -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @StijnWoestenborghs 2 | * @andresnowak 3 | * @Benny-Nottonson 4 | * @soraros -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | 4 | Logo 5 | 6 | 7 |

Basalt

8 | 9 |

10 | A Machine Learning framework from scratch in pure Mojo 🔥 11 |

12 |

13 | 14 |
15 | 16 | 17 | 18 |
19 | 20 | 21 | ## About The Project 22 | 23 | Basalt is a stand-alone machine learning framework that leverages the power of Mojo. 24 | 25 | As [discussed](https://docs.modular.com/mojo/why-mojo) by Modular, Mojo is a language for the future of AI development. Built on top of MLIR technology, rather than existing GCC and LLVM approaches, Mojo looks and feels like Python code, yet performs much closer to languages like Rust or C++. Parametric functions and compile time parameters allow for the graph to statically compiled. Having the static graph allows for much harder performance optimizations. 26 | 27 | Basalt, while still in its infancy, is able to achieve speeds comparable to well established frameworks like Pytorch. Below a snapshot of the current benchmarks. But keep posted, there is much more room for improvement and we are upgrading the project on a daily basis. 28 | 29 | ![basalt_benchmark](https://github.com/basalt-org/basalt/assets/46826967/83037770-a9e3-440d-bdca-f51af0aebee0) 30 | 31 | 32 | ## Quick Start 33 | 34 | Try out the benchmarks yourself: 35 | 36 | ``` 37 | mojo -I . examples/housing.mojo 38 | ``` 39 | ``` 40 | mojo -I . examples/sin_estimate.mojo 41 | ``` 42 | ``` 43 | mojo -I . examples/mnist.mojo 44 | ``` 45 | 46 | Compare to the alternative PyTorch implementation: 47 | Make sure to install the requirements in `python-requirements.txt` in your python environment. 48 | 49 | ``` 50 | python examples/housing.py 51 | python examples/sin_estimate.py 52 | python examples/mnist.py 53 | ``` 54 | 55 | ## Roadmap 56 | 57 | ### v0.1.0 ✅ 58 | - [x] Improve matrix multiplication and convolution kernels 59 | - [x] Switch to custom Tensor and TensorShape implementations 60 | - [x] Improve benchmarks and overall model execution performance 61 | - [x] Add profiling and additional performance tests 62 | 63 | ### v0.2.0 (WIP) 64 | - [ ] Add additional operators: Slice, (Un)Squeeze, Concat, Clip, Gather, Split, FMA ... 65 | - [ ] Better layer support and more activation functions 66 | - [ ] Graph submodules & graph concatenation 67 | - [ ] Computer vision benchmark. 68 | 69 | ### Long-Term 70 | - [ ] Better parallelization 71 | - [ ] GPU support 72 | - [ ] Reworked Dataloader 73 | - [ ] Autotuning and related features 74 | - [ ] Graph compilation optimizations 75 | - [ ] Operator fusion 76 | - [ ] ONNX / Max compatibility 77 | 78 | ## Contributing 79 | 80 | Basalt is built by community efforts and relies on your expertise and enthousiasm! 81 | Small fixes and improvements are much appreciated. If you are considering larger contributions, feel free to contact us for a smoother communication channel on Discord. If you find a bug or have an idea for a feature, please use our issue tracker. Before creating a new issue, please: 82 | * Check if the issue already exists. If an issue is already reported, you can contribute by commenting on the existing issue. 83 | * If not, create a new issue and include all the necessary details to understand/recreate the problem or feature request. 84 | 85 | ### Creating A Pull Request 86 | 87 | 1. Fork the Project 88 | 2. Create your Feature Branch 89 | 3. Commit your Changes 90 | 4. Push to the Branch 91 | 5. Open a Pull Request 92 | > Once your changes are pushed, navigate to your fork on GitHub. And create a pull request against the original basalt-org/basalt repository. 93 | > - Before creating a PR make sure it doesn't break any of the unit-tests. (e.g. `mojo run -I . test/test_ops.mojo`) 94 | > - Introducing new big features requires a new test! 95 | > - In the pull request, provide a detailed description of the changes and why they're needed. Link any relevant issues. 96 | > - If there are any specific instructions for testing or validating your changes, include those as well. 97 | 98 | ## License 99 | 100 | Distributed under the Apache 2.0 License with LLVM Exceptions. See [LICENSE](https://github.com/Basalt-Org/Basalt/blob/main/LICENSE) and the LLVM [License](https://llvm.org/LICENSE.txt) for more information. 101 | 102 | ## Acknowledgements 103 | 104 | * Built with [Mojo](https://github.com/modularml/mojo) created by [Modular](https://github.com/modularml) 105 | -------------------------------------------------------------------------------- /basalt/__init__.mojo: -------------------------------------------------------------------------------- 1 | from .autograd import Graph, Symbol, OP 2 | from .nn import Tensor, TensorShape 3 | from sys.info import simdwidthof 4 | from basalt.utils.collection import Collection 5 | 6 | alias dtype = DType.float32 7 | alias nelts = 2 * simdwidthof[dtype]() 8 | alias seed = 42 9 | alias epsilon = 1e-12 10 | -------------------------------------------------------------------------------- /basalt/autograd/__init__.mojo: -------------------------------------------------------------------------------- 1 | from .symbol import Symbol 2 | from .graph import Graph 3 | from .ops import OP 4 | -------------------------------------------------------------------------------- /basalt/autograd/attributes.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional, OptionalReg 2 | from utils.static_tuple import StaticTuple 3 | from utils.index import IndexList 4 | 5 | from basalt.nn.tensor import Tensor, TensorShape, MAX_RANK 6 | from basalt.utils.bytes import Bytes, scalar_to_bytes, bytes_to_scalar 7 | 8 | 9 | alias MAX_ATTRS = 10 10 | alias MAX_NAME_CHARS = 16 11 | alias MAX_DATA_BYTES = 32 12 | 13 | 14 | @register_passable("trivial") 15 | struct AttributeType(Stringable): 16 | alias BOOL = AttributeType(0, "BOOL") 17 | alias INT = AttributeType(1, "INT") 18 | alias FLOAT = AttributeType(2, "FLOAT") 19 | alias STRING = AttributeType(3, "STRING") 20 | alias INTS = AttributeType(4, "INTS") 21 | alias FLOATS = AttributeType(5, "FLOATS") 22 | 23 | var id: UInt8 24 | var name: Bytes[MAX_NAME_CHARS] 25 | 26 | fn __init__(inout self, id: UInt8, name: String): 27 | self.id = id 28 | self.name = Bytes[MAX_NAME_CHARS](name) 29 | 30 | fn __init__(inout self, type: DType): 31 | if type.is_floating_point(): 32 | self = AttributeType.FLOAT 33 | elif type == DType.bool: 34 | self = AttributeType.BOOL 35 | else: 36 | self = AttributeType.INT 37 | 38 | fn __eq__(self, other: Self) -> Bool: 39 | return self.id == other.id 40 | 41 | fn __str__(self) -> String: 42 | return str(self.name) 43 | 44 | 45 | @register_passable("trivial") 46 | struct AttributeVector(Sized, Stringable, CollectionElement): 47 | var attributes: StaticTuple[Attribute, MAX_ATTRS] 48 | var size: Int 49 | 50 | fn __init__(inout self, *attributes: Attribute): 51 | self.attributes = StaticTuple[Attribute, MAX_ATTRS](Attribute("", "")) 52 | self.size = len(attributes) 53 | for i in range(self.size): 54 | self.attributes[i] = attributes[i] 55 | 56 | @always_inline("nodebug") 57 | fn __len__(self) -> Int: 58 | return self.size 59 | 60 | @always_inline("nodebug") 61 | fn __getitem__(self, index: Int) -> Attribute: 62 | return self.attributes[index] 63 | 64 | @always_inline("nodebug") 65 | fn __getitem__(self, index: StringLiteral) -> OptionalReg[Attribute]: 66 | for i in range(self.size): 67 | if self.attributes[i].name == Bytes[MAX_NAME_CHARS](index): 68 | return self.attributes[i] 69 | return None 70 | 71 | fn __str__(self) -> String: 72 | var s: String = "[" 73 | for i in range(self.size): 74 | s += str(self.attributes[i]) 75 | if i < self.size - 1: 76 | s += ", " 77 | return s + "]" 78 | 79 | 80 | @register_passable("trivial") 81 | struct Attribute(Stringable, CollectionElement): 82 | var data_shape: IndexList[MAX_RANK] 83 | var name: Bytes[MAX_NAME_CHARS] 84 | var data: Bytes[MAX_DATA_BYTES] 85 | var type: AttributeType 86 | var size: Int 87 | 88 | fn __init__(inout self, name: String, value: String): 89 | self.data_shape = IndexList[MAX_RANK]() 90 | self.name = Bytes[MAX_NAME_CHARS](name) 91 | self.data = Bytes[MAX_DATA_BYTES](value) 92 | self.type = AttributeType.STRING 93 | self.size = len(value) 94 | 95 | fn __init__(inout self, name: String, value: TensorShape): 96 | self.data_shape = IndexList[MAX_RANK]() 97 | self.name = Bytes[MAX_NAME_CHARS](name) 98 | self.data = Bytes[MAX_DATA_BYTES]() 99 | self.type = AttributeType.INTS 100 | self.size = value.rank() 101 | 102 | for i in range(self.size): 103 | self.data_shape[i] = value._shape[i] 104 | 105 | fn __init__[N: Int](inout self, name: String, value: IndexList[N]): 106 | constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]() 107 | 108 | self.data_shape = IndexList[MAX_RANK]() 109 | self.name = Bytes[MAX_NAME_CHARS](name) 110 | self.data = Bytes[MAX_DATA_BYTES]() 111 | self.type = AttributeType.INTS 112 | self.size = N 113 | 114 | for i in range(self.size): 115 | self.data_shape[i] = value[i] 116 | 117 | fn __init__(inout self, name: String, value: List[Int]): 118 | self.data_shape = IndexList[MAX_RANK]() 119 | self.name = Bytes[MAX_NAME_CHARS](name) 120 | self.data = Bytes[MAX_DATA_BYTES]() 121 | self.type = AttributeType.INTS 122 | self.size = len(value) 123 | 124 | for i in range(self.size): 125 | self.data_shape[i] = value[i] 126 | 127 | fn __init__(inout self, name: String, value: StaticTuple[Int, _]): 128 | self.data_shape = IndexList[MAX_RANK]() 129 | self.name = Bytes[MAX_NAME_CHARS](name) 130 | self.data = Bytes[MAX_DATA_BYTES]() 131 | self.type = AttributeType.INTS 132 | self.size = len(value) 133 | 134 | for i in range(self.size): 135 | self.data_shape[i] = value[i] 136 | 137 | fn __init__[dtype: DType](inout self, name: String, value: Scalar[dtype]): 138 | constrained[dtype.is_numeric(), "Attribute value must be numeric."]() 139 | 140 | self.data_shape = IndexList[MAX_RANK]() 141 | self.name = Bytes[MAX_NAME_CHARS](name) 142 | self.data = scalar_to_bytes[dtype, MAX_DATA_BYTES](value) 143 | self.type = AttributeType(dtype) 144 | self.size = 1 145 | 146 | fn __init__(inout self, name: String, value: Int): 147 | self.__init__(name, Int64(value)) 148 | self.data_shape[0] = 1 149 | 150 | fn __init__(inout self, name: String, value: FloatLiteral): 151 | self.__init__(name, Float64(value)) 152 | self.data_shape[0] = 1 153 | 154 | @always_inline("nodebug") 155 | fn __str__(self) -> String: 156 | return "Attribute(" + str(self.name) + ", " + "..." + ")" 157 | 158 | @always_inline("nodebug") 159 | fn to_string(self) -> String: 160 | return str(self.data) 161 | 162 | @always_inline("nodebug") 163 | fn to_list(self) -> List[Int]: 164 | var result = List[Int]() 165 | 166 | for i in range(self.size): 167 | result.append(self.data_shape[i]) 168 | 169 | return result 170 | 171 | @always_inline("nodebug") 172 | fn to_shape(self) -> TensorShape: 173 | return TensorShape(rank=self.size, shape=self.data_shape) 174 | 175 | @always_inline("nodebug") 176 | fn to_static[N: Int](self) -> IndexList[N]: 177 | constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]() 178 | 179 | var result = IndexList[N]() 180 | 181 | for i in range(N): 182 | result[i] = int(self.data_shape[i]) 183 | 184 | return result 185 | 186 | @always_inline("nodebug") 187 | fn to_scalar[dtype: DType](self) -> Scalar[dtype]: 188 | constrained[dtype.is_numeric(), "Attribute value must be numeric."]() 189 | 190 | return bytes_to_scalar[dtype](self.data) 191 | 192 | @always_inline("nodebug") 193 | fn to_int(self) -> Int: 194 | return int(self.to_scalar[DType.int64]()) 195 | 196 | fn json(self) -> String: 197 | var result = '{"name": "' + str(self.name) + '", ' 198 | 199 | var type: String = "" 200 | var value: String = "" 201 | 202 | if self.type == AttributeType.STRING: 203 | type = "STRING" 204 | value = '"' + self.to_string() + '"' 205 | elif self.type == AttributeType.INTS: 206 | type = "INTS" 207 | 208 | var value_temp = self.to_shape() 209 | value = "[" 210 | for i in range(value_temp.rank()): 211 | value += str(value_temp._shape[i]) 212 | if i < value_temp.rank() - 1: 213 | value += ", " 214 | value += "]" 215 | elif self.type == AttributeType.FLOAT: 216 | type = "FLOAT" 217 | value = str(self.to_scalar[DType.float64]()) 218 | elif self.type == AttributeType.INT: 219 | type = "INT" 220 | value = str(self.to_int()) 221 | else: 222 | type = "UNKNOWN" 223 | value = "UNKNOWN" 224 | 225 | result += '"type": "' + type + '", ' + '"value": ' + value 226 | 227 | return result + "}" 228 | -------------------------------------------------------------------------------- /basalt/autograd/graph.mojo: -------------------------------------------------------------------------------- 1 | from python.python import Python 2 | from collections.optional import Optional, OptionalReg 3 | 4 | from .node import Node 5 | from .attributes import AttributeVector, Attribute 6 | from .symbol import Symbol 7 | from .ops import OP, static_result_shape, dynamic_result_shape 8 | from .params import ParamDict, Param 9 | 10 | from basalt import seed, dtype 11 | from basalt import Tensor, TensorShape 12 | 13 | 14 | struct Graph: 15 | var inputs: List[Symbol] 16 | var params: ParamDict 17 | var nodes: List[Node] 18 | var outputs: List[Symbol] 19 | var loss_out: OptionalReg[Symbol] 20 | var symbol_count: UInt32 21 | 22 | fn __init__(inout self): 23 | self.inputs = List[Symbol]() 24 | self.params = ParamDict() 25 | self.nodes = List[Node]() 26 | self.outputs = List[Symbol]() 27 | self.loss_out = None 28 | self.symbol_count = 0 29 | 30 | fn __moveinit__(inout self, owned other: Graph): 31 | self.inputs = other.inputs^ 32 | self.params = other.params^ 33 | self.nodes = other.nodes^ 34 | self.outputs = other.outputs^ 35 | self.loss_out = other.loss_out 36 | self.symbol_count = other.symbol_count 37 | 38 | fn create_symbol(inout self, shape: TensorShape, data: Optional[Param] = None, trainable: Bool = False, is_input: Bool = False) -> Symbol: 39 | var symbol = Symbol(self.symbol_count, dtype, shape, trainable) 40 | self.symbol_count += 1 41 | 42 | if is_input: 43 | self.inputs.append(symbol) 44 | else: 45 | if data is not None: 46 | self.params.put(symbol, data.value()) 47 | else: 48 | self.params.put(symbol) 49 | 50 | return symbol 51 | 52 | fn input(inout self, shape: TensorShape, trainable: Bool = False) -> Symbol: 53 | return self.create_symbol(shape, trainable=trainable, is_input=True) 54 | 55 | fn param(inout self, shape: TensorShape, init: Param, trainable: Bool = True) -> Symbol: 56 | return self.create_symbol(shape, init, trainable) 57 | 58 | fn param(inout self, shape: TensorShape, trainable: Bool = True) -> Symbol: 59 | return self.create_symbol(shape, trainable=trainable) 60 | 61 | fn scalar(inout self, value: Scalar[dtype]) -> Symbol: 62 | return self.create_symbol(TensorShape(1), Param(value), trainable=False) 63 | 64 | fn constant(inout self, shape: TensorShape, data: List[Scalar[dtype]]) -> Symbol: 65 | return self.create_symbol(shape, Param(data), trainable=False) 66 | 67 | fn out(inout self, symbol: Symbol): 68 | self.outputs.append(symbol) 69 | 70 | fn loss(inout self, symbol: Symbol): 71 | self.loss_out = symbol 72 | 73 | fn op( 74 | inout self, 75 | op: OP, 76 | *operands: Symbol, 77 | attributes: AttributeVector = AttributeVector(), 78 | ) -> Symbol: 79 | var res_shape = static_result_shape(op, operands, attributes) 80 | var res = Symbol(self.symbol_count, dtype, res_shape, self.result_trainable(operands)) 81 | self.symbol_count += 1 82 | 83 | var inputs = List[Symbol]() 84 | inputs.reserve(len(operands)) 85 | 86 | for operand in operands: 87 | inputs.append(operand) 88 | 89 | self.nodes.append(Node(op, inputs, List[Symbol](res), attributes)) 90 | return res 91 | 92 | fn op( 93 | inout self, 94 | op: OP, 95 | operand_1: Symbol, 96 | operand_2: Float64, 97 | attributes: AttributeVector = AttributeVector(), 98 | ) -> Symbol: 99 | return self.op(op, operand_1, self.scalar(operand_2.cast[dtype]()), attributes=attributes) 100 | 101 | fn op( 102 | inout self, 103 | op: OP, 104 | operand_1: Float64, 105 | operand_2: Symbol, 106 | attributes: AttributeVector = AttributeVector(), 107 | ) -> Symbol: 108 | return self.op(op, self.scalar(operand_1.cast[dtype]()), operand_2, attributes=attributes) 109 | 110 | fn create_symbols(inout self, shapes: List[TensorShape], trainable: Bool = False) -> List[Symbol]: 111 | var symbols = List[Symbol]() 112 | symbols.reserve(len(shapes)) 113 | 114 | for shape in shapes: 115 | symbols.append(Symbol(self.symbol_count, dtype, shape[], trainable)) 116 | self.symbol_count += 1 117 | 118 | return symbols 119 | 120 | fn add_node(inout self, op: OP, inputs: List[Symbol], outputs: List[Symbol], attributes: AttributeVector): 121 | self.nodes.append(Node(op, inputs, outputs, attributes)) 122 | 123 | fn concat(inout self, *operands: Symbol, dim: Int = 0) -> Symbol: 124 | var attributes = AttributeVector(Attribute("dim", dim)) 125 | var res_shape = dynamic_result_shape(OP.CONCAT, operands, attributes)[0] 126 | var res_symbols = self.create_symbols(List[TensorShape](res_shape), self.result_trainable(operands)) 127 | 128 | var operand_list = List[Symbol]() 129 | operand_list.reserve(len(operands)) 130 | for operand in operands: 131 | operand_list.append(operand) 132 | 133 | self.add_node(OP.CONCAT, operand_list, res_symbols, attributes) 134 | return res_symbols[0] 135 | 136 | fn split( 137 | inout self, operand: Symbol, sections: List[Int], dim: Int = 0 138 | ) -> List[Symbol]: 139 | var attributes = AttributeVector(Attribute("sections", TensorShape(sections)), Attribute("dim", dim)) 140 | var res_shapes = dynamic_result_shape(OP.SPLIT, operand, attributes) 141 | var trainable = self.result_trainable(operand) 142 | var result_symbols = self.create_symbols(res_shapes, trainable) 143 | self.add_node(OP.SPLIT, List[Symbol](operand), result_symbols, attributes) 144 | return result_symbols 145 | 146 | @staticmethod 147 | fn result_trainable(operands: VariadicList[Symbol]) -> Bool: 148 | for operand in operands: 149 | if operand.trainable: 150 | return True 151 | return False 152 | 153 | fn json(self) -> String: 154 | var result: String = '{"graph_name": "basalt", "nodes": [' 155 | for i in range(len(self.nodes)): 156 | result += self.nodes[i].json() 157 | if i < len(self.nodes) - 1: 158 | result += ", " 159 | result += '], "inputs": [' 160 | for i in range(len(self.inputs)): 161 | result += self.inputs[i].json() 162 | if i < len(self.inputs) - 1: 163 | result += ", " 164 | result += '], "outputs": [' 165 | for i in range(len(self.outputs)): 166 | result += self.outputs[i].json() 167 | if i < len(self.outputs) - 1: 168 | result += ", " 169 | if self.loss_out: 170 | result += '], "loss": [' 171 | result += self.loss_out.value().json() 172 | result += '], "params": [' 173 | for i in range(len(self.params)): 174 | result += self.params.symbols[i].json() 175 | if i < len(self.params) - 1: 176 | result += ", " 177 | result += "]}" 178 | return result 179 | 180 | fn render(self, render_type: String = "node") raises: 181 | Python.add_to_path("./basalt/utils") 182 | var renderer = Python.import_module("graph_render") 183 | var json = Python.import_module("json") 184 | _ = renderer.netron_render(json.loads(self.json()), render_type) 185 | 186 | fn compile(inout self): 187 | # 0. Sorting the graph 188 | # The staticlly defined graph has an implicit topological sorted order because, 189 | # each new operation is added the list of nodes after its dependencies have been calculated. 190 | # This eliminates the need for explicit topological sorting. 191 | 192 | # Possibilities: 193 | # - 1. Graph layout transformation (graph rewrite) 194 | # - Layer pruning (removing nodes that have no effect - with common sub-tree identification) 195 | # - Eliminate redundant intermediate data copies 196 | # - Operator replacement (e.g. replacing (combination of) costly ops with more efficient ones) 197 | # - (exmple of graph rewrite: https://dl.acm.org/doi/pdf/10.1145/3453483.3454083 - Table 4) 198 | # - Other intra-block optimizations: (e.g. data layout transformation BCHW -> BHWC, etc.) 199 | # - 2. Operator fusion (combining ops without materializing intermediate results) 200 | # - Fusion plan exploration 201 | # - Fusion plan generation (with subsequent intra-block optimizations) 202 | # - (example fusion plan algorithm: https://dl.acm.org/doi/pdf/10.1145/3453483.3454083 - Listing 1) 203 | # - 3. Fusion Code generation (behaviour) 204 | # - Code generation for planned fusion blocks 205 | # - Other inter-block optimizations (e.g. data layout transformation BCHW -> BHWC, etc.) 206 | # - 4. Auto-tuning (of vectorization-, parallelization-, tiling-, unrolling-parameters) 207 | # - (Might only work when memory is initialized) 208 | 209 | # Other considerations: 210 | # - Efficient Memory management: 211 | # - Memory reuse (in-place operations) 212 | # - Data layout from BCHW (batch, channel, height, width) to BHWC can lead to better utilization and efficiency 213 | # - VJP, JVP (for automatic differentiation) 214 | 215 | pass 216 | -------------------------------------------------------------------------------- /basalt/autograd/node.mojo: -------------------------------------------------------------------------------- 1 | from collections.optional import Optional 2 | from utils.variant import Variant 3 | 4 | from basalt.autograd import Symbol 5 | from basalt.autograd.ops import OP 6 | 7 | from .attributes import AttributeVector 8 | 9 | 10 | @value 11 | struct Node(CollectionElement, Stringable): 12 | var operator: OP 13 | var inputs: List[Symbol] 14 | var outputs: List[Symbol] 15 | var attributes: AttributeVector 16 | 17 | fn __init__( 18 | inout self, 19 | operator: OP, 20 | inputs: List[Symbol], 21 | outputs: List[Symbol], 22 | attributes: AttributeVector = AttributeVector(), 23 | ): 24 | self.operator = operator 25 | self.inputs = inputs 26 | self.outputs = outputs 27 | self.attributes = attributes 28 | 29 | fn __str__(self) -> String: 30 | return self.json() 31 | 32 | fn json(self) -> String: 33 | var s: String = '{"operator": "' + str(self.operator.name) + '", "inputs": [' 34 | for i in range(len(self.inputs)): 35 | s += self.inputs[i].json() 36 | if i < len(self.inputs) - 1: 37 | s += ", " 38 | s += '], "outputs": [' 39 | for i in range(len(self.outputs)): 40 | s += self.outputs[i].json() 41 | if i < len(self.outputs) - 1: 42 | s += ", " 43 | s += '], "attributes": [' 44 | for i in range(len(self.attributes)): 45 | s += self.attributes[i].json() 46 | if i < len(self.attributes) - 1: 47 | s += ", " 48 | s += "]}" 49 | return s 50 | -------------------------------------------------------------------------------- /basalt/autograd/ops/__init__.mojo: -------------------------------------------------------------------------------- 1 | from .ops import ( 2 | OP, 3 | static_result_shape, 4 | dynamic_result_shape, 5 | forward_op, 6 | backward_op, 7 | ) 8 | -------------------------------------------------------------------------------- /basalt/autograd/ops/dynamics.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Symbol 2 | from basalt.nn.model import Parameters 3 | from ..attributes import AttributeVector 4 | 5 | from memory import memcpy 6 | 7 | 8 | struct CONCAT: 9 | @staticmethod 10 | fn result_shape( 11 | input_shapes: List[TensorShape], attributes: AttributeVector 12 | ) -> List[TensorShape]: 13 | # Assumptions: all tensors have the same shape, except for the concatenating dimension 14 | var dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 15 | 16 | var concat_size: Int = 0 17 | for i in range(len(input_shapes)): 18 | concat_size += input_shapes[i][dim] 19 | 20 | var res_shape = input_shapes[0] 21 | res_shape[dim] = concat_size 22 | 23 | return List[TensorShape](res_shape) 24 | 25 | @staticmethod 26 | fn calc_chunks(shape: TensorShape, dim: Int) -> Int: 27 | # Number of chunks up to the concatenating dimension 28 | # Assuming tensor of equal shape, except for the concatenating dimension 29 | var chunks = 1 30 | for i in range(dim): 31 | chunks *= shape[i] 32 | return chunks 33 | 34 | @staticmethod 35 | fn forward[attributes: AttributeVector]( 36 | inputs: List[Symbol], 37 | outputs: List[Symbol], 38 | inout parameters: Parameters, 39 | ): 40 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 41 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim) 42 | 43 | var chunks = List[Int]() 44 | var chunk_offsets = List[Int](0) 45 | for i in range(len(inputs)): 46 | chunks.append(inputs[i].shape.num_elements() // n_chunks) 47 | chunk_offsets.append(chunk_offsets[i] + chunks[i]) 48 | 49 | for i in range(n_chunks): 50 | for j in range(len(inputs)): 51 | memcpy( 52 | parameters.tensors[outputs[0]].data() 53 | + i * chunk_offsets[len(inputs)] 54 | + chunk_offsets[j], 55 | parameters.tensors[inputs[j]].data() + i * chunks[j], 56 | chunks[j], 57 | ) 58 | 59 | @staticmethod 60 | fn backward[input_id: Int, attributes: AttributeVector]( 61 | inputs: List[Symbol], 62 | outputs: List[Symbol], 63 | inout parameters: Parameters, 64 | ) -> Tensor[dtype]: 65 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 66 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim) 67 | 68 | var chunks = List[Int]() 69 | var chunk_offsets = List[Int](0) 70 | for i in range(len(inputs)): 71 | chunks.append(inputs[i].shape.num_elements() // n_chunks) 72 | chunk_offsets.append(chunk_offsets[i] + chunks[i]) 73 | 74 | var res_grad = Tensor[dtype](inputs[input_id].shape) 75 | for i in range(n_chunks): 76 | memcpy( 77 | res_grad.data() + i * chunks[input_id], 78 | parameters.grads[outputs[0]].data() 79 | + i * chunk_offsets[len(inputs)] 80 | + chunk_offsets[input_id], 81 | chunks[input_id], 82 | ) 83 | 84 | return res_grad ^ 85 | 86 | 87 | struct SPLIT: 88 | @staticmethod 89 | fn result_shape( 90 | input_shapes: List[TensorShape], attributes: AttributeVector 91 | ) -> List[TensorShape]: 92 | # Assuming the sum of the sections is equal to the total size in the dim dimension. 93 | # E.g. sections = [5, 5, 2] -> shape (., 12, ., .) for dim = 1 94 | var dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 95 | var sections = attributes["sections"].value().to_shape() 96 | 97 | var res_shapes = List[TensorShape]() 98 | for i in range(sections.rank()): 99 | var new_shape = input_shapes[0] 100 | new_shape[dim] = sections[i] 101 | res_shapes.append(new_shape) 102 | 103 | return res_shapes 104 | 105 | @staticmethod 106 | fn calc_chunks(shape: TensorShape, dim: Int) -> Int: 107 | # Number of chunks up to the concatenating dimension 108 | # Assuming tensor of equal shape, except for the concatenating dimension 109 | var chunks = 1 110 | for i in range(dim): 111 | chunks *= shape[i] 112 | return chunks 113 | 114 | @staticmethod 115 | fn forward[attributes: AttributeVector]( 116 | inputs: List[Symbol], 117 | outputs: List[Symbol], 118 | inout parameters: Parameters, 119 | ): 120 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 121 | alias sections = attributes["sections"].value().to_shape() 122 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim) 123 | 124 | var chunks = List[Int]() 125 | var chunk_offsets = List[Int](0) 126 | for i in range(len(outputs)): 127 | chunks.append(outputs[i].shape.num_elements() // n_chunks) 128 | chunk_offsets.append(chunk_offsets[i] + chunks[i]) 129 | 130 | for i in range(n_chunks): 131 | for j in range(len(outputs)): 132 | memcpy( 133 | parameters.tensors[outputs[j]].data() + i * chunks[j], 134 | parameters.tensors[inputs[0]].data() 135 | + i * chunk_offsets[len(outputs)] 136 | + chunk_offsets[j], 137 | chunks[j], 138 | ) 139 | 140 | @staticmethod 141 | fn backward[input_id: Int, attributes: AttributeVector]( 142 | inputs: List[Symbol], 143 | outputs: List[Symbol], 144 | inout parameters: Parameters, 145 | ) -> Tensor[dtype]: 146 | alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 147 | alias sections = attributes["sections"].value().to_shape() 148 | var n_chunks = Self.calc_chunks(inputs[0].shape, dim) 149 | 150 | var chunks = List[Int]() 151 | var chunk_offsets = List[Int](0) 152 | for i in range(len(outputs)): 153 | chunks.append(outputs[i].shape.num_elements() // n_chunks) 154 | chunk_offsets.append(chunk_offsets[i] + chunks[i]) 155 | 156 | var res_grad = Tensor[dtype](inputs[input_id].shape) 157 | 158 | for i in range(n_chunks): 159 | for j in range(len(outputs)): 160 | memcpy( 161 | res_grad.data() 162 | + i * chunk_offsets[len(outputs)] 163 | + chunk_offsets[j], 164 | parameters.grads[outputs[j]].data() + i * chunks[j], 165 | chunks[j], 166 | ) 167 | 168 | return res_grad ^ 169 | -------------------------------------------------------------------------------- /basalt/autograd/ops/matmul.mojo: -------------------------------------------------------------------------------- 1 | from basalt.utils.tensorutils import transpose_2D 2 | 3 | from algorithm import vectorize, parallelize 4 | from memory import memset_zero, stack_allocation, UnsafePointer 5 | from sys.info import simdwidthof 6 | 7 | 8 | @always_inline 9 | fn calculate_block[ 10 | M: Int, N: Int, K: Int, BLOCK_M: Int, BLOCK_N: Int, nelts: Int 11 | ]( 12 | res: UnsafePointer[Scalar[dtype]], 13 | t1: UnsafePointer[Scalar[dtype]], 14 | t2: UnsafePointer[Scalar[dtype]], 15 | bm: Int, 16 | bn: Int, 17 | ): 18 | # Compute tile 19 | var acc = stack_allocation[BLOCK_M * BLOCK_N, dtype]() 20 | memset_zero(acc, BLOCK_M * BLOCK_N) 21 | 22 | for k in range(K): 23 | 24 | @parameter 25 | for m in range(BLOCK_M): 26 | 27 | @parameter 28 | fn inner_n[nelts: Int](n: Int): 29 | acc.store( 30 | m * BLOCK_N + n, 31 | SIMD[dtype, nelts](t1[(bm + m) * K + k]) 32 | .fma( 33 | t2.load[width=nelts](k * N + (bn + n)), 34 | acc.load[width=nelts](m * BLOCK_N + n), 35 | ), 36 | ) 37 | 38 | vectorize[inner_n, nelts](BLOCK_N) 39 | 40 | # Store tile 41 | for m in range(BLOCK_M): 42 | 43 | @parameter 44 | fn vec_store[nelts: Int](n: Int): 45 | res.store( 46 | (bm + m) * N + (bn + n), acc.load[width=nelts](m * BLOCK_N + n) 47 | ) 48 | 49 | vectorize[vec_store, nelts](BLOCK_N) 50 | 51 | 52 | @parameter 53 | @always_inline 54 | fn dot[ 55 | t1_shape: TensorShape, t2_shape: TensorShape 56 | ](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype]): 57 | dot[t1_shape, t2_shape](res.data(), t1.data(), t2.data()) 58 | 59 | 60 | @parameter 61 | @always_inline 62 | fn dot[ 63 | t1_shape: TensorShape, t2_shape: TensorShape 64 | ](res: UnsafePointer[Scalar[dtype]], t1: UnsafePointer[Scalar[dtype]], t2: UnsafePointer[Scalar[dtype]]): 65 | alias M = t1_shape[0] # t1[0] 66 | alias K = t1_shape[1] # t1[1], t2[0] 67 | alias N = t2_shape[1] # t2[1] 68 | 69 | # simdwidthof[dtype]() = 8 for float32 70 | alias nelts = simdwidthof[dtype]() 71 | alias BLOCK_N = 8 * 2 72 | alias BLOCK_M = 6 73 | alias THREADS = 6 # num_logical_cores() 74 | 75 | alias BLOCK_N_REMAINDER = N % BLOCK_N 76 | alias BLOCK_M_REMAINDER = M % BLOCK_M 77 | 78 | @parameter 79 | fn bm_par(m_outer: Int): 80 | var bm = m_outer * BLOCK_M 81 | 82 | for n_outer in range(0, N // BLOCK_N): 83 | var bn = n_outer * BLOCK_N 84 | 85 | calculate_block[M, N, K, BLOCK_M, BLOCK_N, nelts](res, t1, t2, bm, bn) 86 | 87 | # Handle the remainder of N 88 | @parameter 89 | if BLOCK_N_REMAINDER > 0: 90 | var bn = N - BLOCK_N_REMAINDER 91 | 92 | calculate_block[M, N, K, BLOCK_M, BLOCK_N_REMAINDER, nelts]( 93 | res, t1, t2, bm, bn 94 | ) 95 | 96 | parallelize[bm_par](M // BLOCK_M, M // BLOCK_M) 97 | 98 | # Handle the remainder of M 99 | @parameter 100 | if BLOCK_M_REMAINDER > 0: 101 | var bm = M - BLOCK_M_REMAINDER 102 | 103 | for n_outer in range(0, N // BLOCK_N): 104 | var bn = n_outer * BLOCK_N 105 | 106 | calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N, nelts]( 107 | res, t1, t2, bm, bn 108 | ) 109 | 110 | # Handle corner remainder 111 | @parameter 112 | if BLOCK_N_REMAINDER > 0: 113 | var bn = N - BLOCK_N_REMAINDER 114 | 115 | calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N_REMAINDER, nelts]( 116 | res, t1, t2, bm, bn 117 | ) 118 | 119 | 120 | fn dot_transpose_t2[ 121 | A_shape: TensorShape, B_shape: TensorShape 122 | ](inout C: UnsafePointer[Scalar[dtype]], A: UnsafePointer[Scalar[dtype]], B: UnsafePointer[Scalar[dtype]]): 123 | dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) 124 | 125 | 126 | fn dot_transpose_t2[ 127 | A_shape: TensorShape, B_shape: TensorShape 128 | ](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): 129 | memset_zero(C.data(), C.num_elements()) 130 | 131 | dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) 132 | 133 | # @parameter 134 | # fn calc_row(i: Int): 135 | # for j in range(B_shape[0]): 136 | 137 | # @parameter 138 | # fn calc_row_A_B[nelts: Int](k: Int): 139 | # var A_pos = i * A.dim(1) + k 140 | # var B_pos = j * A.dim(1) + k 141 | # var t_new_pos = i * C.dim(1) + j 142 | 143 | # C[t_new_pos] += ( 144 | # A.load[nelts](A_pos) * B.load[nelts](B_pos) 145 | # ).reduce_add() 146 | 147 | # vectorize[calc_row_A_B, nelts, size=A_shape[1]]() 148 | 149 | # parallelize[calc_row](A_shape[0], 1) 150 | 151 | 152 | fn dot_transpose_t1[ 153 | A_shape: TensorShape, B_shape: TensorShape 154 | ](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): 155 | memset_zero(C.data(), C.num_elements()) 156 | 157 | dot[TensorShape(A_shape[1], A_shape[0]), B_shape](C, transpose_2D[A_shape](A), B) 158 | 159 | # @parameter 160 | # fn calc_row(i: Int): 161 | # for j in range(A_shape[0]): 162 | 163 | # @parameter 164 | # fn calc_row_t_new_B[nelts: Int](k: Int): 165 | # var A_pos = j * A.dim(1) + i 166 | # var B_pos = j * B.dim(1) + k 167 | # var t_new_pos = i * C.dim(1) + k 168 | 169 | # C.store[nelts]( 170 | # t_new_pos, 171 | # C.load[nelts](t_new_pos) 172 | # + A[A_pos] * B.load[nelts](B_pos), 173 | # ) 174 | 175 | # vectorize[calc_row_t_new_B, nelts, size=B_shape[1]]() 176 | 177 | # parallelize[calc_row](A_shape[1], 1) 178 | -------------------------------------------------------------------------------- /basalt/autograd/ops/pool.mojo: -------------------------------------------------------------------------------- 1 | from utils.numerics import min_or_neg_inf 2 | 3 | from basalt import Tensor, TensorShape 4 | from basalt.autograd.attributes import AttributeVector 5 | from basalt.autograd.ops.conv import get_result_shape 6 | 7 | 8 | struct MAXPOOL2D: 9 | @staticmethod 10 | fn result_shape( 11 | input_shape: TensorShape, attributes: AttributeVector 12 | ) -> TensorShape: 13 | var kernel_size = attributes["kernel_size"].value().to_static[2]() 14 | var padding = attributes["padding"].value().to_static[2]() 15 | var stride = attributes["stride"].value().to_static[2]() 16 | var dilation = attributes["dilation"].value().to_static[2]() 17 | 18 | var res = get_result_shape( 19 | input_shape, 20 | TensorShape(kernel_size[0], kernel_size[1]), 21 | padding, 22 | stride, 23 | dilation, 24 | ) 25 | 26 | return TensorShape(input_shape[0], input_shape[1], res[0], res[1]) 27 | 28 | @staticmethod 29 | fn forward[ 30 | input_shape: TensorShape, attributes: AttributeVector 31 | ](inout outputs: Tensor[dtype], inputs: Tensor[dtype]): 32 | """ 33 | Returns the max value of each kernel in the input tensor. 34 | inputs.shape [batch_size, channels, iX, iY] 35 | with kernel_size = (kX, kY) 36 | outputs.shape [batch_size, channels, oX, oY]. 37 | """ 38 | alias kernel_size = attributes["kernel_size"].value().to_static[2]() 39 | alias padding = attributes["padding"].value().to_static[2]() 40 | alias stride = attributes["stride"].value().to_static[2]() 41 | alias dilation = attributes["dilation"].value().to_static[2]() 42 | 43 | alias inputs_strides = input_shape.strides() 44 | alias output_shape = Self.result_shape(input_shape, attributes) 45 | alias outputs_strides = output_shape.strides() 46 | 47 | for batch in range(input_shape[0]): 48 | for in_ch in range(input_shape[1]): 49 | for x in range(output_shape[2]): 50 | for y in range(output_shape[3]): 51 | var max_val: Scalar[dtype] = min_or_neg_inf[dtype]() 52 | var ix_base = x * stride[0] - padding[0] 53 | var iy_base = y * stride[1] - padding[1] 54 | for kx in range(kernel_size[0]): 55 | for ky in range(kernel_size[1]): 56 | var ix = ix_base + kx * dilation[0] 57 | var iy = iy_base + ky * dilation[1] 58 | 59 | if ( 60 | ix < 0 61 | or iy < 0 62 | or ix >= input_shape[2] 63 | or iy >= input_shape[3] 64 | ): 65 | continue 66 | 67 | var idx = ( 68 | batch * inputs_strides[0] 69 | + in_ch * inputs_strides[1] 70 | + ix * inputs_strides[2] 71 | + iy 72 | ) 73 | 74 | var val = inputs[idx] 75 | if val > max_val: 76 | max_val = val 77 | 78 | var out_idx = ( 79 | batch * outputs_strides[0] 80 | + in_ch * outputs_strides[1] 81 | + x * outputs_strides[2] 82 | + y 83 | ) 84 | 85 | outputs[out_idx] = max_val 86 | 87 | @staticmethod 88 | fn backward[ 89 | ug_shape: TensorShape, input_shape: TensorShape, attributes: AttributeVector 90 | ](ug: Tensor[dtype], inputs: Tensor[dtype]) -> Tensor[dtype]: 91 | """ 92 | Backward operation of MAXPOOL2D. 93 | 94 | Upper gradient of shape: [batch_size, channels, uX, uY] 95 | """ 96 | alias kernel_size = attributes["kernel_size"].value().to_static[2]() 97 | alias padding = attributes["padding"].value().to_static[2]() 98 | alias stride = attributes["stride"].value().to_static[2]() 99 | alias dilation = attributes["dilation"].value().to_static[2]() 100 | 101 | alias ug_strides = ug_shape.strides() 102 | alias inputs_strides = input_shape.strides() 103 | 104 | var res = Tensor[dtype](input_shape) 105 | 106 | for batch in range(input_shape[0]): 107 | for in_ch in range(input_shape[1]): 108 | for x in range(ug_shape[2]): 109 | for y in range(ug_shape[3]): 110 | var max_val: Scalar[dtype] = min_or_neg_inf[dtype]() 111 | var max_idx: Int = -1 112 | var ix_base = x * stride[0] - padding[0] 113 | var iy_base = y * stride[1] - padding[1] 114 | for kx in range(kernel_size[0]): 115 | for ky in range(kernel_size[1]): 116 | var ix = ix_base + kx * dilation[0] 117 | var iy = iy_base + ky * dilation[1] 118 | 119 | if ( 120 | ix < 0 121 | or iy < 0 122 | or ix >= input_shape[2] 123 | or iy >= input_shape[3] 124 | ): 125 | continue 126 | 127 | var idx = ( 128 | batch * inputs_strides[0] 129 | + in_ch * inputs_strides[1] 130 | + ix * inputs_strides[2] 131 | + iy 132 | ) 133 | 134 | var val = inputs[idx] 135 | if val > max_val: 136 | max_val = val 137 | max_idx = idx 138 | 139 | var ug_idx = ( 140 | batch * ug_strides[0] 141 | + in_ch * ug_strides[1] 142 | + x * ug_strides[2] 143 | + y 144 | ) 145 | 146 | res[max_idx] += ug[ug_idx] 147 | 148 | return res 149 | -------------------------------------------------------------------------------- /basalt/autograd/params.mojo: -------------------------------------------------------------------------------- 1 | from collections.optional import Optional 2 | from memory import UnsafePointer 3 | 4 | from basalt import dtype 5 | from basalt import Tensor, TensorShape 6 | from .symbol import Symbol 7 | from .attributes import Attribute 8 | 9 | 10 | @value 11 | struct Param(CollectionElement, Stringable): 12 | var data: Optional[List[Scalar[dtype]]] 13 | var initializer: Optional[Attribute] 14 | 15 | fn __init__(inout self): 16 | self.data = None 17 | self.initializer = None 18 | 19 | fn __init__(inout self, data: List[Scalar[dtype]]): 20 | self.data = data 21 | self.initializer = None 22 | 23 | fn __init__(inout self, data: Scalar[dtype]): 24 | self.data = List[Scalar[dtype]](data) 25 | self.initializer = None 26 | 27 | fn __init__(inout self, initializer: String, *args: Scalar[dtype]): 28 | # Supported initializers: 29 | # "random_uniform", lower_bound, upper_bound 30 | # "random_normal", mean, std 31 | # #TODO: "kaiming_uniform", mode, nonlinearity 32 | # #TODO: "kaiming_normal", mode, nonlinearity 33 | self.initializer = Attribute("initializer", initializer) 34 | var data = List[Scalar[dtype]]() 35 | for arg in args: 36 | data.append(arg) 37 | self.data = data 38 | 39 | fn __getitem__(self, i: Int) -> Optional[Scalar[dtype]]: 40 | if self.data: 41 | return self.data.value()[i] 42 | else: 43 | return None 44 | 45 | fn __str__(self) -> String: 46 | var s: String = "" 47 | if self.data: 48 | var data = self.data.value() 49 | s += "[" 50 | for i in range(len(data)): 51 | s += str(data[i]) 52 | if i < len(data) - 1: 53 | s += ", " 54 | s += "]" 55 | return s 56 | 57 | 58 | @value 59 | struct ParamDict(Sized): 60 | var symbols: List[Symbol] 61 | var values: List[Param] 62 | 63 | fn __init__(inout self): 64 | self.symbols = List[Symbol]() 65 | self.values = List[Param]() 66 | 67 | fn put(inout self, param_id: Symbol, value: Param = Param()): 68 | self.symbols.append(param_id) 69 | self.values.append(value) 70 | 71 | fn get_tensor(self, idx: Int) -> Tensor[dtype]: 72 | # May only be called at runtime 73 | var num = self.symbols[idx].shape.num_elements() 74 | var t = UnsafePointer[Scalar[dtype]].alloc(num) 75 | for i in range(num): 76 | t[i] = self.values[idx][i].value() 77 | return Tensor[dtype](t, self.symbols[idx].shape) 78 | 79 | fn __len__(self) -> Int: 80 | return len(self.symbols) 81 | -------------------------------------------------------------------------------- /basalt/autograd/symbol.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Tensor, TensorShape 2 | 3 | 4 | @value 5 | @register_passable("trivial") 6 | struct Symbol(CollectionElement, Stringable, EqualityComparable): 7 | var name: UInt32 8 | var dtype: DType 9 | var shape: TensorShape 10 | var trainable: Bool 11 | 12 | fn __eq__(self, other: Self) -> Bool: 13 | return self.name == other.name 14 | 15 | fn __ne__(self, other: Self) -> Bool: 16 | return self.name != other.name 17 | 18 | fn __str__(self) -> String: 19 | return self.json() 20 | 21 | fn json(self) -> String: 22 | return ( 23 | '{"name": "' 24 | + str(self.name) 25 | + '", "dtype": "' 26 | + str(self.dtype) 27 | + '", "shape": "' 28 | + str(self.shape) 29 | + '", "trainable": "' 30 | + str(self.trainable) 31 | + '"}' 32 | ) 33 | -------------------------------------------------------------------------------- /basalt/nn/__init__.mojo: -------------------------------------------------------------------------------- 1 | from .tensor import Tensor, TensorShape 2 | from .model import Model 3 | 4 | from .layers.linear import Linear 5 | from .layers.conv import Conv2d 6 | from .layers.pool import MaxPool2d 7 | 8 | from .loss import MSELoss, CrossEntropyLoss 9 | from .activations import ( 10 | Softmax, 11 | LogSoftmax, 12 | ReLU, 13 | LeakyReLU, 14 | Sigmoid, 15 | Tanh, 16 | ) 17 | -------------------------------------------------------------------------------- /basalt/nn/activations.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Tensor, TensorShape 2 | from basalt import Graph, Symbol, OP 3 | from basalt.autograd.attributes import Attribute, AttributeVector 4 | 5 | 6 | # '''Activation functions.''' 7 | fn ReLU(inout g: Graph, input: Symbol) -> Symbol: 8 | return g.op(OP.RELU, input) 9 | 10 | 11 | fn LeakyReLU( 12 | inout g: Graph, input: Symbol, negative_slope: Scalar[dtype] 13 | ) -> Symbol: 14 | return g.op( 15 | OP.LEAKYRELU, 16 | input, 17 | attributes=AttributeVector(Attribute("negative_slope", negative_slope)), 18 | ) 19 | 20 | 21 | fn Sigmoid(inout g: Graph, input: Symbol) -> Symbol: 22 | return g.op(OP.SIGMOID, input) 23 | 24 | 25 | fn Tanh(inout g: Graph, input: Symbol) -> Symbol: 26 | return g.op(OP.TANH, input) 27 | 28 | 29 | fn Softmax(inout g: Graph, input: Symbol, axis: Int) -> Symbol: 30 | # softmax: exp(x_i) / sum(exp(x_j)) 31 | # stable softmax: exp(x_i - max(x_j)) / sum(exp(x_j - max(x_j))) 32 | 33 | var max_values = g.op( 34 | OP.MAX, input, attributes=AttributeVector(Attribute("axis", axis)) 35 | ) 36 | var input_minus_max = g.op(OP.SUB, input, max_values) 37 | var exp_values = g.op(OP.EXP, input_minus_max) 38 | var sum_values = g.op( 39 | OP.SUM, exp_values, attributes=AttributeVector(Attribute("axis", axis)) 40 | ) 41 | 42 | return g.op(OP.DIV, exp_values, sum_values) 43 | 44 | 45 | fn LogSoftmax(inout g: Graph, input: Symbol, axis: Int) -> Symbol: 46 | # stable logsoftmax: log(exp(x_i - max(x_j)) / sum(exp(x_j - max(x_j)))) 47 | # stable logsoftmax: x_i - max(x_j) - log(sum(exp(x_j - max(x_j)))) 48 | 49 | var max_values = g.op( 50 | OP.MAX, input, attributes=AttributeVector(Attribute("axis", axis)) 51 | ) 52 | var input_minus_max = g.op(OP.SUB, input, max_values) 53 | var exp_values = g.op(OP.EXP, input_minus_max) 54 | var sum_values = g.op( 55 | OP.SUM, exp_values, attributes=AttributeVector(Attribute("axis", axis)) 56 | ) 57 | var log_values = g.op(OP.LOG, sum_values) 58 | 59 | return g.op(OP.SUB, input_minus_max, log_values) 60 | -------------------------------------------------------------------------------- /basalt/nn/initializers.mojo: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | from basalt import dtype 4 | from basalt import Tensor, TensorShape 5 | from basalt.utils.rand_utils import rand_normal, rand_uniform 6 | 7 | 8 | fn initialize_tensor( 9 | shape: TensorShape, type: String, data: List[Scalar[dtype]] 10 | ) -> Tensor[dtype]: 11 | if type == "random_uniform": 12 | var low = data[0] 13 | var high = data[1] 14 | var t = Tensor[dtype](shape) 15 | rand_uniform(t, low=low, high=high) 16 | return t 17 | elif type == "random_normal": 18 | var mean = data[0].cast[DType.float64]() 19 | var std = data[1].cast[DType.float64]() 20 | var t = Tensor[dtype](shape) 21 | rand_normal(t, mean=mean, std=std) 22 | return t 23 | # elif type == "kaiming_uniform": 24 | # # mode, nonlinearity 25 | # var mode_id = data[0] 26 | # var mode = "fan_in" if mode_id == 0 else "fan_out" 27 | # return kaiming_uniform(shape, mode = mode) 28 | # elif type == "kaiming_normal": 29 | # # mode, nonlinearity 30 | # var mode_id = data[0] 31 | # var mode = "fan_in" if mode_id == 0 else "fan_out" 32 | # return kaiming_normal(shape, mode = mode) 33 | else: 34 | print("[ERROR] Unsupported initialization type: " + type) 35 | return Tensor[dtype]() 36 | 37 | 38 | fn calculate_fan(shape: TensorShape, mode: String) -> Scalar[dtype]: 39 | """ 40 | Calculate the fan-in and fan-out of any tensor. 41 | """ 42 | # NOTE: shape.rank() should be > 2 43 | # mode: "fan_in" or "fan_out" 44 | if shape.rank() < 2: 45 | print( 46 | "[ERROR] Fan in and fan out can not be calculated for tensor with less than" 47 | " 2 dimensions" 48 | ) 49 | 50 | var num_input_fmaps = shape[1] 51 | var num_output_fmaps = shape[0] 52 | var receptive_field_size = 1 53 | if shape.rank() > 2: 54 | for i in range(2, shape.rank()): 55 | receptive_field_size *= shape[i] 56 | 57 | var fan_in = num_input_fmaps * receptive_field_size 58 | var fan_out = num_output_fmaps * receptive_field_size 59 | 60 | if mode == "fan_in": 61 | return fan_in 62 | else: 63 | return fan_out 64 | 65 | 66 | # # TODO: https://pytorch.org/docs/stable/_modules/torch/nn/init.html 67 | # fn kaiming_uniform(shape: TensorShape, mode: String = "fan_in", nonlinearity: String = "leaky_relu") -> Tensor[dtype]: 68 | # var fan = calculate_fan(shape, mode) 69 | 70 | # # TODO: add support for other gains: https://github.com/pytorch/pytorch/blob/main/torch/nn/init.py#L68 71 | # # Gain for linear and conv layers is 1 72 | # var gain = 1 73 | # var std = gain / sqrt(fan) 74 | 75 | # # var bound = sqrt(3) * std.cast[dtype]() 76 | # var bound = std.cast[dtype]() 77 | 78 | # # print("Shape", shape, "Fan", fan, "Bound", bound) 79 | 80 | # var t = Tensor[dtype](shape) 81 | # rand_uniform(t, low = -bound, high = bound) 82 | # return t^ 83 | 84 | 85 | # # TODO: https://pytorch.org/docs/stable/_modules/torch/nn/init.html 86 | # fn kaiming_normal(shape: TensorShape, mode: String = "fan_in", nonlinearity: String = "leaky_relu") -> Tensor[dtype]: 87 | # var fan = calculate_fan(shape, mode) 88 | 89 | # # TODO: add support for other gains: https://github.com/pytorch/pytorch/blob/main/torch/nn/init.py#L68 90 | # # Gain for linear and conv layers is 1 91 | # var gain = 1 92 | # var std = gain / sqrt(fan) 93 | 94 | # var t = Tensor[dtype](shape) 95 | # rand_normal(t, mean = 0, std = std.cast[DType.float64]()) 96 | # return t^ 97 | -------------------------------------------------------------------------------- /basalt/nn/layers/__init__.mojo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basalt-org/basalt/fe16eadcfcfee9271f9df0dd94d11b7c50d868ba/basalt/nn/layers/__init__.mojo -------------------------------------------------------------------------------- /basalt/nn/layers/conv.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Graph, Symbol, OP 2 | from basalt import Tensor, TensorShape 3 | from basalt.utils import q_sqrt 4 | from basalt.autograd.params import Param 5 | from basalt.autograd.attributes import AttributeVector, Attribute 6 | 7 | from utils.index import IndexList 8 | 9 | 10 | fn Conv2d( 11 | inout g: Graph, 12 | inputs: Symbol, 13 | out_channels: Int, 14 | kernel_size: IndexList[2], 15 | padding: IndexList[2] = 0, 16 | stride: IndexList[2] = 1, 17 | dilation: IndexList[2] = 1, 18 | ) -> Symbol: 19 | """ 20 | A 2D Convolution Layer. 21 | 22 | Parameters 23 | inputs.shape [batch, in_channels, iX, iY] 24 | kernel.shape [out_channels, in_channels, kX, kY] (or weights) 25 | bias.shape [out_channels]. 26 | output.shape [batch, out_channels, oX, oY]. 27 | """ 28 | 29 | var in_channels: Int = inputs.shape[1] 30 | var fan_in: Scalar[dtype] = in_channels * kernel_size[0] * kernel_size[1] 31 | var bound = q_sqrt(fan_in) 32 | var weights = g.param( 33 | TensorShape(out_channels, in_channels, kernel_size[0], kernel_size[1]), 34 | init=Param("random_uniform", -bound, bound) 35 | # init=Param("kaiming_uniform", 0) 36 | ) 37 | var bias = g.param( 38 | TensorShape(out_channels), init=Param("random_uniform", -bound, bound) 39 | ) 40 | 41 | return g.op( 42 | OP.CONV2D, 43 | inputs, 44 | weights, 45 | bias, 46 | attributes=AttributeVector( 47 | Attribute("padding", padding), 48 | Attribute("stride", stride), 49 | Attribute("dilation", dilation), 50 | ), 51 | ) 52 | -------------------------------------------------------------------------------- /basalt/nn/layers/dropout.mojo: -------------------------------------------------------------------------------- 1 | # TODO 2 | -------------------------------------------------------------------------------- /basalt/nn/layers/linear.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Tensor, TensorShape 2 | from basalt import Graph, Symbol, OP 3 | from basalt.utils import q_sqrt 4 | from basalt.autograd.params import Param 5 | 6 | 7 | fn Linear( 8 | inout g: Graph, 9 | inputs: Symbol, 10 | n_outputs: Int, 11 | ) -> Symbol: 12 | """ 13 | A fully connected layer. 14 | """ 15 | 16 | var fan_in: Scalar[dtype] = inputs.shape[1] 17 | var bound = q_sqrt(fan_in) 18 | var weights = g.param( 19 | TensorShape(inputs.shape[1], n_outputs), 20 | init=Param("random_uniform", -bound, bound) 21 | # init=Param("random_uniform", 1) # NOTE: mode: fan_out required as weight are defined transposed 22 | ) 23 | var b = g.param(TensorShape(n_outputs), init=Param("random_uniform", -bound, bound)) 24 | 25 | var res = g.op(OP.DOT, inputs, weights) 26 | return g.op(OP.ADD, res, b) 27 | -------------------------------------------------------------------------------- /basalt/nn/layers/pool.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Tensor, TensorShape 2 | from collections.optional import Optional 3 | from utils.index import IndexList 4 | 5 | from basalt import Graph, Symbol, OP 6 | from basalt.autograd.attributes import AttributeVector, Attribute 7 | 8 | 9 | fn set_static_stride( 10 | kernel_size: IndexList[2], stride: Optional[Int] = None 11 | ) -> IndexList[2]: 12 | if stride: 13 | return IndexList[2](stride.value(), stride.value()) 14 | else: 15 | return kernel_size 16 | 17 | 18 | fn MaxPool2d( 19 | inout g: Graph, 20 | inputs: Symbol, 21 | kernel_size: IndexList[2], 22 | stride: Optional[Int] = None, 23 | padding: IndexList[2] = 0, 24 | dilation: IndexList[2] = 1, 25 | ) -> Symbol: 26 | """ 27 | A 2D Max Pooling Layer. 28 | 29 | Kernel is unaware of the in_channels and out_channels of the input tensor. 30 | kernel.size (kX, kY) 31 | """ 32 | 33 | # TODO: assert padding <= kernel_size / 2 (at compile time) 34 | 35 | var stride_temp = set_static_stride(kernel_size, stride) 36 | 37 | return MaxPool2d(g, inputs, kernel_size, stride_temp, padding, dilation) 38 | 39 | 40 | fn MaxPool2d( 41 | inout g: Graph, 42 | inputs: Symbol, 43 | kernel_size: IndexList[2], 44 | stride: IndexList[2], # stride should be 1 or more 45 | padding: IndexList[2] = 0, 46 | dilation: IndexList[2] = 1, 47 | ) -> Symbol: 48 | """ 49 | A 2D Max Pooling Layer. 50 | 51 | Kernel is unaware of the in_channels and out_channels of the input tensor. 52 | kernel.size (kX, kY) 53 | """ 54 | # TODO: assert padding <= kernel_size / 2 (at compile time) 55 | 56 | return g.op( 57 | OP.MAXPOOL2D, 58 | inputs, 59 | attributes=AttributeVector( 60 | Attribute("kernel_size", kernel_size), 61 | Attribute("padding", padding), 62 | Attribute("stride", stride), 63 | Attribute("dilation", dilation), 64 | ), 65 | ) 66 | 67 | 68 | # # TODO 69 | -------------------------------------------------------------------------------- /basalt/nn/layers/sequential.mojo: -------------------------------------------------------------------------------- 1 | # TODO 2 | -------------------------------------------------------------------------------- /basalt/nn/loss.mojo: -------------------------------------------------------------------------------- 1 | import basalt.nn as nn 2 | from basalt import Tensor, TensorShape 3 | from basalt import Graph, Symbol, OP 4 | 5 | 6 | fn MSELoss( 7 | inout g: Graph, 8 | y_pred: Symbol, 9 | y_true: Symbol, 10 | ) -> Symbol: 11 | # 1/N * sum( (outputs - targets)^2 ) 12 | 13 | var diff = g.op(OP.SUB, y_true, y_pred) 14 | var loss = g.op(OP.POW, diff, 2) 15 | var mean_loss = g.op(OP.MEAN, loss) 16 | 17 | return mean_loss 18 | 19 | 20 | fn CrossEntropyLoss( 21 | inout g: Graph, 22 | y_pred: Symbol, 23 | y_true: Symbol, 24 | ) -> Symbol: 25 | # -1/N * sum( targets * log_softmax(outputs) ) 26 | 27 | var log_softmax = nn.LogSoftmax(g, y_pred, axis=1) 28 | 29 | # CrossEntropy (reduction Mean) 30 | var targets_log_softmax = g.op(OP.MUL, y_true, log_softmax) 31 | var ret = g.op(OP.SUM, targets_log_softmax) 32 | var negDivN = g.op(OP.MUL, ret, -1.0 / y_pred.shape[0]) 33 | 34 | return negDivN 35 | -------------------------------------------------------------------------------- /basalt/nn/optim.mojo: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from algorithm import vectorize, parallelize 3 | 4 | from .model import Parameters 5 | from basalt import Graph, Tensor, TensorShape 6 | from basalt.utils.collection import Collection 7 | from basalt.utils.math_util import add, sub, mul, div 8 | 9 | 10 | fn get_trainable_parameters(g: Graph) -> List[Symbol]: 11 | """ 12 | Get all symbols of trainable parameters. 13 | """ 14 | 15 | var trainable_parameters = List[Symbol]() 16 | 17 | for i in range(len(g.params)): 18 | if g.params.symbols[i].trainable: 19 | trainable_parameters.append(g.params.symbols[i]) 20 | 21 | return trainable_parameters ^ 22 | 23 | 24 | @value 25 | struct Adam[ 26 | g: Graph, 27 | trainable_parameters: List[Symbol] = get_trainable_parameters(g), 28 | ]: 29 | var parameters: Pointer[Parameters, MutableAnyOrigin] 30 | 31 | var lr: Scalar[dtype] 32 | var beta1: Scalar[dtype] 33 | var beta2: Scalar[dtype] 34 | var epsilon: Scalar[dtype] 35 | var iter: Int 36 | 37 | var rms_grads: Collection 38 | var momentum_grads: Collection 39 | 40 | fn __init__( 41 | inout self, 42 | ref[MutableAnyOrigin] parameters: Parameters, 43 | lr: Scalar[dtype] = 0.001, 44 | beta1: Scalar[dtype] = 0.9, 45 | beta2: Scalar[dtype] = 0.999, 46 | epsilon: Scalar[dtype] = 1e-8, 47 | ): 48 | self.parameters = Pointer.address_of(parameters) 49 | 50 | self.lr = lr 51 | self.beta1 = beta1 52 | self.beta2 = beta2 53 | self.epsilon = epsilon 54 | self.iter = 0 55 | 56 | # Capacity of the collections should be the n of trainable parameters 57 | self.rms_grads = Collection(capacity=len(trainable_parameters)) 58 | self.momentum_grads = Collection(capacity=len(trainable_parameters)) 59 | 60 | self.allocate_rms_and_momentum() 61 | 62 | fn zero_grad(inout self): 63 | """Set all gradients to zero.""" 64 | self.parameters[].grads.set_zero() 65 | 66 | fn step(inout self): 67 | """Update model parameters.""" 68 | self.iter += 1 69 | 70 | # Loop over all trainable parameters 71 | @parameter 72 | fn p_step(i: Int): 73 | var param = trainable_parameters[i] 74 | 75 | @parameter 76 | fn v_step[nelts: Int](j: Int): 77 | var momentum_grads = self.momentum_grads[param].load[nelts](j) 78 | var rms_grads = self.rms_grads[param].load[nelts](j) 79 | var grads = self.parameters[].grads[param].load[nelts](j) 80 | var params = self.parameters[].tensors[param].load[nelts](j) 81 | 82 | # Momentum beta 1 83 | # f1 = beta1 * momentum + (1 - beta1) * grad 84 | momentum_grads = self.beta1 * momentum_grads + (1 - self.beta1) * grads 85 | self.momentum_grads[param].store[nelts](j, momentum_grads) 86 | 87 | # Bias correction 88 | # f2 = f1 / (1 - beta1 ** iter) 89 | momentum_grads = momentum_grads / (1 - self.beta1**self.iter) 90 | 91 | # RMS beta 2 92 | # f1 = beta2 * rms + (1 - beta2) * grad ** 2 93 | rms_grads = self.beta2 * rms_grads + (1 - self.beta2) * grads * grads 94 | self.rms_grads[param].store[nelts](j, rms_grads) 95 | 96 | # Bias correction 97 | # f2 = f1 / (1 - beta2 ** iter) 98 | rms_grads = rms_grads / (1 - self.beta2**self.iter) 99 | 100 | # tensor = tensor - lr * (f2 / (sqrt(rms) + epsilon)) 101 | params = params - self.lr * ( 102 | momentum_grads / (sqrt(rms_grads) + self.epsilon) 103 | ) 104 | self.parameters[].tensors[param].store[nelts](j, params) 105 | 106 | vectorize[v_step, 1](param.shape.num_elements()) 107 | 108 | parallelize[p_step](len(trainable_parameters)) 109 | 110 | fn allocate_rms_and_momentum(inout self): 111 | # They are initialized to zero 112 | # Loop over all trainable parameters 113 | for i in range(len(trainable_parameters)): 114 | var param = trainable_parameters[i] 115 | self.rms_grads.append(Tensor[dtype](param.shape), param) 116 | self.momentum_grads.append(Tensor[dtype](param.shape), param) 117 | -------------------------------------------------------------------------------- /basalt/nn/tensor.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_true 2 | from algorithm import vectorize 3 | from utils.index import IndexList 4 | from memory import memset_zero, memcpy, UnsafePointer 5 | 6 | 7 | alias MAX_RANK = 8 8 | 9 | 10 | @register_passable("trivial") 11 | struct TensorShape(Stringable): 12 | var _rank: Int 13 | var _shape: IndexList[MAX_RANK] 14 | 15 | fn __init__(inout self, *shape: Int): 16 | self._rank = len(shape) 17 | self._shape = IndexList[MAX_RANK]() 18 | for i in range(min(self._rank, MAX_RANK)): 19 | self._shape[i] = shape[i] 20 | 21 | fn __init__(inout self, shapes: VariadicList[Int]): 22 | self._rank = len(shapes) 23 | self._shape = IndexList[MAX_RANK]() 24 | for i in range(min(self._rank, MAX_RANK)): 25 | self._shape[i] = shapes[i] 26 | 27 | fn __init__(inout self, shape: List[Int]): 28 | self._rank = len(shape) 29 | self._shape = IndexList[MAX_RANK]() 30 | for i in range(min(self._rank, MAX_RANK)): 31 | self._shape[i] = shape[i] 32 | 33 | fn __init__[num: Int](inout self, shape: IndexList[num]): 34 | self._rank = num 35 | self._shape = IndexList[MAX_RANK]() 36 | for i in range(min(self._rank, MAX_RANK)): 37 | self._shape[i] = shape[i] 38 | 39 | fn __init__(inout self, rank: Int, shape: IndexList[MAX_RANK]): 40 | self._rank = rank 41 | self._shape = shape 42 | 43 | @always_inline("nodebug") 44 | fn __getitem__(self, index: Int) -> Int: 45 | return self._shape[index if index >= 0 else self._rank + index] 46 | 47 | @always_inline("nodebug") 48 | fn __setitem__(inout self, index: Int, value: Int): 49 | self._shape[index if index >= 0 else self._rank + index] = value 50 | 51 | @always_inline("nodebug") 52 | fn rank(self) -> Int: 53 | return self._rank 54 | 55 | fn num_elements(self) -> Int: 56 | var result = 1 57 | for i in range(self._rank): 58 | result *= self._shape[i] 59 | return result 60 | 61 | fn strides(self) -> IndexList[MAX_RANK]: 62 | var result = IndexList[MAX_RANK](0) 63 | result[self._rank - 1] = 1 64 | for i in range(self._rank - 2, -1, -1): 65 | result[i] = result[i + 1] * self._shape[i + 1] 66 | return result 67 | 68 | fn __str__(self) -> String: 69 | var s: String = "(" 70 | for i in range(self._rank): 71 | s += str(self._shape[i]) 72 | if i < self._rank - 1: 73 | s += ", " 74 | return s + ")" 75 | 76 | @always_inline("nodebug") 77 | fn __eq__(self, other: TensorShape) -> Bool: 78 | if self.rank() != other.rank(): 79 | return False 80 | for i in range(self.rank()): 81 | if self[i] != other[i]: 82 | return False 83 | return True 84 | 85 | @always_inline("nodebug") 86 | fn __ne__(self, other: TensorShape) -> Bool: 87 | return not self.__eq__(other) 88 | 89 | fn __contains__(self, value: Int) -> Bool: 90 | for i in range(self.rank()): 91 | if self[i] == value: 92 | return True 93 | return False 94 | 95 | fn to_list(self) -> List[Int]: 96 | var result = List[Int]() 97 | for i in range(self.rank()): 98 | result.append(self[i]) 99 | return result 100 | 101 | 102 | struct Tensor[dtype: DType](Stringable, Movable, CollectionElement): 103 | var _data: UnsafePointer[Scalar[dtype]] 104 | var _shape: TensorShape 105 | 106 | fn __init__(inout self, *dims: Int): 107 | self._shape = TensorShape(dims) 108 | self._data = UnsafePointer[Scalar[dtype]].alloc(self._shape.num_elements()) 109 | memset_zero(self._data, self._shape.num_elements()) 110 | 111 | fn __init__(inout self, owned shape: TensorShape): 112 | self._data = UnsafePointer[Scalar[dtype]].alloc(shape.num_elements()) 113 | memset_zero(self._data, shape.num_elements()) 114 | self._shape = shape 115 | 116 | fn __init__(inout self, shapes: VariadicList[Int]): 117 | self._shape = TensorShape(shapes) 118 | self._data = UnsafePointer[Scalar[dtype]].alloc(self._shape.num_elements()) 119 | memset_zero(self._data, self._shape.num_elements()) 120 | 121 | fn __init__( 122 | inout self, owned data: UnsafePointer[Scalar[dtype]], owned shape: TensorShape 123 | ): 124 | # NOTE: Remember to use _ = your_tensor that you passed, so there is no weird behavior in this function 125 | self._data = UnsafePointer[Scalar[dtype]].alloc(shape.num_elements()) 126 | self._shape = shape 127 | 128 | memcpy(self._data, data, self._shape.num_elements()) 129 | _ = data 130 | 131 | fn __moveinit__(inout self, owned other: Tensor[dtype]): 132 | self._data = other._data 133 | self._shape = other._shape 134 | 135 | fn __copyinit__(inout self, other: Tensor[dtype]): 136 | # print("[WARNING] Copying tensor") 137 | self._data = UnsafePointer[Scalar[dtype]].alloc(other._shape.num_elements()) 138 | memcpy(self._data, other._data, other.num_elements()) 139 | self._shape = other._shape 140 | 141 | @always_inline("nodebug") 142 | fn __getitem__(self, index: Int) -> Scalar[dtype]: 143 | return self._data[index] 144 | 145 | @always_inline("nodebug") 146 | fn __setitem__(self, index: Int, value: Scalar[dtype]): 147 | self._data[index] = value 148 | 149 | @always_inline("nodebug") 150 | fn data(self) -> UnsafePointer[Scalar[dtype]]: 151 | return self._data 152 | 153 | @always_inline("nodebug") 154 | fn shape(self) -> TensorShape: 155 | return self._shape 156 | 157 | @always_inline("nodebug") 158 | fn load[simd_width: Int](self, index: Int) -> SIMD[dtype, simd_width]: 159 | return self._data.load[width=simd_width](index) 160 | 161 | @always_inline("nodebug") 162 | fn store[simd_width: Int](self, index: Int, value: SIMD[dtype, simd_width]): 163 | self._data.store(index, value) 164 | 165 | @always_inline("nodebug") 166 | fn strides(self) -> IndexList[MAX_RANK]: 167 | return self._shape.strides() 168 | 169 | @always_inline("nodebug") 170 | fn rank(self) -> Int: 171 | return self._shape.rank() 172 | 173 | @always_inline("nodebug") 174 | fn num_elements(self) -> Int: 175 | return self._shape.num_elements() 176 | 177 | @always_inline("nodebug") 178 | fn dim(self, index: Int) -> Int: 179 | return self._shape[index] 180 | 181 | @always_inline("nodebug") 182 | fn zero(self): 183 | memset_zero(self._data, self.num_elements()) 184 | 185 | @always_inline("nodebug") 186 | fn ireshape(inout self, new_shape: TensorShape) raises: 187 | # NOTE Consider not raising on error 188 | assert_true(self.num_elements() == new_shape.num_elements()) 189 | self._shape = new_shape 190 | 191 | fn __str__(self) -> String: 192 | # temp fix 193 | var s: String = "[" 194 | for i in range(self.num_elements()): 195 | s += str(self[i]) 196 | if i < self.num_elements() - 1: 197 | s += ", " 198 | return s + "]" 199 | 200 | 201 | @always_inline("nodebug") 202 | fn __del__(owned self): 203 | self._data.free() 204 | -------------------------------------------------------------------------------- /basalt/utils/__init__.mojo: -------------------------------------------------------------------------------- 1 | from memory.unsafe import bitcast 2 | 3 | 4 | @always_inline("nodebug") 5 | fn q_sqrt(value: Float32) -> Float32: 6 | var y = bitcast[DType.float32](0x5F3759DF - (bitcast[DType.uint32](value) >> 1)) 7 | return -y * ((0.5 * value * y).fma(y, -1.5)) 8 | -------------------------------------------------------------------------------- /basalt/utils/bytes.mojo: -------------------------------------------------------------------------------- 1 | from math import nan 2 | from utils.numerics import inf 3 | from utils.static_tuple import StaticTuple 4 | 5 | alias ScalarBytes = DType.uint64.sizeof() 6 | 7 | 8 | @register_passable("trivial") 9 | struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable): 10 | """ 11 | Static sequence of bytes. 12 | """ 13 | 14 | var data: StaticTuple[UInt8, capacity] 15 | 16 | fn __init__(inout self): 17 | var data = StaticTuple[UInt8, capacity](0) 18 | 19 | for i in range(capacity): 20 | data[i] = 0 21 | 22 | self.data = data 23 | 24 | fn __init__(inout self, s: String): 25 | var data = StaticTuple[UInt8, capacity](0) 26 | var length = len(s) 27 | 28 | for i in range(capacity): 29 | data[i] = ord(s[i]) if i < length else 0 30 | 31 | self.data = data 32 | 33 | @always_inline("nodebug") 34 | fn __len__(self) -> Int: 35 | return capacity 36 | 37 | @always_inline("nodebug") 38 | fn __setitem__(inout self, index: Int, value: UInt8): 39 | self.data[index] = value 40 | 41 | @always_inline("nodebug") 42 | fn __getitem__(self, index: Int) -> UInt8: 43 | return self.data[index] 44 | 45 | @always_inline("nodebug") 46 | fn __eq__(self, other: Self) -> Bool: 47 | for i in range(capacity): 48 | if self[i] != other[i]: 49 | return False 50 | return True 51 | 52 | @always_inline("nodebug") 53 | fn __ne__(self, other: Self) -> Bool: 54 | for i in range(capacity): 55 | if self[i] != other[i]: 56 | return True 57 | return False 58 | 59 | @always_inline("nodebug") 60 | fn __str__(self) -> String: 61 | var result: String = "" 62 | 63 | for i in range(capacity): 64 | var val = self[i] 65 | if val != 0: 66 | result += chr(int(val)) 67 | 68 | return result 69 | 70 | 71 | fn scalar_to_bytes[ 72 | dtype: DType, Size: Int = ScalarBytes 73 | ](value: Scalar[dtype]) -> Bytes[Size]: 74 | constrained[Size >= ScalarBytes, "Size must be at least ${ScalarBytes}"]() 75 | 76 | var bits = bitcast[DType.uint64](value.cast[expand_type[dtype]()]()) 77 | var data = Bytes[Size]() 78 | 79 | for i in range(ScalarBytes): 80 | data[i] = (bits >> (i << 3)).cast[DType.uint8]() 81 | 82 | return data 83 | 84 | 85 | fn bytes_to_scalar[dtype: DType](data: Bytes) -> Scalar[dtype]: 86 | constrained[data.capacity >= ScalarBytes, "Size must be at least ${ScalarBytes}"]() 87 | 88 | var bits: UInt64 = 0 89 | 90 | for i in range(ScalarBytes): 91 | bits |= data[i].cast[DType.uint64]() << (i << 3) 92 | 93 | return bitcast[expand_type[dtype]()](bits).cast[dtype]() 94 | 95 | 96 | fn expand_type[dtype: DType]() -> DType: 97 | @parameter 98 | if dtype.is_floating_point(): 99 | return DType.float64 100 | elif dtype.is_signed(): 101 | return DType.int64 102 | elif dtype.is_integral(): 103 | return DType.uint64 104 | 105 | constrained[False, "Type must be numeric"]() 106 | return DType.invalid 107 | -------------------------------------------------------------------------------- /basalt/utils/collection.mojo: -------------------------------------------------------------------------------- 1 | from memory.unsafe_pointer import UnsafePointer 2 | from memory import memset_zero, memcpy 3 | 4 | from basalt import Tensor, Symbol 5 | 6 | 7 | struct Collection(CollectionElement, Sized): 8 | """ 9 | A collection of tensors with associated symbols. 10 | """ 11 | 12 | var size: Int 13 | var capacity: Int 14 | var data: UnsafePointer[Tensor[dtype]] 15 | var symbols: UnsafePointer[Scalar[DType.uint32]] 16 | 17 | @always_inline("nodebug") 18 | fn __init__(inout self, *, capacity: Int = 1): 19 | """ 20 | Initializes a new Collection with the given capacity. 21 | """ 22 | self.size = 0 23 | self.capacity = capacity 24 | self.data = UnsafePointer[Tensor[dtype]].alloc(capacity) 25 | UnsafePointer.init_pointee_move((self.data + self.size), Tensor[dtype]()) 26 | self.symbols = UnsafePointer[Scalar[DType.uint32]].alloc(capacity) 27 | 28 | @always_inline("nodebug") 29 | fn __moveinit__(inout self, owned existing: Self): 30 | """ 31 | Move initializes a Collection from an existing one. 32 | """ 33 | self.size = existing.size 34 | self.capacity = existing.capacity 35 | self.data = existing.data 36 | self.symbols = existing.symbols 37 | 38 | @always_inline("nodebug") 39 | fn __copyinit__(inout self, existing: Self): 40 | """ 41 | Copy initializes a Collection from an existing one. 42 | """ 43 | self.capacity = existing.capacity 44 | self.size = existing.size 45 | self.data = UnsafePointer[Tensor[dtype]].alloc(existing.capacity) 46 | self.symbols = UnsafePointer[Scalar[DType.uint32]].alloc(existing.capacity) 47 | memcpy(self.symbols, existing.symbols, existing.capacity) 48 | 49 | for i in range(existing.size): 50 | UnsafePointer.init_pointee_move((self.data + i), (existing.data + i)[]) 51 | 52 | @always_inline("nodebug") 53 | fn __del__(owned self): 54 | """ 55 | Destructor for the Collection. 56 | """ 57 | for i in range(self.size): 58 | UnsafePointer.destroy_pointee((self.data + i)) 59 | if self.data: 60 | self.data.free() 61 | if self.symbols: 62 | self.symbols.free() 63 | 64 | @always_inline("nodebug") 65 | fn __len__(self) -> Int: 66 | """ 67 | Returns the number of elements in the Collection. 68 | """ 69 | return self.size 70 | 71 | @always_inline("nodebug") 72 | fn _realloc(inout self, new_capacity: Int): 73 | """ 74 | Reallocates the Collection to the new capacity. 75 | """ 76 | var new_data = UnsafePointer[Tensor[dtype]].alloc(new_capacity) 77 | var new_symbols = UnsafePointer[Scalar[DType.uint32]].alloc(new_capacity) 78 | 79 | for i in range(self.size): 80 | UnsafePointer.init_pointee_move((new_data + i), (self.data + i)[]) 81 | new_symbols[i] = self.symbols[i] 82 | 83 | self.data.free() 84 | self.symbols.free() 85 | 86 | self.data = new_data 87 | self.symbols = new_symbols 88 | self.capacity = new_capacity 89 | 90 | @always_inline("nodebug") 91 | fn append(inout self, owned value: Tensor[dtype], symbol: Symbol): 92 | """ 93 | Appends a tensor and its associated symbol to the Collection. 94 | """ 95 | self.append(value ^, symbol.name) 96 | 97 | @always_inline("nodebug") 98 | fn append(inout self, owned value: Tensor[dtype], symbol_name: UInt32): 99 | """ 100 | Appends a tensor and its associated symbol name to the Collection. 101 | """ 102 | if self.size >= self.capacity: 103 | self._realloc(max(1, self.capacity * 2)) 104 | UnsafePointer.init_pointee_move((self.data + self.size), value ^) 105 | self.symbols[self.size] = symbol_name 106 | self.size += 1 107 | 108 | @always_inline("nodebug") 109 | fn get_index(self, symbol_name: UInt32) -> Int: 110 | """ 111 | Returns the index of the tensor with the given symbol name. 112 | """ 113 | alias factor = 8 114 | # 2 -> 5.32s MNIST 115 | # 4 -> 4.95s MNIST 116 | # 8 -> 4.85s MNIST 117 | # 16 -> 5.19s MNIST 118 | # NOTE: This ideally should just be a hashmap 119 | 120 | for i in range(0, self.size, factor): 121 | var elems = self.symbols.load[width=factor](i) == symbol_name 122 | 123 | for j in range(factor): 124 | if elems[j]: 125 | return i + j 126 | 127 | var split = divmod(self.size, factor) 128 | 129 | for i in range(split[1]): 130 | var index = split[0] + i 131 | 132 | if self.symbols[index] == symbol_name: 133 | return index 134 | 135 | return -1 136 | 137 | fn __getitem__( 138 | self, 139 | symbol: Symbol, 140 | ) -> ref[self.data[0]] Tensor[dtype]: 141 | # TODO: This is a hack, we should instead use dict, because there can be cases where the object doesn't exist and also self.data[0] can be a value that doesn't exit because the list is empty (but we hack this by assigning an empty value) 142 | """ 143 | Returns a reference to the tensor with the given symbol. 144 | """ 145 | var index = self.get_index(symbol.name) 146 | 147 | 148 | return (self.data + index)[] 149 | 150 | @always_inline("nodebug") 151 | fn clear(inout self): 152 | """ 153 | Clears the Collection, removing all tensors and symbols. 154 | """ 155 | for i in range(self.size): 156 | UnsafePointer.destroy_pointee((self.data + i)) 157 | memset_zero(self.symbols, self.capacity) 158 | self.size = 0 159 | 160 | @always_inline("nodebug") 161 | fn set_zero(self): 162 | """ 163 | Zeroes out all the tensors in the collection. 164 | """ 165 | for i in range(self.size): 166 | self.data[i].zero() 167 | -------------------------------------------------------------------------------- /basalt/utils/dataloader.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal 2 | from math import min 3 | from memory import memcpy 4 | 5 | from basalt import dtype, nelts 6 | from basalt import Tensor, TensorShape 7 | 8 | 9 | @value 10 | struct Batch[dtype: DType](CollectionElement): 11 | var data: Tensor[dtype] 12 | var labels: Tensor[dtype] 13 | 14 | fn __init__(inout self, batch_data: Tensor[dtype], batch_labels: Tensor[dtype]): 15 | self.data = batch_data 16 | self.labels = batch_labels 17 | 18 | fn __init__( 19 | inout self, 20 | df_data: Tensor[dtype], 21 | df_labels: Tensor[dtype], 22 | start: Int, 23 | batch_data_shape: TensorShape, 24 | batch_labels_shape: TensorShape, 25 | ): 26 | # TODO: find a better way to do this 27 | # Links to the copies of the input tensors in model.forward() 28 | self.data = Tensor[dtype](batch_data_shape) 29 | self.labels = Tensor[dtype](batch_labels_shape) 30 | memcpy( 31 | self.data.data(), 32 | df_data.data().offset(start * batch_data_shape.strides()[0]), 33 | batch_data_shape.num_elements(), 34 | ) 35 | memcpy( 36 | self.labels.data(), 37 | df_labels.data().offset(start * batch_labels_shape.strides()[0]), 38 | batch_labels_shape.num_elements(), 39 | ) 40 | 41 | fn __getitem__(self, index: Int) -> Tensor[dtype]: 42 | if index == 0: 43 | return self.data 44 | elif index == 1: 45 | return self.labels 46 | else: 47 | print("[ERROR] Batch.__getitem__(): Index out of bounds") 48 | return Tensor[dtype]() 49 | 50 | 51 | @value 52 | struct DataLoader: 53 | var data: Tensor[dtype] 54 | var labels: Tensor[dtype] 55 | var batch_size: Int 56 | var _current_index: Int 57 | var _num_batches: Int 58 | var _data_batch_shape: TensorShape 59 | var _label_batch_shape: TensorShape 60 | 61 | fn __init__( 62 | inout self, 63 | data: Tensor[dtype], 64 | labels: Tensor[dtype], 65 | batch_size: Int, 66 | ): 67 | self.data = data 68 | self.labels = labels 69 | self.batch_size = batch_size 70 | 71 | # Number of batches to iter, NOTE: ignore the remainder for now 72 | # var remainder = 1 if self.data.dim(0) % self.batch_size != 0 else 0 73 | self._current_index = 0 74 | self._num_batches = self.data.dim(0) // self.batch_size # + remainder 75 | 76 | # Batch shapes 77 | self._data_batch_shape = self.data.shape() 78 | self._label_batch_shape = self.labels.shape() 79 | self._data_batch_shape[0] = self.batch_size 80 | self._label_batch_shape[0] = self.batch_size 81 | 82 | @always_inline 83 | fn __len__(self) -> Int: 84 | """ 85 | Returns the number of the batches left in the dataset. 86 | """ 87 | return self._num_batches 88 | 89 | fn __iter__(self) -> Self: 90 | # TODO: Starting the iterator requires to return (COPY!) the whole dataloader which containts the whole dataset 91 | # Does this mean that the whole dataset is copied every epoch ?! 92 | return self 93 | 94 | fn __has_next__(self) -> Bool: 95 | return self._num_batches > 0 96 | 97 | fn __next__(inout self) -> Batch[dtype]: 98 | # NOTE: ignore the remainder for now 99 | # var end = min(self._current_index + self.batch_size, self.data.dim(0)) 100 | # self._data_shape[0] = end - self._current_index 101 | # self._label_shape[0] = end - self._current_index 102 | 103 | var temp_current_index = self._current_index 104 | self._current_index += self.batch_size 105 | self._num_batches -= 1 106 | 107 | return Batch[dtype]( 108 | self.data, 109 | self.labels, 110 | temp_current_index, 111 | self._data_batch_shape, 112 | self._label_batch_shape, 113 | ) 114 | -------------------------------------------------------------------------------- /basalt/utils/datasets.mojo: -------------------------------------------------------------------------------- 1 | from algorithm import vectorize 2 | 3 | from basalt import dtype 4 | from basalt import Tensor, TensorShape 5 | from basalt.utils.tensorutils import elwise_op, tmean, tstd 6 | 7 | 8 | @always_inline 9 | fn div[dtype: DType, simd_width: Int](a: SIMD[dtype, simd_width], b: Scalar[dtype]) -> SIMD[dtype, simd_width]: 10 | return a / b 11 | 12 | 13 | struct BostonHousing: 14 | alias n_inputs = 13 15 | 16 | var data: Tensor[dtype] 17 | var labels: Tensor[dtype] 18 | 19 | fn __init__(inout self, file_path: String) raises: 20 | var s = read_file(file_path) 21 | # Skip the first and last lines 22 | # This does assume your last line in the file has a newline at the end 23 | var list_of_lines = s.split("\n")[1:-1] 24 | 25 | # Length is number of lines 26 | var N = len(list_of_lines) 27 | 28 | self.data = Tensor[dtype](N, self.n_inputs) # All columns except the last one 29 | self.labels = Tensor[dtype](N, 1) # Only the last column (MEDV) 30 | 31 | var line: List[String] = List[String]() 32 | 33 | # Load data in Tensor 34 | for item in range(N): 35 | line = list_of_lines[item].split(",") 36 | self.labels[item] = cast_string[dtype](line[-1]) 37 | 38 | for n in range(self.n_inputs): 39 | self.data[item * self.n_inputs + n] = cast_string[dtype](line[n]) 40 | 41 | # Normalize data 42 | # TODO: redo when tensorutils tmean2 and tstd2 are implemented 43 | alias nelts = simdwidthof[dtype]() 44 | var col = Tensor[dtype](N) 45 | for j in range(self.n_inputs): 46 | for k in range(N): 47 | col[k] = self.data[k * self.n_inputs + j] 48 | for i in range(N): 49 | self.data[i * self.n_inputs + j] = (self.data[i * self.n_inputs + j] - tmean(col)) / tstd(col) 50 | 51 | 52 | struct MNIST: 53 | var data: Tensor[dtype] 54 | var labels: Tensor[dtype] 55 | 56 | fn __init__(inout self, file_path: String) raises: 57 | var s = read_file(file_path) 58 | # Skip the first and last lines 59 | # This does assume your last line in the file has a newline at the end 60 | var list_of_lines = s.split("\n")[1:-1] 61 | 62 | # Length is number of lines 63 | var N = len(list_of_lines) 64 | self.data = Tensor[dtype](N, 1, 28, 28) 65 | self.labels = Tensor[dtype](N) 66 | 67 | var line: List[String] = List[String]() 68 | 69 | # Load data in Tensor 70 | for item in range(N): 71 | line = list_of_lines[item].split(",") 72 | self.labels[item] = atol(line[0]) 73 | for i in range(self.data.shape()[2]): 74 | for j in range(self.data.shape()[3]): 75 | self.data[item * 28 * 28 + i * 28 + j] = atol(line[i * 28 + j + 1]) 76 | 77 | # Normalize data 78 | alias nelts = simdwidthof[dtype]() 79 | 80 | @parameter 81 | fn vecdiv[nelts: Int](idx: Int): 82 | self.data.store[nelts](idx, div(self.data.load[nelts](idx), 255.0)) 83 | 84 | vectorize[vecdiv, nelts](self.data.num_elements()) 85 | 86 | 87 | fn read_file(file_path: String) raises -> String: 88 | var s: String 89 | with open(file_path, "r") as f: 90 | s = f.read() 91 | return s 92 | 93 | 94 | fn find_first(s: String, delimiter: String) -> Int: 95 | for i in range(len(s)): 96 | if s[i] == delimiter: 97 | return i 98 | return -1 99 | 100 | 101 | fn cast_string[dtype: DType](s: String) raises -> Scalar[dtype]: 102 | """ 103 | Cast a string with decimal to a SIMD vector of dtype. 104 | """ 105 | 106 | var idx = find_first(s, delimiter=".") 107 | var x: Scalar[dtype] = -1 108 | 109 | if idx == -1: 110 | # No decimal point 111 | x = atol(s) 112 | return x 113 | else: 114 | var c_int: Scalar[dtype] 115 | var c_frac: Scalar[dtype] 116 | c_int = atol(s[:idx]) 117 | c_frac = atol(s[idx + 1 :]) 118 | x = c_int + c_frac / (10 ** len(s[idx + 1 :])) 119 | return x 120 | -------------------------------------------------------------------------------- /basalt/utils/graph_render.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper 3 | from onnx import TensorProto 4 | import netron 5 | 6 | 7 | def get_param_data(param_shape): 8 | factor = 1 9 | for dim in param_shape: 10 | factor *= dim 11 | return [0] * factor 12 | 13 | 14 | def create_onnx_graph_from_json(graph, type="node"): 15 | # Create a list to hold nodes, inputs, outputs, and initializers 16 | nodes = [] 17 | inputs = [] 18 | outputs = [] 19 | initializers = [] 20 | intermediates = [] 21 | 22 | # Process params as initializers (if operator-graph) 23 | visited = [] 24 | if type == "operator": 25 | onnx_inputs = graph["inputs"] + graph.get("params", []) 26 | elif type == "node": 27 | onnx_inputs = graph["inputs"] 28 | 29 | # Process params as initializers 30 | for initializer in graph.get("params", []): 31 | name = initializer["name"] 32 | dtype = TensorProto.FLOAT # TODO 33 | shape = list(map(int, initializer["shape"].split("x"))) 34 | tensor = helper.make_tensor(name, dtype, shape, get_param_data(shape)) 35 | initializers.append(tensor) 36 | visited.append(name) 37 | 38 | # Process inputs 39 | for input in onnx_inputs: 40 | name = input["name"] 41 | dtype = TensorProto.FLOAT # TODO 42 | shape = list(map(int, input["shape"].split("x"))) 43 | inputs.append(helper.make_tensor_value_info(name, dtype, shape)) 44 | visited.append(name) 45 | 46 | # Process outputs 47 | for output in graph["outputs"]: 48 | name = output["name"] 49 | dtype = TensorProto.FLOAT # TODO 50 | shape = list(map(int, output["shape"].split("x"))) 51 | outputs.append(helper.make_tensor_value_info(name, dtype, shape)) 52 | visited.append(name) 53 | 54 | # Process nodes 55 | for node in graph["nodes"]: 56 | operator = node["operator"] 57 | onnx_node = helper.make_node( 58 | operator, 59 | inputs=[input["name"] for input in node["inputs"]], 60 | outputs=[output["name"] for output in node["outputs"]], 61 | name=f"{node['operator']}_node", 62 | ) 63 | 64 | # Process attributes 65 | for attribute in node["attributes"]: 66 | attr_type = 0 67 | if attribute["type"] == "FLOAT": 68 | attr_type = onnx.AttributeProto.FLOAT 69 | elif attribute["type"] == "INT": 70 | attr_type = onnx.AttributeProto.INT 71 | elif attribute["type"] == "STRING": 72 | attr_type = onnx.AttributeProto.STRING 73 | elif attribute["type"] == "FLOATS": 74 | attr_type = onnx.AttributeProto.FLOATS 75 | elif attribute["type"] == "INTS": 76 | attr_type = onnx.AttributeProto.INTS 77 | else: 78 | raise ValueError(f"Unsupported attribute type: {attribute['type']}") 79 | 80 | onnx_attribute = helper.make_attribute( 81 | attribute["name"], attribute["value"], attr_type=attr_type 82 | ) 83 | onnx_node.attribute.append(onnx_attribute) 84 | 85 | nodes.append(onnx_node) 86 | 87 | # Process intermediates 88 | for output in node["outputs"]: 89 | if output["name"] not in visited: 90 | name = output["name"] 91 | dtype = TensorProto.FLOAT 92 | shape = list(map(int, output["shape"].split("x"))) 93 | intermediates.append(helper.make_tensor_value_info(name, dtype, shape)) 94 | visited.append(name) 95 | 96 | # Process loss 97 | if "loss" in graph.keys(): 98 | loss = graph["loss"][0] 99 | name = loss["name"] 100 | if name not in visited: 101 | dtype = TensorProto.FLOAT 102 | shape = list(map(int, loss["shape"].split("x"))) 103 | outputs.append(helper.make_tensor_value_info(name, dtype, shape)) 104 | visited.append(name) 105 | 106 | # Create the graph 107 | graph_def = helper.make_graph( 108 | nodes, 109 | graph.get("graph_name", "basalt-ONNX"), 110 | inputs, 111 | outputs, 112 | initializer=initializers, 113 | value_info=intermediates, 114 | ) 115 | 116 | # Create the model 117 | model_def = helper.make_model(graph_def, producer_name="basalt") 118 | 119 | # Save the model to a file 120 | onnx.save(model_def, "output_model.onnx") 121 | 122 | 123 | def netron_render(graph, type="node"): 124 | assert type in ["node", "operator"] 125 | create_onnx_graph_from_json(graph, type=type) 126 | netron.start("output_model.onnx") 127 | -------------------------------------------------------------------------------- /basalt/utils/math_util.mojo: -------------------------------------------------------------------------------- 1 | @always_inline 2 | fn add[ 3 | dtype: DType, simd_width: Int 4 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ 5 | dtype, simd_width 6 | ]: 7 | return a + b 8 | 9 | 10 | @always_inline 11 | fn sub[ 12 | dtype: DType, simd_width: Int 13 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ 14 | dtype, simd_width 15 | ]: 16 | return a - b 17 | 18 | 19 | @always_inline 20 | fn mul[ 21 | dtype: DType, simd_width: Int 22 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ 23 | dtype, simd_width 24 | ]: 25 | return a * b 26 | 27 | 28 | @always_inline 29 | fn div[ 30 | dtype: DType, simd_width: Int 31 | ](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ 32 | dtype, simd_width 33 | ]: 34 | return a / b 35 | 36 | 37 | @always_inline 38 | fn round_simd[ 39 | dtype: DType, simd_width: Int 40 | ](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]: 41 | return round(x) 42 | -------------------------------------------------------------------------------- /basalt/utils/perf_utils.mojo: -------------------------------------------------------------------------------- 1 | from time.time import monotonic as now 2 | from memory import UnsafePointer, memcpy, memset 3 | 4 | from basalt.autograd.node import Node 5 | 6 | 7 | @always_inline("nodebug") 8 | fn fit_string[num: Int](s: String) -> String: 9 | var data = UnsafePointer[Byte]().alloc(num + 1) 10 | var copy_len = min(num, len(s)) 11 | 12 | memcpy(data, s.unsafe_ptr(), copy_len) 13 | memset(data + copy_len, ord(" "), num - copy_len) 14 | data[num] = 0 15 | 16 | return String(ptr=data, length=num + 1) 17 | 18 | 19 | @always_inline("nodebug") 20 | fn truncate_decimals[num: Int](s: String) -> String: 21 | try: 22 | var parts = s.split(".") 23 | var truncated = parts[0] 24 | 25 | if len(parts) > 1: 26 | var decimal_parts = parts[1].split("e") 27 | truncated += "." + fit_string[num](decimal_parts[0]) 28 | 29 | if len(decimal_parts) > 1: 30 | truncated += "e" + decimal_parts[1] 31 | 32 | return truncated 33 | except e: 34 | print("[WARNING] could not truncate decimals: ", e) 35 | return s 36 | 37 | 38 | @value 39 | struct PerfMetricsValues: 40 | var node: Node 41 | var ns: Float64 42 | 43 | 44 | @value 45 | struct PerfMetrics: 46 | var forward_perf_metrics: List[PerfMetricsValues] 47 | var backward_perf_metrics: List[PerfMetricsValues] 48 | var epochs_forward: Int 49 | var epochs_backward: Int 50 | var start: Int 51 | 52 | fn __init__(inout self): 53 | self.forward_perf_metrics = List[PerfMetricsValues]() 54 | self.backward_perf_metrics = List[PerfMetricsValues]() 55 | self.epochs_forward = 0 56 | self.epochs_backward = 0 57 | self.start = 0 58 | 59 | fn __init__(inout self, graph: Graph): 60 | self.forward_perf_metrics = List[PerfMetricsValues]() 61 | self.backward_perf_metrics = List[PerfMetricsValues]() 62 | 63 | self.forward_perf_metrics.reserve(graph.nodes.size) 64 | self.backward_perf_metrics.reserve(graph.nodes.size) 65 | 66 | for i in range(graph.nodes.size): 67 | self.forward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0)) 68 | self.backward_perf_metrics.append(PerfMetricsValues(graph.nodes[i], 0.0)) 69 | 70 | self.epochs_forward = 0 71 | self.epochs_backward = 0 72 | self.start = 0 73 | 74 | fn start_forward_pass(inout self): 75 | self.start = now() 76 | 77 | fn end_forward_pass(inout self, pos: Int): 78 | self.forward_perf_metrics[pos].ns += now() - self.start 79 | self.epochs_forward += 1 80 | 81 | fn start_backward_pass(inout self): 82 | self.start = now() 83 | 84 | fn end_backward_pass(inout self, pos: Int): 85 | self.backward_perf_metrics[pos].ns += now() - self.start 86 | self.epochs_backward += 1 87 | 88 | fn print_perf_metrics[ 89 | type_part: String 90 | ](self, time_format: String = "ns", print_shape: Bool = False): 91 | constrained[type_part == "Forward" or type_part == "Backward", "Only 'Forward' or 'Backward' are accepted types."]() 92 | 93 | alias is_forward = type_part == "Forward" 94 | 95 | var metrics = self.forward_perf_metrics if is_forward else self.backward_perf_metrics 96 | var epochs = self.epochs_forward if is_forward else self.epochs_backward 97 | var size = len(metrics) 98 | var total_time: Float64 = 0 99 | 100 | if size == 0: 101 | return 102 | 103 | if is_forward: 104 | print("\n\nForward pass performance metrics:") 105 | else: 106 | print("\n\nBackward pass performance metrics:") 107 | 108 | for i in range(size): 109 | total_time += metrics[i].ns 110 | 111 | var header = ( 112 | fit_string[5]("Node") 113 | + "| " 114 | + fit_string[15]("Operator") 115 | + "| " 116 | + fit_string[20]("Time [" + time_format + "]") 117 | + "| " 118 | + fit_string[20]("Percentage [%]") 119 | ) 120 | 121 | if print_shape: 122 | header += "| " + fit_string[70]("Shape\t = OP( , , )") 123 | 124 | print(header) 125 | 126 | var header_length = len(header) 127 | var seperator = UnsafePointer[UInt8]().alloc(header_length + 1) 128 | 129 | memset(seperator, ord("-"), header_length) 130 | seperator[header_length] = 0 131 | 132 | print(String(ptr=seperator, length=len(header) + 1)) 133 | 134 | for i in range(size): 135 | var value = metrics[i] 136 | var time = value.ns / epochs 137 | 138 | if time_format == "ms": 139 | time /= 1e6 140 | elif time_format == "s": 141 | time /= 1e9 142 | 143 | var percentage = (value.ns / total_time) * 100 144 | 145 | var print_value = ( 146 | fit_string[5](str(i)) 147 | + "| " 148 | + fit_string[15](str(value.node.operator)) 149 | + "| " 150 | + fit_string[20](truncate_decimals[4](str(time))) 151 | + "| " 152 | + fit_string[20](truncate_decimals[3](str(percentage)) + " %") 153 | + "| " 154 | ) 155 | 156 | if print_shape: 157 | var shape_str = fit_string[15]("<" + str(value.node.outputs[0].shape) + ">") 158 | 159 | for j in range(1, len(value.node.outputs)): 160 | shape_str += ", " + fit_string[15]("<" + str(value.node.outputs[j].shape) + ">") 161 | 162 | shape_str += fit_string[7](" = OP(") + fit_string[15]("<" + str(value.node.inputs[0].shape) + ">") 163 | 164 | for j in range(1, len(value.node.inputs)): 165 | shape_str += ", " + fit_string[15]("<" + str(value.node.inputs[j].shape) + ">") 166 | 167 | shape_str += ")" 168 | 169 | print(print_value, end="") 170 | print(shape_str) 171 | else: 172 | print(print_value) 173 | 174 | if time_format == "ms": 175 | total_time /= 1e6 176 | elif time_format == "s": 177 | total_time /= 1e9 178 | 179 | print( 180 | "\nTotal average " 181 | + type_part 182 | + " time: " 183 | + str(total_time) 184 | + " " 185 | + time_format 186 | ) 187 | 188 | 189 | fn print_forward_perf_metrics(self, time_format: String = "ns", print_shape: Bool = False): 190 | self.print_perf_metrics["Forward"](time_format, print_shape) 191 | 192 | fn print_backward_perf_metrics(self, time_format: String = "ns", print_shape: Bool = False): 193 | self.print_perf_metrics["Backward"](time_format, print_shape) 194 | -------------------------------------------------------------------------------- /basalt/utils/rand_utils.mojo: -------------------------------------------------------------------------------- 1 | from basalt import Tensor 2 | from random import rand, randn 3 | from algorithm import vectorize 4 | from utils.static_tuple import StaticTuple 5 | 6 | 7 | @always_inline 8 | fn rand_uniform[dtype: DType](inout res: Tensor[dtype], low: Scalar[dtype], high: Scalar[dtype]): 9 | var scale = high - low 10 | 11 | rand[dtype](res.data(), res.num_elements()) 12 | 13 | @parameter 14 | fn vecscale[nelts: Int](idx: Int): 15 | res.store[nelts](idx, res.load[nelts](idx).fma(scale, low)) 16 | 17 | vectorize[vecscale, nelts](res.num_elements()) 18 | 19 | 20 | @always_inline 21 | fn rand_normal[dtype: DType](inout res: Tensor[dtype], mean: Float64, std: Float64): 22 | randn[dtype](res.data(), res.num_elements(), mean, std**2) 23 | 24 | 25 | @register_passable("trivial") 26 | struct MersenneTwister: 27 | """ 28 | Pseudo-random generator Mersenne Twister (MT19937-32bit). 29 | """ 30 | 31 | alias N: Int = 624 32 | alias M: Int = 397 33 | alias MATRIX_A: Int32 = 0x9908B0DF 34 | alias UPPER_MASK: Int32 = 0x80000000 35 | alias LOWER_MASK: Int32 = 0x7FFFFFFF 36 | alias TEMPERING_MASK_B: Int32 = 0x9D2C5680 37 | alias TEMPERING_MASK_C: Int32 = 0xEFC60000 38 | 39 | var state: StaticTuple[Int32, Self.N] 40 | var index: Int 41 | 42 | fn __init__(inout self, seed: Int): 43 | alias W: Int = 32 44 | alias F: Int32 = 1812433253 45 | alias D: Int32 = 0xFFFFFFFF 46 | 47 | self.index = Self.N 48 | self.state = StaticTuple[Int32, Self.N]() 49 | self.state[0] = seed & D 50 | 51 | for i in range(1, Self.N): 52 | var prev = self.state[i - 1] 53 | self.state[i] = (F * (prev ^ (prev >> (W - 2))) + i) & D 54 | 55 | fn next(inout self) -> Int32: 56 | if self.index >= Self.N: 57 | for i in range(Self.N): 58 | var x = (self.state[i] & Self.UPPER_MASK) + (self.state[(i + 1) % Self.N] & Self.LOWER_MASK) 59 | var xA = x >> 1 60 | if x % 2 != 0: 61 | xA ^= Self.MATRIX_A 62 | self.state[i] = self.state[(i + Self.M) % Self.N] ^ xA 63 | self.index = 0 64 | 65 | var y = self.state[self.index] 66 | y ^= y >> 11 67 | y ^= (y << 7) & Self.TEMPERING_MASK_B 68 | y ^= (y << 15) & Self.TEMPERING_MASK_C 69 | y ^= y >> 18 70 | self.index += 1 71 | 72 | return y 73 | 74 | fn next_ui8(inout self) -> UInt8: 75 | return self.next().value & int(0xFF) 76 | -------------------------------------------------------------------------------- /basalt/utils/tensor_creation_utils.mojo: -------------------------------------------------------------------------------- 1 | from python import Python, PythonObject 2 | from memory import memcpy, UnsafePointer 3 | 4 | # maybe this functions should be from the Tensor struct (like tensor.to_numpy()) and tensor.__init__(np_array: PythonObject) to create a tensor from a numpy array and tensor.copy_np_data(np_array: PythonObject) to copy the numpy array to the tensor. 5 | 6 | 7 | fn to_numpy(tensor: Tensor) -> PythonObject: 8 | try: 9 | var np = Python.import_module("numpy") 10 | 11 | np.set_printoptions(4) 12 | 13 | var rank = tensor.rank() 14 | var dims = PythonObject([]) 15 | for i in range(rank): 16 | dims.append(tensor.dim(i)) 17 | var pyarray: PythonObject = np.empty(dims, dtype=np.float32) 18 | 19 | var pointer_d = pyarray.__array_interface__["data"][0].unsafe_get_as_pointer[DType.float32]() 20 | var d: UnsafePointer[Float32] = tensor.data().bitcast[Float32]() 21 | memcpy(pointer_d, d, tensor.num_elements()) 22 | 23 | _ = tensor 24 | 25 | return pyarray^ 26 | except e: 27 | print("Error in to numpy", e) 28 | return PythonObject() 29 | 30 | 31 | fn to_tensor(np_array: PythonObject) raises -> Tensor[dtype]: 32 | var shape = List[Int]() 33 | for i in range(np_array.ndim): 34 | shape.append(int(float(np_array.shape[i]))) 35 | if np_array.ndim == 0: 36 | # When the numpy array is a scalar, you need or the reshape to a size 1 ndarray or do this, if not the memcpy gets a memory error (Maybe because it is a register value?). 37 | var tensor = Tensor[dtype](TensorShape(1)) 38 | tensor[0] = float(np_array).cast[dtype]() 39 | return tensor^ 40 | 41 | var tensor = Tensor[dtype](TensorShape(shape)) 42 | 43 | var np_array_2: PythonObject 44 | try: 45 | var np = Python.import_module("numpy") 46 | # copy is also necessary for ops like slices to make them contiguous instead of references. 47 | np_array_2 = np.float32(np_array.copy()) 48 | except e: 49 | np_array_2 = np_array.copy() 50 | print("Error in to_tensor", e) 51 | 52 | var pointer_d = np_array_2.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]() 53 | memcpy(tensor.data(), pointer_d, tensor.num_elements()) 54 | 55 | _ = np_array_2 56 | _ = np_array 57 | 58 | return tensor^ 59 | 60 | 61 | fn copy_np_data(inout tensor: Tensor, np_array: PythonObject) raises: 62 | var np_array_2: PythonObject 63 | try: 64 | var np = Python.import_module("numpy") 65 | # copy is also necessary for ops like slices to make them contiguous instead of references. 66 | np_array_2 = np.float32(np_array.copy()) 67 | except e: 68 | np_array_2 = np_array.copy() 69 | print("Error in to_tensor", e) 70 | 71 | var pointer_d = np_array_2.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]() 72 | var d: UnsafePointer[Float32] = tensor.data().bitcast[Float32]() 73 | memcpy(d, pointer_d, tensor.num_elements()) 74 | 75 | # This shouldn't be necessary anymore, but I'm leaving it here for now. 76 | # _ = np_array_2 77 | # _ = np_array 78 | # _ = tensor 79 | -------------------------------------------------------------------------------- /examples/data/mnist_torch.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basalt-org/basalt/fe16eadcfcfee9271f9df0dd94d11b7c50d868ba/examples/data/mnist_torch.onnx -------------------------------------------------------------------------------- /examples/housing.mojo: -------------------------------------------------------------------------------- 1 | from time.time import monotonic as now 2 | 3 | import basalt.nn as nn 4 | from basalt import Tensor, TensorShape 5 | from basalt import Graph, Symbol, OP 6 | from basalt.utils.datasets import BostonHousing 7 | from basalt.utils.dataloader import DataLoader 8 | from basalt.nn.model import Parameters 9 | 10 | 11 | fn linear_regression(batch_size: Int, n_inputs: Int, n_outputs: Int) -> Graph: 12 | var g = Graph() 13 | 14 | var x = g.input(TensorShape(batch_size, n_inputs)) 15 | var y_true = g.input(TensorShape(batch_size, n_outputs)) 16 | 17 | var y_pred = nn.Linear(g, x, n_outputs) 18 | g.out(y_pred) 19 | 20 | var loss = nn.MSELoss(g, y_pred, y_true) 21 | g.loss(loss) 22 | 23 | return g ^ 24 | 25 | 26 | fn main(): 27 | # Train Parameters 28 | alias batch_size = 32 29 | alias num_epochs = 200 30 | alias learning_rate = 0.02 31 | 32 | alias graph = linear_regression(batch_size, 13, 1) 33 | 34 | # try: graph.render("operator") 35 | # except: print("Could not render graph") 36 | 37 | var model = nn.Model[graph]() 38 | var optim = nn.optim.Adam[graph](model.parameters, lr=learning_rate) 39 | 40 | # Batchwise data loader 41 | print("Loading data...") 42 | var train_data: BostonHousing 43 | try: 44 | train_data = BostonHousing(file_path="./examples/data/housing.csv") 45 | except: 46 | print("Could not load data") 47 | return 48 | 49 | var training_loader = DataLoader( 50 | data=train_data.data, labels=train_data.labels, batch_size=batch_size 51 | ) 52 | 53 | print("Training started.") 54 | var start = now() 55 | for epoch in range(num_epochs): 56 | var num_batches: Int = 0 57 | var epoch_loss: Float32 = 0.0 58 | for batch in training_loader: 59 | # Forward pass 60 | var loss = model.forward(batch.data, batch.labels) 61 | 62 | # Backward pass 63 | optim.zero_grad() 64 | model.backward() 65 | optim.step() 66 | 67 | epoch_loss += loss[0] 68 | num_batches += 1 69 | 70 | print( 71 | "Epoch: [", 72 | epoch + 1, 73 | "/", 74 | num_epochs, 75 | "] \t Avg loss per epoch:", 76 | epoch_loss / num_batches, 77 | ) 78 | 79 | print("Training finished: ", (now() - start) / 1e9, "seconds") 80 | 81 | # print("\n\nInferencing model...\n") 82 | # for batch in training_loader: 83 | # var output = model.inference(batch.data) 84 | 85 | # # Print first (and only output) 86 | # print("Predicted: ", output[0]) 87 | -------------------------------------------------------------------------------- /examples/housing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.utils.data import Dataset, DataLoader, TensorDataset 7 | import time 8 | 9 | 10 | class BostonHousing(Dataset): 11 | def __init__(self, data: pd.DataFrame): 12 | # Data: All columns except the last one / Target: Only the last column (MEDV) 13 | self.data = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32) 14 | self.target = torch.tensor(data.iloc[:, -1].values, dtype=torch.float32).view( 15 | -1, 1 16 | ) 17 | 18 | # Normalize data 19 | self.data = (self.data - self.data.mean(dim=0)) / self.data.std(dim=0) 20 | 21 | # Create dataset 22 | self.dataset = TensorDataset(self.data, self.target) 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, idx): 28 | return self.dataset[idx] 29 | 30 | 31 | class LinearRegression(nn.Module): 32 | def __init__(self, input_dim): 33 | super(LinearRegression, self).__init__() 34 | self.linear = nn.Linear(input_dim, 1) 35 | 36 | def forward(self, x): 37 | return self.linear(x) 38 | 39 | 40 | if __name__ == "__main__": 41 | # Load data and split in training and testing sets 42 | df = pd.read_csv("./examples/data/housing.csv") 43 | 44 | TRAIN_PCT = 0.99 45 | shuffled_df = df.sample(frac=1, random_state=42) 46 | train_df = shuffled_df[: int(TRAIN_PCT * len(df))] 47 | test_df = shuffled_df[int(TRAIN_PCT * len(df)) :] 48 | 49 | train_data = BostonHousing(train_df) 50 | test_data = BostonHousing(test_df) 51 | 52 | # Train Parameters 53 | batch_size = 32 54 | num_epochs = 200 55 | learning_rate = 0.02 56 | 57 | # Batchwise data loader 58 | loaders = { 59 | "train": DataLoader( 60 | train_data, batch_size=batch_size, shuffle=False, num_workers=1 61 | ), 62 | "test": DataLoader( 63 | test_data, batch_size=batch_size, shuffle=False, num_workers=1 64 | ), 65 | } 66 | 67 | device = torch.device("cpu") 68 | # model = torch.compile(LinearRegression(train_data.data.shape[1]), fullgraph=True, options={"epilogue_fusion": True, "max_autotune": True}) 69 | model = LinearRegression(train_data.data.shape[1]) 70 | loss_func = nn.MSELoss() 71 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 72 | # optimizer = optim.SGD(model.parameters(), lr=learning_rate) 73 | 74 | # it seems the python for loop is what is making the program slow (so pytorch has a disadvantage thanks to python) 75 | model.train() 76 | 77 | start = time.time() 78 | for epoch in range(num_epochs): 79 | epoch_loss = 0 80 | num_batches = 0 81 | for batch_data, batch_labels in loaders["train"]: 82 | start_batch = time.time() 83 | 84 | # Forward pass 85 | outputs = model(batch_data) 86 | loss = loss_func(outputs, batch_labels) 87 | 88 | # Backward pass 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | 93 | epoch_loss += loss.item() 94 | num_batches += 1 95 | 96 | # print time in ms 97 | # print(f'Batch time: {1000 * (time.time() - start_batch):.2f} ms') # The speed of a batch in basalt and pytorch are similar or pytorch can be faster 98 | 99 | print( 100 | f"Epoch [{epoch + 1}/{num_epochs}],\t Avg loss per epoch:" 101 | f" {epoch_loss / num_batches}" 102 | ) 103 | 104 | print(f"Training time: {time.time() - start:.2f} seconds") 105 | 106 | # Evaluate the model 107 | model.eval() 108 | with torch.no_grad(): 109 | test_predictions = model(test_data.data) 110 | mse_loss = loss_func(test_predictions, test_data.target).item() 111 | print(f"Mean Squared Error on Test Data: {mse_loss:.4f}") 112 | -------------------------------------------------------------------------------- /examples/mnist.mojo: -------------------------------------------------------------------------------- 1 | from time.time import monotonic as now 2 | 3 | import basalt.nn as nn 4 | from basalt import Tensor, TensorShape 5 | from basalt import Graph, Symbol, OP, dtype 6 | from basalt.utils.datasets import MNIST 7 | from basalt.utils.dataloader import DataLoader 8 | from basalt.autograd.attributes import AttributeVector, Attribute 9 | 10 | 11 | # def plot_image(data: Tensor, num: Int): 12 | # from python.python import Python, PythonObject 13 | 14 | # np = Python.import_module("numpy") 15 | # plt = Python.import_module("matplotlib.pyplot") 16 | 17 | # var pyimage: PythonObject = np.empty((28, 28), np.float64) 18 | # for m in range(28): 19 | # for n in range(28): 20 | # pyimage.itemset((m, n), data[num * 28 * 28 + m * 28 + n]) 21 | 22 | # plt.imshow(pyimage) 23 | # plt.show() 24 | 25 | 26 | fn create_CNN(batch_size: Int) -> Graph: 27 | var g = Graph() 28 | var x = g.input(TensorShape(batch_size, 1, 28, 28)) 29 | 30 | var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2) 31 | var x2 = nn.ReLU(g, x1) 32 | var x3 = nn.MaxPool2d(g, x2, kernel_size=2) 33 | var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2) 34 | var x5 = nn.ReLU(g, x4) 35 | var x6 = nn.MaxPool2d(g, x5, kernel_size=2) 36 | var x7 = g.op( 37 | OP.RESHAPE, 38 | x6, 39 | attributes=AttributeVector( 40 | Attribute( 41 | "shape", 42 | TensorShape(x6.shape[0], x6.shape[1] * x6.shape[2] * x6.shape[3]), 43 | ) 44 | ), 45 | ) 46 | var out = nn.Linear(g, x7, n_outputs=10) 47 | g.out(out) 48 | 49 | var y_true = g.input(TensorShape(batch_size, 10)) 50 | var loss = nn.CrossEntropyLoss(g, out, y_true) 51 | # var loss = nn.MSELoss(g, out, y_true) 52 | g.loss(loss) 53 | 54 | return g ^ 55 | 56 | 57 | fn main(): 58 | alias num_epochs = 20 59 | alias batch_size = 4 60 | alias learning_rate = 1e-3 61 | 62 | alias graph = create_CNN(batch_size) 63 | 64 | # try: graph.render("operator") 65 | # except: print("Could not render graph") 66 | 67 | var model = nn.Model[graph]() 68 | var optim = nn.optim.Adam[graph](model.parameters, lr=learning_rate) 69 | 70 | print("Loading data ...") 71 | var train_data: MNIST 72 | try: 73 | train_data = MNIST(file_path="./examples/data/mnist_test_small.csv") 74 | # _ = plot_image(train_data.data, 1) 75 | except e: 76 | print("Could not load data") 77 | print(e) 78 | return 79 | 80 | var training_loader = DataLoader( 81 | data=train_data.data, labels=train_data.labels, batch_size=batch_size 82 | ) 83 | 84 | print("Training started/") 85 | var start = now() 86 | 87 | for epoch in range(num_epochs): 88 | var num_batches: Int = 0 89 | var epoch_loss: Float32 = 0.0 90 | var epoch_start = now() 91 | for batch in training_loader: 92 | # [ONE HOT ENCODING!] 93 | var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10) 94 | for bb in range(batch.labels.dim(0)): 95 | labels_one_hot[int((bb * 10 + batch.labels[bb]))] = 1.0 96 | 97 | # Forward pass 98 | var loss = model.forward(batch.data, labels_one_hot) 99 | 100 | # Backward pass 101 | optim.zero_grad() 102 | model.backward() 103 | optim.step() 104 | 105 | epoch_loss += loss[0] 106 | num_batches += 1 107 | 108 | print( 109 | "Epoch [", 110 | epoch + 1, 111 | "/", 112 | num_epochs, 113 | "],\t Step [", 114 | num_batches, 115 | "/", 116 | train_data.data.dim(0) // batch_size, 117 | "],\t Loss:", 118 | epoch_loss / num_batches, 119 | ) 120 | 121 | print("Epoch time: ", (now() - epoch_start) / 1e9, "seconds") 122 | 123 | print("Training finished: ", (now() - start) / 1e9, "seconds") 124 | 125 | model.print_perf_metrics("ms", True) 126 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torch.utils.data import Dataset, DataLoader, TensorDataset 12 | 13 | 14 | class MNIST(Dataset): 15 | def __init__(self, csv_file): 16 | data = pd.read_csv(csv_file) 17 | self.labels = torch.tensor(data.iloc[:, 0].values, dtype=torch.int64) 18 | self.images = torch.tensor( 19 | data.iloc[:, 1:].values, dtype=torch.float32 20 | ).reshape(-1, 1, 28, 28) 21 | 22 | # Normalize data 23 | self.images = self.images / 255.0 24 | 25 | self.dataset = TensorDataset(self.images, self.labels) 26 | 27 | def __len__(self): 28 | return len(self.dataset) 29 | 30 | def __getitem__(self, idx): 31 | return self.dataset[idx] 32 | 33 | 34 | class CNN(nn.Module): 35 | def __init__(self): 36 | super(CNN, self).__init__() 37 | self.conv1 = nn.Sequential( 38 | nn.Conv2d( 39 | in_channels=1, 40 | out_channels=16, 41 | kernel_size=5, 42 | stride=1, 43 | padding=2, 44 | ), 45 | nn.ReLU(), 46 | nn.MaxPool2d(kernel_size=2), 47 | ) 48 | self.conv2 = nn.Sequential( 49 | nn.Conv2d(16, 32, 5, 1, 2), 50 | nn.ReLU(), 51 | nn.MaxPool2d(2), 52 | ) 53 | # fully connected layer, output 10 classes 54 | self.out = nn.Linear(32 * 7 * 7, 10) 55 | 56 | def forward(self, x): 57 | x = self.conv1(x) 58 | x = self.conv2(x) 59 | # flatten the output of conv2 to (batch_size, 32 * 7 * 7) 60 | x = x.view(x.size(0), -1) 61 | output = self.out(x) 62 | return output 63 | 64 | 65 | if __name__ == "__main__": 66 | num_epochs = 20 67 | batch_size = 4 68 | learning_rate = 1e-3 69 | 70 | # Load data 71 | train_data = MNIST("./examples/data/mnist_test_small.csv") 72 | 73 | # Visualize data 74 | num = 0 75 | plt.imshow(np.array(train_data[num][0]).squeeze()) 76 | plt.title("%i" % train_data[num][1]) 77 | plt.show() 78 | 79 | # Batchwise data loader 80 | loaders = { 81 | "train": DataLoader( 82 | train_data, batch_size=batch_size, shuffle=True, num_workers=1 83 | ), 84 | } 85 | 86 | device = torch.device("cpu") 87 | cnn = CNN() 88 | loss_func = nn.CrossEntropyLoss() 89 | optimizer = optim.Adam(cnn.parameters(), lr=learning_rate) 90 | 91 | # Train the model 92 | cnn.train() 93 | total_step = len(loaders["train"]) 94 | start = time.time() 95 | for epoch in range(num_epochs): 96 | for i, (images, labels) in enumerate(loaders["train"]): 97 | b_x = Variable(images) 98 | b_y = Variable(labels) 99 | 100 | output = cnn(b_x) 101 | loss = loss_func(output, b_y) 102 | 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | 107 | print( 108 | "Epoch [{}/{}],\t Step [{}/{}],\t Loss: {:.6f}".format( 109 | epoch + 1, num_epochs, i + 1, total_step, loss.item() 110 | ) 111 | ) 112 | 113 | print(f"Training time: {time.time() - start:.2f} seconds") 114 | 115 | # Export to ONNX 116 | export_onnx = os.environ.get("export_onnx", 0) 117 | if export_onnx == "1": 118 | dummy_input = torch.randn(1, 1, 28, 28) 119 | 120 | # cnn.out.weight = nn.Parameter(cnn.out.weight.T) # transpose because torch saves the weight of linear layer as (output_dim, input_dim) (so they transposed and there is not a real reason for this) 121 | torch.onnx.export(cnn, dummy_input, "./examples/data/mnist_torch.onnx", verbose=True) -------------------------------------------------------------------------------- /examples/mnist_load_model.mojo: -------------------------------------------------------------------------------- 1 | from time.time import monotonic as now 2 | from pathlib import Path 3 | 4 | import basalt.nn as nn 5 | from basalt import Tensor, TensorShape 6 | from basalt import Graph, Symbol, OP, dtype 7 | from basalt.utils.datasets import MNIST 8 | from basalt.utils.dataloader import DataLoader 9 | from basalt.autograd.attributes import AttributeVector, Attribute 10 | 11 | 12 | # def plot_image(data: Tensor, num: Int): 13 | # from python.python import Python, PythonObject 14 | 15 | # np = Python.import_module("numpy") 16 | # plt = Python.import_module("matplotlib.pyplot") 17 | 18 | # var pyimage: PythonObject = np.empty((28, 28), np.float64) 19 | # for m in range(28): 20 | # for n in range(28): 21 | # pyimage.itemset((m, n), data[num * 28 * 28 + m * 28 + n]) 22 | 23 | # plt.imshow(pyimage) 24 | # plt.show() 25 | 26 | 27 | fn create_CNN(batch_size: Int) -> Graph: 28 | var g = Graph() 29 | var x = g.input(TensorShape(batch_size, 1, 28, 28)) 30 | 31 | var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2) 32 | var x2 = nn.ReLU(g, x1) 33 | var x3 = nn.MaxPool2d(g, x2, kernel_size=2) 34 | var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2) 35 | var x5 = nn.ReLU(g, x4) 36 | var x6 = nn.MaxPool2d(g, x5, kernel_size=2) 37 | var x7 = g.op( 38 | OP.RESHAPE, 39 | x6, 40 | attributes=AttributeVector( 41 | Attribute( 42 | "shape", 43 | TensorShape(x6.shape[0], x6.shape[1] * x6.shape[2] * x6.shape[3]), 44 | ) 45 | ), 46 | ) 47 | var out = nn.Linear(g, x7, n_outputs=10) 48 | g.out(out) 49 | 50 | return g ^ 51 | 52 | 53 | fn main(): 54 | alias num_epochs = 1 55 | alias batch_size = 4 56 | alias learning_rate = 1e-3 57 | 58 | alias graph = create_CNN(batch_size) 59 | 60 | # try: graph.render("operator") 61 | # except: print("Could not render graph") 62 | 63 | var model = nn.Model[graph]() 64 | model.load_model_data("./examples/data/mnist_torch.onnx") 65 | 66 | print("Loading data ...") 67 | var train_data: MNIST 68 | try: 69 | train_data = MNIST(file_path="./examples/data/mnist_test_small.csv") 70 | # _ = plot_image(train_data.data, 1) 71 | except e: 72 | print("Could not load data") 73 | print(e) 74 | return 75 | 76 | var training_loader = DataLoader( 77 | data=train_data.data, labels=train_data.labels, batch_size=batch_size 78 | ) 79 | 80 | # Testing 81 | print("Testing started") 82 | var start = now() 83 | 84 | var correct = 0 85 | for batch in training_loader: 86 | var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10) 87 | for bb in range(batch.labels.dim(0)): 88 | labels_one_hot[int(bb * 10 + batch.labels[bb])] = 1.0 89 | 90 | var output = model.inference(batch.data, labels_one_hot)[0] 91 | 92 | fn argmax(tensor: Tensor[dtype], dim: Int) -> Tensor[dtype]: 93 | var result = Tensor[dtype](tensor.dim(0)) 94 | for i in range(tensor.dim(0)): 95 | var max_val = tensor[i * 10] 96 | var max_idx = 0 97 | for j in range(1, 10): 98 | if tensor[i * 10 + j] > max_val: 99 | max_val = tensor[i * 10 + j] 100 | max_idx = j 101 | result[i] = max_idx 102 | 103 | return result 104 | 105 | var pred = argmax(output, dim=1) 106 | 107 | for i in range(batch.labels.dim(0)): 108 | if pred[i] == batch.labels[i]: 109 | correct += 1 110 | 111 | print("Accuracy: ", correct / train_data.data.dim(0) * 100, "%") 112 | print("Testing finished: ", (now() - start) / 1e9, "seconds") 113 | 114 | # model.print_perf_metrics("ms", True) 115 | 116 | model.export_model("./output_model.onnx") -------------------------------------------------------------------------------- /examples/sin_estimate.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from time.time import monotonic as now 3 | import math 4 | 5 | import basalt.nn as nn 6 | from basalt import Tensor, TensorShape 7 | from basalt import dtype 8 | from basalt import Graph, Symbol, OP 9 | from basalt.utils.tensorutils import fill 10 | 11 | 12 | fn create_simple_nn(batch_size: Int, n_inputs: Int, n_outputs: Int) -> Graph: 13 | var g = Graph() 14 | 15 | var x = g.input(TensorShape(batch_size, n_inputs)) 16 | var y_true = g.input(TensorShape(batch_size, n_outputs)) 17 | 18 | var x1 = nn.Linear(g, x, n_outputs=32) 19 | var x2 = nn.ReLU(g, x1) 20 | var x3 = nn.Linear(g, x2, n_outputs=32) 21 | var x4 = nn.ReLU(g, x3) 22 | var y_pred = nn.Linear(g, x4, n_outputs=n_outputs) 23 | g.out(y_pred) 24 | 25 | var loss = nn.MSELoss(g, y_pred, y_true) 26 | g.loss(loss) 27 | 28 | g.compile() 29 | 30 | return g ^ 31 | 32 | 33 | fn main(): 34 | alias batch_size = 32 35 | alias n_inputs = 1 36 | alias n_outputs = 1 37 | alias learning_rate = 0.01 38 | 39 | alias epochs = 20000 40 | 41 | alias graph = create_simple_nn(batch_size, n_inputs, n_outputs) 42 | 43 | # try: graph.render("operator") 44 | # except: print("Could not render graph") 45 | 46 | var model = nn.Model[graph]() 47 | var optimizer = nn.optim.Adam[graph](model.parameters, lr=learning_rate) 48 | 49 | var x_data = Tensor[dtype](batch_size, n_inputs) 50 | var y_data = Tensor[dtype](batch_size, n_outputs) 51 | 52 | print("Training started") 53 | var start = now() 54 | for i in range(epochs): 55 | rand[dtype](x_data.data(), x_data.num_elements()) 56 | 57 | for j in range(batch_size): 58 | x_data[j] = x_data[j] * 2 - 1 59 | y_data[j] = math.sin(x_data[j]) 60 | 61 | var out = model.forward(x_data, y_data) 62 | 63 | if (i + 1) % 1000 == 0: 64 | print("[", i + 1, "/", epochs, "] \tLoss: ", out[0]) 65 | 66 | optimizer.zero_grad() 67 | model.backward() 68 | optimizer.step() 69 | 70 | print("Training finished: ", (now() - start) / 1e9, "seconds") 71 | -------------------------------------------------------------------------------- /examples/sin_estimate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | import time 5 | 6 | 7 | class SimpleNN(nn.Module): 8 | def __init__(self, n_inputs, n_outputs): 9 | super(SimpleNN, self).__init__() 10 | self.linear1 = nn.Linear(in_features=n_inputs, out_features=32) 11 | self.relu1 = nn.ReLU() 12 | self.linear2 = nn.Linear(in_features=32, out_features=32) 13 | self.relu2 = nn.ReLU() 14 | self.linear3 = nn.Linear(in_features=32, out_features=n_outputs) 15 | 16 | def forward(self, x): 17 | x1 = self.linear1(x) 18 | x2 = self.relu1(x1) 19 | x3 = self.linear2(x2) 20 | x4 = self.relu2(x3) 21 | y_pred = self.linear3(x4) 22 | return y_pred 23 | 24 | 25 | if __name__ == "__main__": 26 | batch_size = 32 27 | n_inputs = 1 28 | n_outputs = 1 29 | learning_rate = 0.01 30 | 31 | device = torch.device("cpu") 32 | model = SimpleNN(n_inputs, n_outputs).to(device) 33 | loss_func = nn.MSELoss() 34 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 35 | 36 | x = torch.rand(batch_size, n_inputs).to(device) * 2 - 1 37 | y = torch.sin(x).to(device) 38 | 39 | epochs = 20000 40 | 41 | model.train() 42 | start = time.time() 43 | for i in range(epochs): 44 | x = torch.rand(batch_size, n_inputs).to(device) * 2 - 1 45 | y = torch.sin(x).to(device) 46 | 47 | outputs = model(x) 48 | loss = loss_func(outputs, y) 49 | 50 | # Backward pass 51 | optimizer.zero_grad() 52 | loss.backward() 53 | optimizer.step() 54 | 55 | if (i + 1) % 1000 == 0: 56 | print(f"Epoch [{i + 1}/{epochs}],\t Loss: {loss.item()}") 57 | 58 | print(f"Training time: {time.time() - start:.2f} seconds. Loss: {loss.item()}") 59 | -------------------------------------------------------------------------------- /mojoproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = ["stijn", "benny", "andres"] 3 | channels = ["conda-forge", "https://conda.modular.com/max"] 4 | description = "Basalt is a stand-alone machine learning framework that leverages the power of Mojo." 5 | name = "Basalt" 6 | platforms = ["osx-arm64", "linux-64"] 7 | version = "0.1.0" 8 | 9 | [tasks] 10 | test = { cmd = "magic run mojo test -I . tests" } 11 | test_mojo = { cmd = "magic run mojo test -I . tests/mojo" } 12 | test_python = { cmd = "magic run mojo test -I . tests/python" } 13 | 14 | [dependencies] 15 | max = ">=24.6.0,<25" 16 | 17 | [pypi-dependencies] 18 | torch = ">=2.5.1, <3" 19 | torchvision = ">=0.20.1, <0.21" 20 | torchaudio = ">=2.5.1, <3" 21 | onnx = ">=1.17.0, <2" 22 | -------------------------------------------------------------------------------- /python-requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | matplotlib==3.8.0 3 | pandas==2.1.1 4 | onnx 5 | netron -------------------------------------------------------------------------------- /tests/__init__.mojo: -------------------------------------------------------------------------------- 1 | from .testing_utils import * 2 | -------------------------------------------------------------------------------- /tests/mojo/test_activations.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal 2 | 3 | from basalt import dtype 4 | from basalt.nn import ( 5 | Tensor, 6 | TensorShape, 7 | Model, 8 | Softmax, 9 | LogSoftmax, 10 | ReLU, 11 | LeakyReLU, 12 | Sigmoid, 13 | Tanh, 14 | ) 15 | from basalt.autograd import Graph, Symbol 16 | from basalt.utils.tensorutils import fill 17 | 18 | from tests import assert_tensors_equal 19 | 20 | 21 | alias Activation = fn (inout g: Graph, input: Symbol) -> Symbol 22 | alias AxisActivation = fn (inout g: Graph, input: Symbol, axis: Int) -> Symbol 23 | alias LeakyReLUActivation = fn ( 24 | inout g: Graph, input: Symbol, negative_slope: Scalar[dtype] 25 | ) -> Symbol 26 | 27 | 28 | fn create_graph[ 29 | shape: TensorShape, 30 | func: AxisActivation, 31 | axis: Int, 32 | ]() -> Graph: 33 | var g = Graph() 34 | var x = g.input(shape) 35 | var activation = func(g, x, axis) 36 | g.out(activation) 37 | return g^ 38 | 39 | 40 | fn create_graph[ 41 | shape: TensorShape, 42 | func: LeakyReLUActivation, 43 | negative_slope: Scalar[dtype], 44 | ]() -> Graph: 45 | var g = Graph() 46 | var x = g.input(shape) 47 | var activation = func(g, x, negative_slope) 48 | g.out(activation) 49 | return g^ 50 | 51 | 52 | fn create_graph[shape: TensorShape, func: Activation]() -> Graph: 53 | var g = Graph() 54 | var x = g.input(shape) 55 | var activation = func(g, x) 56 | g.out(activation) 57 | return g^ 58 | 59 | 60 | fn test_graph[ 61 | shape: TensorShape, 62 | func: AxisActivation, 63 | nodes: Int, 64 | axis: Int, 65 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises: 66 | alias graph = create_graph[shape, func, axis]() 67 | 68 | var model = Model[graph](inference_only=True) 69 | var res = model.inference(input)[0] 70 | 71 | assert_tensors_equal["almost"](res, expected) 72 | assert_equal(len(graph.nodes), nodes) 73 | 74 | 75 | fn test_graph[ 76 | shape: TensorShape, 77 | func: LeakyReLUActivation, 78 | nodes: Int, 79 | negative_slope: Scalar[dtype], 80 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises: 81 | alias graph = create_graph[shape, func, negative_slope]() 82 | 83 | var model = Model[graph](inference_only=True) 84 | var res = model.inference(input)[0] 85 | 86 | assert_tensors_equal["almost"](res, expected) 87 | assert_equal(len(graph.nodes), nodes) 88 | 89 | 90 | # TODO: All these overloads feel redundant. Find a way to condense them 91 | fn test_graph[ 92 | shape: TensorShape, 93 | func: Activation, 94 | nodes: Int, 95 | ](input: Tensor[dtype], expected: Tensor[dtype]) raises: 96 | alias graph = create_graph[shape, func]() 97 | 98 | var model = Model[graph](inference_only=True) 99 | var res = model.inference(input)[0] 100 | 101 | assert_tensors_equal["almost", "Tensor equality failed"](res, expected) 102 | assert_equal(len(graph.nodes), nodes, "Node count failed") 103 | 104 | 105 | fn test_SOFTMAX() raises: 106 | alias shape = TensorShape(2, 3, 2) 107 | alias nodes = 5 108 | 109 | var input = Tensor[dtype](shape) 110 | fill(input, 4) 111 | 112 | var expected = Tensor[dtype](shape) 113 | 114 | fill(expected, 0.5) 115 | test_graph[shape, Softmax, nodes, 0](input, expected) 116 | 117 | fill(expected, 1.0 / 3.0) 118 | test_graph[shape, Softmax, nodes, 1](input, expected) 119 | 120 | fill(expected, 0.5) 121 | test_graph[shape, Softmax, nodes, 2](input, expected) 122 | 123 | 124 | fn test_LOGSOFTMAX() raises: 125 | alias shape = TensorShape(2, 3, 2) 126 | alias nodes = 6 127 | 128 | var input = Tensor[dtype](shape) 129 | fill(input, 4) 130 | 131 | var expected = Tensor[dtype](shape) 132 | 133 | fill(expected, -0.69314718) 134 | test_graph[shape, LogSoftmax, nodes, 0](input, expected) 135 | 136 | fill(expected, -1.09861231) 137 | test_graph[shape, LogSoftmax, nodes, 1](input, expected) 138 | 139 | fill(expected, -0.69314718) 140 | test_graph[shape, LogSoftmax, nodes, 2](input, expected) 141 | 142 | 143 | fn test_RELU() raises: 144 | alias shape = TensorShape(2, 3) 145 | alias nodes = 1 146 | 147 | var input = Tensor[dtype](shape) 148 | 149 | for i in range(6): 150 | input[i] = 3 if i < 3 else -3 151 | 152 | var expected = Tensor[dtype](shape) 153 | 154 | for i in range(6): 155 | expected[i] = 3 if i < 3 else 0 156 | 157 | test_graph[shape, ReLU, nodes](input, expected) 158 | 159 | 160 | fn test_LEAKYRELU() raises: 161 | alias negative_slope = Float32(0.1) 162 | 163 | alias shape = TensorShape(2, 3) 164 | alias nodes = 1 165 | 166 | var input = Tensor[dtype](shape) 167 | 168 | for i in range(6): 169 | input[i] = i - 3 170 | 171 | var expected = Tensor[dtype](shape) 172 | 173 | for i in range(6): 174 | expected[i] = i - 3 if i - 3 > 0 else negative_slope * (i - 3) 175 | 176 | test_graph[shape, LeakyReLU, nodes, negative_slope](input, expected) 177 | 178 | 179 | fn test_SIGMOID() raises: 180 | alias shape = TensorShape(2, 3) 181 | alias nodes = 1 182 | 183 | var input = Tensor[dtype](shape) 184 | fill(input, 0) 185 | 186 | var expected = Tensor[dtype](shape) 187 | 188 | fill(expected, 0.5) 189 | test_graph[shape, Sigmoid, nodes](input, expected) 190 | 191 | 192 | fn test_TANH() raises: 193 | alias shape = TensorShape(2, 3) 194 | alias nodes = 1 195 | 196 | var input = Tensor[dtype](shape) 197 | fill(input, 0) 198 | 199 | var expected = Tensor[dtype](shape) 200 | 201 | fill(expected, 0.0) 202 | test_graph[shape, Tanh, nodes](input, expected) 203 | 204 | 205 | fn main(): 206 | try: 207 | test_SOFTMAX() 208 | test_LOGSOFTMAX() 209 | test_RELU() 210 | test_LEAKYRELU() 211 | test_SIGMOID() 212 | test_TANH() 213 | except e: 214 | print("[ERROR] Error in activations") 215 | print(e) 216 | -------------------------------------------------------------------------------- /tests/mojo/test_attributes.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal, assert_true 2 | from utils.index import IndexList 3 | 4 | from basalt.nn import TensorShape 5 | from basalt.autograd.attributes import Attribute 6 | 7 | 8 | fn test_attribute_key() raises: 9 | alias a = Attribute(name="test", value=-1) 10 | 11 | assert_true(str(a.name) == "test") 12 | 13 | 14 | fn test_attribute_int() raises: 15 | alias value: Int = 1 16 | alias a = Attribute(name="test", value=value) 17 | 18 | assert_true(a.to_int() == 1) 19 | 20 | 21 | fn test_attribute_string() raises: 22 | alias value: String = "hello" 23 | alias a = Attribute(name="test", value=value) 24 | 25 | assert_true(a.to_string() == value) 26 | 27 | 28 | fn test_attribute_tensor_shape() raises: 29 | alias value: TensorShape = TensorShape(1, 2, 3) 30 | alias a = Attribute(name="test", value=value) 31 | 32 | assert_true(a.to_shape() == value) 33 | 34 | 35 | fn test_attribute_static_int_tuple() raises: 36 | alias value: IndexList[7] = IndexList[7](1, 2, 3, 4, 5, 6, 7) 37 | alias a = Attribute(name="test", value=value) 38 | 39 | assert_true(a.to_static[7]() == value) 40 | 41 | 42 | fn test_attribute_scalar() raises: 43 | fn test_float32() raises: 44 | alias value_a: Float32 = 1.23456 45 | alias a1 = Attribute(name="test", value=value_a) 46 | assert_true( 47 | a1.to_scalar[DType.float32]() == value_a, 48 | "Float32 scalar attribute failed", 49 | ) 50 | 51 | alias value_b: Float32 = 65151 52 | alias a2 = Attribute(name="test", value=value_b) 53 | assert_true( 54 | a2.to_scalar[DType.float32]() == value_b, 55 | "Float32 scalar attribute failed", 56 | ) 57 | 58 | fn test_float_literal() raises: 59 | alias value_c: FloatLiteral = -1.1 60 | alias a3 = Attribute(name="test", value=value_c) 61 | assert_true( 62 | a3.to_scalar[DType.float32]() == value_c, 63 | "FloatLiteral scalar attribute failed", 64 | ) 65 | 66 | fn test_float64() raises: 67 | alias value_a: Float64 = -1.23456 68 | alias a1 = Attribute(name="test", value=value_a) 69 | assert_true( 70 | a1.to_scalar[DType.float64]() == value_a, 71 | "Float64 scalar attribute failed", 72 | ) 73 | 74 | alias value_b: Float64 = 123456 75 | alias a2 = Attribute(name="test", value=value_b) 76 | assert_true( 77 | a2.to_scalar[DType.float64]() == value_b, 78 | "Float64 scalar attribute failed", 79 | ) 80 | 81 | fn test_int32() raises: 82 | alias value_a: Int32 = 666 83 | alias a1 = Attribute(name="test", value=value_a) 84 | assert_true( 85 | a1.to_scalar[DType.int32]() == value_a, 86 | "Int32 scalar attribute failed", 87 | ) 88 | 89 | alias value_b: Int32 = -666 90 | alias a2 = Attribute(name="test", value=value_b) 91 | assert_true( 92 | a2.to_scalar[DType.int32]() == value_b, 93 | "Int32 scalar attribute failed", 94 | ) 95 | 96 | fn test_attribute_small_scalar() raises: 97 | alias value_a: Float32 = 1e-18 98 | alias a = Attribute(name="test", value=value_a) 99 | assert_true( 100 | a.to_scalar[DType.float32]() == value_a, 101 | "SMALL scalar attribute failed", 102 | ) 103 | 104 | fn test_attribute_big_scalar() raises: 105 | alias value_a: Float32 = 1e40 106 | alias a = Attribute(name="test", value=value_a) 107 | assert_true( 108 | a.to_scalar[DType.float32]() == value_a, 109 | "BIG scalar attribute failed", 110 | ) 111 | 112 | test_float32() 113 | test_float_literal() 114 | test_float64() 115 | test_int32() 116 | test_attribute_small_scalar() 117 | test_attribute_big_scalar() 118 | 119 | 120 | fn main(): 121 | try: 122 | test_attribute_key() 123 | test_attribute_int() 124 | test_attribute_string() 125 | test_attribute_tensor_shape() 126 | test_attribute_static_int_tuple() 127 | test_attribute_scalar() 128 | except e: 129 | print("[ERROR] Error in attributes") 130 | print(e) 131 | -------------------------------------------------------------------------------- /tests/mojo/test_collection.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal 2 | 3 | from basalt import dtype 4 | from basalt.nn import Tensor, TensorShape 5 | from basalt.autograd import Symbol 6 | from basalt.utils.collection import Collection 7 | from basalt.utils.tensorutils import fill 8 | 9 | from tests import assert_tensors_equal 10 | 11 | 12 | fn test_append_tensors() raises: 13 | alias t1_shape = TensorShape(1, 10) 14 | alias t2_shape = TensorShape(2, 20) 15 | var s1 = Symbol(0, dtype, t1_shape, True) 16 | var s2 = Symbol(1, dtype, t2_shape, True) 17 | 18 | var c = Collection(capacity=2) 19 | assert_equal(c.capacity, 2) 20 | assert_equal(c.size, 0) 21 | 22 | c.append(Tensor[dtype](s1.shape), s1) 23 | assert_equal(c.size, 1) 24 | 25 | c.append(Tensor[dtype](s2.shape), s2) 26 | assert_equal(c.size, 2) 27 | 28 | 29 | fn test_get_tensor_reference() raises: 30 | alias t1_shape = TensorShape(1, 10) 31 | alias t2_shape = TensorShape(2, 20) 32 | var s1 = Symbol(0, dtype, t1_shape, True) 33 | var s2 = Symbol(1, dtype, t2_shape, True) 34 | 35 | var t1 = Tensor[dtype](s1.shape) 36 | var t2 = Tensor[dtype](s2.shape) 37 | fill(t1, 1) 38 | fill(t2, 2) 39 | 40 | var c = Collection(capacity=2) 41 | c.append(t1 ^, s1) 42 | c.append(t2 ^, s2) 43 | 44 | var t1_expected = Tensor[dtype](s1.shape) 45 | var t2_expected = Tensor[dtype](s2.shape) 46 | fill(t1_expected, 1) 47 | fill(t2_expected, 2) 48 | 49 | assert_tensors_equal(c[s1], t1_expected) 50 | assert_tensors_equal(c[s2], t2_expected) 51 | 52 | 53 | fn test_resize_collection() raises: 54 | alias t1_shape = TensorShape(1, 10) 55 | alias t2_shape = TensorShape(2, 20) 56 | alias t3_shape = TensorShape(3, 30) 57 | var s1 = Symbol(0, dtype, t1_shape, True) 58 | var s2 = Symbol(1, dtype, t2_shape, True) 59 | var s3 = Symbol(2, dtype, t3_shape, True) 60 | 61 | var t1 = Tensor[dtype](s1.shape) 62 | var t2 = Tensor[dtype](s2.shape) 63 | var t3 = Tensor[dtype](s3.shape) 64 | fill(t1, 1) 65 | fill(t2, 2) 66 | fill(t3, 3) 67 | 68 | var c = Collection(capacity=1) 69 | assert_equal(c.size, 0) 70 | assert_equal(c.capacity, 1) 71 | 72 | c.append(t1 ^, s1) 73 | assert_equal(c.size, 1) 74 | assert_equal(c.capacity, 1) 75 | 76 | c.append(t2 ^, s2) 77 | assert_equal(c.size, 2) 78 | assert_equal(c.capacity, 2) 79 | 80 | c.append(t3 ^, s3) 81 | assert_equal(c.size, 3) 82 | assert_equal(c.capacity, 4) 83 | 84 | var t1_expected = Tensor[dtype](s1.shape) 85 | var t2_expected = Tensor[dtype](s2.shape) 86 | var t3_expected = Tensor[dtype](s3.shape) 87 | fill(t1_expected, 1) 88 | fill(t2_expected, 2) 89 | fill(t3_expected, 3) 90 | 91 | assert_tensors_equal(c[s1], t1_expected) 92 | assert_tensors_equal(c[s2], t2_expected) 93 | assert_tensors_equal(c[s3], t3_expected) 94 | 95 | 96 | fn test_set_zero() raises: 97 | alias t1_shape = TensorShape(1, 10) 98 | alias t2_shape = TensorShape(2, 20) 99 | var s1 = Symbol(0, dtype, t1_shape, True) 100 | var s2 = Symbol(1, dtype, t2_shape, True) 101 | var t1 = Tensor[dtype](s1.shape) 102 | var t2 = Tensor[dtype](s2.shape) 103 | fill(t1, 1) 104 | fill(t2, 2) 105 | 106 | var c = Collection(capacity=2) 107 | c.append(t1 ^, s1) 108 | c.append(t2 ^, s2) 109 | 110 | var t1_expected = Tensor[dtype](s1.shape) 111 | var t2_expected = Tensor[dtype](s2.shape) 112 | fill(t1_expected, 1) 113 | fill(t2_expected, 2) 114 | assert_tensors_equal(c[s1], t1_expected) 115 | assert_tensors_equal(c[s2], t2_expected) 116 | 117 | c.set_zero() 118 | 119 | assert_tensors_equal(c[s1], Tensor[dtype](t1_shape)) 120 | assert_tensors_equal(c[s2], Tensor[dtype](t2_shape)) 121 | 122 | 123 | fn test_operate_on_reference() raises: 124 | alias res_shape = TensorShape(1, 10) 125 | alias t1_shape = TensorShape(1, 10) 126 | var sr = Symbol(0, dtype, t1_shape, True) 127 | var s1 = Symbol(1, dtype, t1_shape, True) 128 | var res = Tensor[dtype](res_shape) 129 | var t1 = Tensor[dtype](s1.shape) 130 | 131 | var c = Collection(capacity=2) 132 | c.append(res ^, sr) 133 | c.append(t1 ^, s1) 134 | 135 | fn some_operation[ 136 | res_shape: TensorShape, t_shape: TensorShape 137 | ](inout res: Tensor[dtype], t1: Tensor[dtype]): 138 | for i in range(res.num_elements()): 139 | res[i] = t1[i] 140 | 141 | for i in range(1, 10): 142 | some_operation[res_shape, t1_shape](c[sr], c[s1]) 143 | fill(c[s1], i) 144 | 145 | var res_expected = Tensor[dtype](res_shape) 146 | var t1_expected = Tensor[dtype](t1_shape) 147 | fill(res_expected, i - 1) 148 | fill(t1_expected, i) 149 | 150 | assert_tensors_equal(c[sr], res_expected) 151 | assert_tensors_equal(c[s1], t1_expected) 152 | 153 | 154 | fn main() raises: 155 | try: 156 | test_append_tensors() 157 | test_get_tensor_reference() 158 | test_resize_collection() 159 | test_set_zero() 160 | test_operate_on_reference() 161 | except e: 162 | print(e) 163 | raise e 164 | -------------------------------------------------------------------------------- /tests/mojo/test_loss.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal, assert_almost_equal 2 | 3 | from basalt import dtype, nelts 4 | from basalt.autograd import Graph, Symbol, OP 5 | from basalt.nn import Model, Tensor, TensorShape, MSELoss, CrossEntropyLoss 6 | from basalt.utils.tensorutils import fill 7 | 8 | 9 | fn test_MSE_perfect() raises: 10 | alias y_pred_shape = TensorShape(2, 10) # batch of 2, 10 classes 11 | alias y_true_shape = TensorShape(2, 10) 12 | 13 | fn create_graph() -> Graph: 14 | var g = Graph() 15 | 16 | var y_pred = g.input(y_pred_shape) 17 | var y_true = g.input(y_true_shape) 18 | 19 | var loss = MSELoss(g, y_pred, y_true) 20 | 21 | g.out(loss) 22 | 23 | return g ^ 24 | 25 | alias graph = create_graph() 26 | assert_equal(len(graph.nodes), 3) 27 | 28 | var y_pred = Tensor[dtype](y_pred_shape) 29 | var y_true = Tensor[dtype](y_true_shape) 30 | 31 | fill(y_pred, 1) 32 | fill(y_true, 1) 33 | 34 | var model = Model[graph](inference_only=True) 35 | 36 | var loss = model.inference(y_pred, y_true)[0] 37 | 38 | assert_equal(loss.dim(0), 1) # MSE summed over all elements 39 | assert_equal(loss[0], 0) # loss is 0 40 | 41 | 42 | fn test_MSE_imperfect() raises: 43 | alias y_pred_shape = TensorShape(1, 10) # batch of 1, 10 classes 44 | alias y_true_shape = TensorShape(1, 10) 45 | 46 | fn create_graph() -> Graph: 47 | var g = Graph() 48 | 49 | var y_pred = g.input(y_pred_shape) 50 | var y_true = g.input(y_true_shape) 51 | 52 | var loss = MSELoss(g, y_pred, y_true) 53 | 54 | g.out(loss) 55 | 56 | return g ^ 57 | 58 | alias graph = create_graph() 59 | assert_equal(len(graph.nodes), 3) 60 | 61 | var y_pred = Tensor[dtype](y_pred_shape) 62 | var y_true = Tensor[dtype](y_true_shape) 63 | 64 | fill(y_pred, 1) 65 | 66 | for i in range(10): 67 | y_true[i] = i 68 | 69 | var model = Model[graph](inference_only=True) 70 | 71 | var loss = model.inference(y_pred, y_true)[0] 72 | 73 | var expected_loss: Scalar[dtype] = 0.0 74 | 75 | for i in range(10): 76 | expected_loss += (y_pred[i] - y_true[i]) ** 2 77 | 78 | expected_loss = expected_loss / y_true_shape[1] 79 | 80 | assert_almost_equal(loss[0], expected_loss) 81 | 82 | 83 | fn test_CrossEntropy_perfect() raises: 84 | alias y_pred_shape = TensorShape(2, 3) # batch of 2, 3 classes 85 | alias y_true_shape = TensorShape(2, 3) 86 | 87 | fn create_graph() -> Graph: 88 | var g = Graph() 89 | 90 | var y_pred = g.input(y_pred_shape) 91 | var y_true = g.input(y_true_shape) 92 | 93 | var loss = CrossEntropyLoss(g, y_pred, y_true) 94 | 95 | g.out(loss) 96 | 97 | return g ^ 98 | 99 | alias graph = create_graph() 100 | assert_equal(len(graph.nodes), 9) 101 | 102 | var y_pred = Tensor[dtype](y_pred_shape) 103 | var y_true = Tensor[dtype](y_true_shape) 104 | 105 | y_pred[0 * y_pred.dim(1) + 0] = 0.1 106 | y_pred[0 * y_pred.dim(1) + 1] = 0.2 107 | y_pred[0 * y_pred.dim(1) + 2] = 0.7 108 | y_true[0 * y_true.dim(1) + 0] = 0 109 | y_true[0 * y_true.dim(1) + 1] = 0 110 | y_true[0 * y_true.dim(1) + 2] = 1 111 | 112 | y_pred[1 * y_pred.dim(1) + 0] = 0.7 113 | y_pred[1 * y_pred.dim(1) + 1] = 0.2 114 | y_pred[1 * y_pred.dim(1) + 2] = 0.1 115 | y_true[1 * y_true.dim(1) + 0] = 1 116 | y_true[1 * y_true.dim(1) + 1] = 0 117 | y_true[1 * y_true.dim(1) + 2] = 0 118 | 119 | var model = Model[graph](inference_only=True) 120 | 121 | var loss = model.inference(y_pred, y_true)[0] 122 | 123 | assert_equal(loss.shape(), TensorShape(1)) 124 | assert_almost_equal(loss[0], 0.76794958) 125 | 126 | 127 | fn test_CrossEntropy_imperfect() raises: 128 | alias y_pred_shape = TensorShape(2, 3) # batch of 2, 3 classes 129 | alias y_true_shape = TensorShape(2, 3) 130 | 131 | fn create_graph() -> Graph: 132 | var g = Graph() 133 | 134 | var y_pred = g.input(y_pred_shape) 135 | var y_true = g.input(y_true_shape) 136 | 137 | var loss = CrossEntropyLoss(g, y_pred, y_true) 138 | 139 | g.out(loss) 140 | 141 | return g ^ 142 | 143 | alias graph = create_graph() 144 | 145 | var y_pred = Tensor[dtype](y_pred_shape) 146 | var y_true = Tensor[dtype](y_true_shape) 147 | 148 | y_pred[0 * y_pred.dim(1) + 0] = 0.1 149 | y_pred[0 * y_pred.dim(1) + 1] = 0.2 150 | y_pred[0 * y_pred.dim(1) + 2] = 0.7 151 | y_true[0 * y_true.dim(1) + 0] = 0 152 | y_true[0 * y_true.dim(1) + 1] = 1 153 | y_true[0 * y_true.dim(1) + 2] = 0 154 | 155 | y_pred[1 * y_pred.dim(1) + 0] = 0.7 156 | y_pred[1 * y_pred.dim(1) + 1] = 0.2 157 | y_pred[1 * y_pred.dim(1) + 2] = 0.1 158 | y_true[1 * y_true.dim(1) + 0] = 0 159 | y_true[1 * y_true.dim(1) + 1] = 0 160 | y_true[1 * y_true.dim(1) + 2] = 1 161 | 162 | var model = Model[graph](inference_only=True) 163 | 164 | var loss = model.inference(y_pred, y_true)[0] 165 | 166 | assert_equal(loss.shape(), TensorShape(1)) 167 | assert_almost_equal(loss[0], 1.31794953) 168 | 169 | 170 | fn main(): 171 | try: 172 | test_MSE_perfect() 173 | test_MSE_imperfect() 174 | test_CrossEntropy_perfect() 175 | test_CrossEntropy_imperfect() 176 | except e: 177 | print("[ERROR] Error in loss") 178 | print(e) 179 | -------------------------------------------------------------------------------- /tests/mojo/test_ops.mojo: -------------------------------------------------------------------------------- 1 | from math import exp, log 2 | from utils.index import IndexList 3 | 4 | from basalt import dtype, nelts 5 | from basalt.autograd import OP 6 | from basalt.autograd.attributes import Attribute, AttributeVector 7 | from basalt.utils.tensorutils import fill 8 | from basalt.nn import Tensor, TensorShape 9 | 10 | from tests import test_unary_op, test_binary_op, test_ternary_op 11 | 12 | 13 | fn test_ADD() raises: 14 | alias t1_shape = TensorShape(2, 3) 15 | alias t2_shape = TensorShape(2, 3) 16 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 17 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 18 | fill(t1, 1.0) 19 | fill(t2, 1.0) 20 | 21 | var expected = Tensor[dtype](2, 3) 22 | fill(expected, 2.0) 23 | 24 | test_binary_op[OP.ADD, t1_shape, t2_shape](t1, t2, expected) 25 | 26 | 27 | fn test_SUB() raises: 28 | alias t1_shape = TensorShape(2, 3) 29 | alias t2_shape = TensorShape(2, 3) 30 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 31 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 32 | fill(t1, 2.0) 33 | fill(t2, 1.0) 34 | 35 | var expected = Tensor[dtype](2, 3) 36 | fill(expected, 1.0) 37 | 38 | test_binary_op[OP.SUB, t1_shape, t2_shape](t1, t2, expected) 39 | 40 | 41 | fn test_MUL() raises: 42 | alias t1_shape = TensorShape(2, 3) 43 | alias t2_shape = TensorShape(2, 3) 44 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 45 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 46 | fill(t1, 2.0) 47 | fill(t2, 3.0) 48 | 49 | var expected = Tensor[dtype](2, 3) 50 | fill(expected, 6.0) 51 | 52 | test_binary_op[OP.MUL, t1_shape, t2_shape](t1, t2, expected) 53 | 54 | 55 | fn test_DIV() raises: 56 | alias t1_shape = TensorShape(2, 3) 57 | alias t2_shape = TensorShape(2, 3) 58 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 59 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 60 | fill(t1, 6.0) 61 | fill(t2, 2.0) 62 | 63 | var expected = Tensor[dtype](2, 3) 64 | fill(expected, 3.0) 65 | 66 | test_binary_op[OP.DIV, t1_shape, t2_shape](t1, t2, expected) 67 | 68 | 69 | fn test_DOT() raises: 70 | alias t1_shape = TensorShape(2, 3) 71 | alias t2_shape = TensorShape(3, 2) 72 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 73 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 74 | fill(t1, 1.0) 75 | fill(t2, 2.0) 76 | 77 | var expected = Tensor[dtype](2, 2) 78 | fill(expected, 6.0) 79 | 80 | test_binary_op[OP.DOT, t1_shape, t2_shape](t1, t2, expected) 81 | 82 | 83 | fn test_EXP() raises: 84 | alias t1_shape = TensorShape(2, 3) 85 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 86 | fill(t1, 2.0) 87 | 88 | var expected = Tensor[dtype](2, 3) 89 | fill(expected, exp(SIMD[dtype, 1](2.0))) 90 | 91 | test_unary_op[OP.EXP, t1_shape](t1, expected) 92 | 93 | 94 | fn test_LOG() raises: 95 | alias t1_shape = TensorShape(2, 3) 96 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 97 | fill(t1, 2.0) 98 | 99 | var expected = Tensor[dtype](2, 3) 100 | fill(expected, log(SIMD[dtype, 1](2.0))) 101 | 102 | test_unary_op[OP.LOG, t1_shape](t1, expected) 103 | 104 | 105 | fn test_POW() raises: 106 | alias t1_shape = TensorShape(2, 3) 107 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 108 | fill(t1, 2.0) 109 | 110 | alias t2_shape = TensorShape(1) 111 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 112 | t2[0] = 2.0 113 | 114 | var expected = Tensor[dtype](2, 3) 115 | fill(expected, 4.0) 116 | 117 | test_binary_op[OP.POW, t1_shape, t2_shape](t1, t2, expected) 118 | 119 | 120 | fn test_SUM() raises: 121 | alias t1_shape = TensorShape(2, 3, 4) 122 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 123 | fill(t1, 1.0) 124 | 125 | # No axis specified 126 | var expected = Tensor[dtype](1) 127 | fill(expected, 24.0) 128 | test_unary_op[OP.SUM, t1_shape](t1, expected) 129 | 130 | # Test axis 1 131 | alias attrs = AttributeVector(Attribute("axis", 1)) 132 | expected = Tensor[dtype](2, 1, 4) 133 | fill(expected, 3.0) 134 | test_unary_op[OP.SUM, t1_shape, attrs](t1, expected) 135 | 136 | 137 | fn test_MAX() raises: 138 | alias t1_shape = TensorShape(2, 3, 2) 139 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 140 | for i in range(t1_shape.num_elements()): 141 | t1[i] = i + 1 142 | 143 | # No axis specified 144 | var expected = Tensor[dtype](1) 145 | fill(expected, t1_shape.num_elements()) 146 | test_unary_op[OP.MAX, t1_shape](t1, expected) 147 | 148 | @parameter 149 | fn fill_tensor[ 150 | size: Int 151 | ](inout tensor: Tensor[dtype], values: IndexList[size]): 152 | for i in range(tensor.num_elements()): 153 | tensor[i] = values[i] 154 | 155 | # Test axis 0 156 | alias attrs = AttributeVector(Attribute("axis", 0)) 157 | var expected_max_axis_0_temp = IndexList[6](7, 8, 9, 10, 11, 12) 158 | expected = Tensor[dtype](1, 3, 2) 159 | fill_tensor(expected, expected_max_axis_0_temp) 160 | test_unary_op[OP.MAX, t1_shape, attrs](t1, expected) 161 | 162 | # Test axis 1 163 | alias attrs_1 = AttributeVector(Attribute("axis", 1)) 164 | var expected_max_axis_1_temp = IndexList[4](5, 6, 11, 12) 165 | expected = Tensor[dtype](2, 1, 2) 166 | fill_tensor(expected, expected_max_axis_1_temp) 167 | test_unary_op[OP.MAX, t1_shape, attrs_1](t1, expected) 168 | 169 | # Test axis 2 170 | alias attrs_2 = AttributeVector(Attribute("axis", 2)) 171 | var expected_max_axis_2_temp = IndexList[6](2, 4, 6, 8, 10, 12) 172 | expected = Tensor[dtype](2, 3, 1) 173 | fill_tensor(expected, expected_max_axis_2_temp) 174 | test_unary_op[OP.MAX, t1_shape, attrs_2](t1, expected) 175 | 176 | 177 | fn test_MEAN() raises: 178 | alias t1_shape = TensorShape(2, 3) 179 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 180 | fill(t1, 5.0) 181 | 182 | # No axis specified 183 | var expected = Tensor[dtype](1) 184 | fill(expected, 5.0) 185 | test_unary_op[OP.MEAN, t1_shape](t1, expected) 186 | 187 | # Test axis 0 188 | alias attrs = AttributeVector(Attribute("axis", 0)) 189 | expected = Tensor[dtype](1, 3) 190 | fill(expected, 5.0) 191 | test_unary_op[OP.MEAN, t1_shape, attrs](t1, expected) 192 | 193 | # Test axis 1 194 | alias attrs_1 = AttributeVector(Attribute("axis", 1)) 195 | expected = Tensor[dtype](2, 1) 196 | fill(expected, 5.0) 197 | test_unary_op[OP.MEAN, t1_shape, attrs_1](t1, expected) 198 | 199 | 200 | fn test_TRANSPOSE() raises: 201 | alias t1_shape = TensorShape(2, 3, 4) 202 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 203 | for i in range(t1_shape.num_elements()): 204 | t1[i] = i + 1 205 | 206 | # Test tranpose (no attributes = reversing the axis by default) 207 | var expected = Tensor[dtype](4, 3, 2) 208 | var expected_strides = expected.strides() 209 | for i in range(t1_shape[0]): 210 | for j in range(t1_shape[1]): 211 | for k in range(t1_shape[2]): 212 | expected[k * expected_strides[0] + j * expected_strides[1] + i] = t1[ 213 | i * t1_shape[1] * t1_shape[2] + j * t1_shape[2] + k 214 | ] 215 | 216 | test_unary_op[OP.TRANSPOSE, t1_shape](t1, expected) 217 | 218 | # Test tranpose 1, 2, 0 219 | alias attrs = AttributeVector(Attribute("axes", TensorShape(1, 2, 0))) 220 | var expected_axis_1 = Tensor[dtype](3, 4, 2) 221 | var expected_axis_1_strides = expected_axis_1.strides() 222 | for i in range(t1_shape[0]): 223 | for j in range(t1_shape[1]): 224 | for k in range(t1_shape[2]): 225 | expected_axis_1[ 226 | j * expected_axis_1_strides[0] + k * expected_axis_1_strides[1] + i 227 | ] = t1[i * t1_shape[1] * t1_shape[2] + j * t1_shape[2] + k] 228 | 229 | test_unary_op[OP.TRANSPOSE, t1_shape, attrs](t1, expected_axis_1) 230 | 231 | 232 | fn test_FLATTEN() raises: 233 | alias t1_shape = TensorShape(2, 3, 4) 234 | var t1 = Tensor[dtype](t1_shape) 235 | fill(t1, 1.0) 236 | 237 | var expected = Tensor[dtype](24) 238 | fill(expected, 1.0) 239 | 240 | test_unary_op[OP.FLATTEN, t1_shape](t1, expected) 241 | 242 | 243 | fn test_RESHAPE() raises: 244 | alias t_shape = TensorShape(2, 2, 5) 245 | alias new_shape = TensorShape(2, 10) 246 | 247 | var t = Tensor[dtype](t_shape) 248 | var expected = Tensor[dtype](new_shape) 249 | for i in range(20): 250 | t[i] = i + 1 251 | expected[i] = i + 1 252 | 253 | alias attrs = AttributeVector(Attribute("shape", new_shape)) 254 | test_unary_op[OP.RESHAPE, t_shape, attrs](t, expected) 255 | 256 | 257 | fn test_FMA() raises: 258 | alias t1_shape = TensorShape(2, 3) 259 | alias t2_shape = TensorShape(2, 3) 260 | alias t3_shape = TensorShape(2, 3) 261 | var t1: Tensor[dtype] = Tensor[dtype](t1_shape) 262 | var t2: Tensor[dtype] = Tensor[dtype](t2_shape) 263 | var t3: Tensor[dtype] = Tensor[dtype](t3_shape) 264 | fill(t1, 1.0) 265 | fill(t2, 2.0) 266 | fill(t3, 3.0) 267 | 268 | var expected = Tensor[dtype](2, 3) 269 | fill(expected, 1.0 * 2.0 + 3.0) 270 | 271 | test_ternary_op[OP.FMA, t1_shape, t2_shape, t3_shape](t1, t2, t3, expected) 272 | 273 | 274 | fn main(): 275 | try: 276 | test_ADD() 277 | test_SUB() 278 | test_MUL() 279 | test_DIV() 280 | test_DOT() 281 | test_EXP() 282 | test_LOG() 283 | test_POW() 284 | test_SUM() 285 | test_MAX() 286 | test_MEAN() 287 | test_TRANSPOSE() 288 | test_FLATTEN() 289 | test_RESHAPE() 290 | test_FMA() 291 | except e: 292 | print("[ERROR] Error in ops") 293 | print(e) 294 | -------------------------------------------------------------------------------- /tests/python/test_broadcast_shapes.mojo: -------------------------------------------------------------------------------- 1 | from python.python import Python, PythonObject 2 | from testing import assert_true 3 | 4 | from basalt.nn import Tensor, TensorShape 5 | from basalt.utils.tensorutils import broadcast_shapes 6 | 7 | 8 | fn to_tensor_shape(owned shape: PythonObject) raises -> TensorShape: 9 | var tensor_shape = List[Int]() 10 | for dim in shape: 11 | tensor_shape.append(int(float(dim))) 12 | return TensorShape(tensor_shape) 13 | 14 | 15 | fn np_broadcast_shapes(s1: TensorShape, s2: TensorShape) raises -> TensorShape: 16 | var np = Python.import_module("numpy") 17 | var s1_py: PythonObject = [] 18 | var s2_py: PythonObject = [] 19 | for i in range(s1.rank()): 20 | s1_py += [s1[i]] 21 | for i in range(s2.rank()): 22 | s2_py += [s2[i]] 23 | 24 | var py_shape = np.broadcast_shapes(s1_py, s2_py) 25 | 26 | return to_tensor_shape(py_shape) 27 | 28 | 29 | fn test_broadcast_shapes() raises: 30 | var s1 = TensorShape(3, 5, 2) 31 | var s2 = TensorShape(3, 5, 2) 32 | var s3 = broadcast_shapes(s1, s2) 33 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 34 | 35 | s1 = TensorShape(3, 5, 2) 36 | s2 = TensorShape(1, 2) 37 | s3 = broadcast_shapes(s1, s2) 38 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 39 | 40 | s1 = TensorShape(5, 1) 41 | s2 = TensorShape(3, 5, 1) 42 | s3 = broadcast_shapes(s1, s2) 43 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 44 | 45 | s1 = TensorShape(3, 1, 2) 46 | s2 = TensorShape(3, 5, 2) 47 | s3 = broadcast_shapes(s1, s2) 48 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 49 | 50 | s1 = TensorShape(1, 1, 1) 51 | s2 = TensorShape(3, 5, 2) 52 | s3 = broadcast_shapes(s1, s2) 53 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 54 | 55 | s1 = TensorShape(2) 56 | s2 = TensorShape(3, 5, 2) 57 | s3 = broadcast_shapes(s1, s2) 58 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 59 | 60 | s1 = TensorShape() 61 | s2 = TensorShape(3, 5, 2) 62 | s3 = broadcast_shapes(s1, s2) 63 | assert_true(s3 == np_broadcast_shapes(s1, s2)) 64 | 65 | # # Both errors expected 66 | # print("EXPECTED RAISE!") 67 | # try: 68 | # s1 = TensorShape(3, 2, 2) 69 | # s2 = TensorShape(3, 5, 2) 70 | # s3 = broadcast_shapes(s1, s2) 71 | # _ = np_broadcast_shapes(s1, s2) 72 | # except e: 73 | # print("Numpy:", e) 74 | 75 | # print("EXPECTED RAISE!") 76 | # try: 77 | # s1 = TensorShape(3) 78 | # s2 = TensorShape(2) 79 | # s3 = broadcast_shapes(s1, s2) 80 | # _ = np_broadcast_shapes(s1, s2) 81 | # except e: 82 | # print("Numpy:", e) 83 | 84 | 85 | fn test_broadcast_shapes_multiple() raises: 86 | var np = Python.import_module("numpy") 87 | 88 | var s1 = TensorShape(1, 2) 89 | var s2 = TensorShape(3, 1) 90 | var s3 = TensorShape(3, 2) 91 | var res = broadcast_shapes(s1, s2, s3) 92 | var res_np = to_tensor_shape(np.broadcast_shapes((1, 2), (3, 1), (3, 2))) 93 | assert_true(res == res_np) 94 | 95 | s1 = TensorShape(6, 7) 96 | s2 = TensorShape(5, 6, 1) 97 | s3 = TensorShape(7) 98 | var s4 = TensorShape(5, 1, 7) 99 | res = broadcast_shapes(s1, s2, s3, s4) 100 | res_np = to_tensor_shape(np.broadcast_shapes((6, 7), (5, 6, 1), (7), (5, 1, 7))) 101 | assert_true(res == res_np) 102 | 103 | 104 | fn main(): 105 | try: 106 | test_broadcast_shapes() 107 | test_broadcast_shapes_multiple() 108 | except e: 109 | print("[Error] In test broadcasting.") 110 | print(e) 111 | -------------------------------------------------------------------------------- /tests/python/test_dynamic_ops_torch.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from python.python import Python, PythonObject 3 | 4 | from basalt import dtype, nelts 5 | from basalt.autograd import Graph, Symbol, OP 6 | from basalt.autograd.attributes import Attribute, AttributeVector 7 | from basalt.nn import Model, Tensor, TensorShape 8 | 9 | from tests import ( 10 | assert_tensors_equal, 11 | to_numpy, 12 | to_tensor, 13 | create_graph_concat, 14 | create_graph_split, 15 | ) 16 | 17 | 18 | @value 19 | struct torch_output_cat: 20 | var expected: Tensor[dtype] 21 | var grad_1: Tensor[dtype] 22 | var grad_2: Tensor[dtype] 23 | var grad_3: Tensor[dtype] 24 | 25 | 26 | fn torch_cat( 27 | input_1: Tensor, input_2: Tensor, input_3: Tensor, upper_grad: Tensor, dim: Int 28 | ) -> torch_output_cat: 29 | try: 30 | var py = Python.import_module("builtins") 31 | var torch = Python.import_module("torch") 32 | var np = Python.import_module("numpy") 33 | 34 | var input_1 = torch.from_numpy(to_numpy(input_1)).requires_grad_(True) 35 | var input_2 = torch.from_numpy(to_numpy(input_2)).requires_grad_(True) 36 | var input_3 = torch.from_numpy(to_numpy(input_3)).requires_grad_(True) 37 | 38 | var expected: PythonObject 39 | 40 | var tensors = py.list() 41 | tensors.append(input_1) 42 | tensors.append(input_2) 43 | tensors.append(input_3) 44 | expected = torch.cat(tensors, dim=dim) 45 | 46 | # uppergrad & backwards 47 | var upper_grad = torch.from_numpy(to_numpy(upper_grad)) 48 | _ = expected.backward(upper_grad) 49 | 50 | return torch_output_cat( 51 | to_tensor(expected.detach().numpy()), 52 | to_tensor(input_1.grad.numpy()), 53 | to_tensor(input_2.grad.numpy()), 54 | to_tensor(input_3.grad.numpy()), 55 | ) 56 | 57 | except e: 58 | print("Error importing torch: ", e) 59 | var d = Tensor[dtype](1) 60 | return torch_output_cat(d, d, d, d) 61 | 62 | 63 | fn test_CONCAT() raises: 64 | alias t1_shape = TensorShape(11, 3, 17, 19) 65 | alias t2_shape = TensorShape(11, 3, 17, 19) 66 | alias t3_shape = TensorShape(11, 3, 17, 19) 67 | var t1 = Tensor[dtype](t1_shape) 68 | var t2 = Tensor[dtype](t2_shape) 69 | var t3 = Tensor[dtype](t3_shape) 70 | rand(t1.data(), t1.num_elements()) 71 | rand(t2.data(), t2.num_elements()) 72 | rand(t3.data(), t3.num_elements()) 73 | 74 | # default: dim = 0 75 | alias graph = create_graph_concat(t1_shape, t2_shape, t3_shape, dim=0) 76 | var model = Model[graph]() 77 | var res = model.forward(t1, t2, t3) 78 | 79 | alias ug_shape = TensorShape(33, 3, 17, 19) 80 | var ug = Tensor[dtype](ug_shape) 81 | rand(ug.data(), ug.num_elements()) 82 | 83 | var expected_and_grad = torch_cat(t1, t2, t3, ug, dim=0) 84 | model.backward(ug) 85 | 86 | assert_tensors_equal["almost"](res, expected_and_grad.expected) 87 | assert_tensors_equal["almost"]( 88 | model.parameters.grads[graph.nodes[0].inputs[0]], 89 | expected_and_grad.grad_1, 90 | ) 91 | assert_tensors_equal["almost"]( 92 | model.parameters.grads[graph.nodes[0].inputs[1]], 93 | expected_and_grad.grad_2, 94 | ) 95 | assert_tensors_equal["almost"]( 96 | model.parameters.grads[graph.nodes[0].inputs[2]], 97 | expected_and_grad.grad_3, 98 | ) 99 | 100 | # dim = 2 101 | alias graph_2 = create_graph_concat(t1_shape, t2_shape, t3_shape, dim=2) 102 | var model_2 = Model[graph_2]() 103 | var res_2 = model_2.forward(t1, t2, t3) 104 | 105 | alias ug_shape_2 = TensorShape(11, 3, 51, 19) 106 | var ug_2 = Tensor[dtype](ug_shape_2) 107 | rand(ug_2.data(), ug_2.num_elements()) 108 | 109 | var expected_and_grad_2 = torch_cat(t1, t2, t3, ug_2, dim=2) 110 | model_2.backward(ug_2) 111 | 112 | assert_tensors_equal["almost"](res_2, expected_and_grad_2.expected) 113 | assert_tensors_equal["almost"]( 114 | model_2.parameters.grads[graph_2.nodes[0].inputs[0]], 115 | expected_and_grad_2.grad_1, 116 | ) 117 | assert_tensors_equal["almost"]( 118 | model_2.parameters.grads[graph_2.nodes[0].inputs[1]], 119 | expected_and_grad_2.grad_2, 120 | ) 121 | assert_tensors_equal["almost"]( 122 | model_2.parameters.grads[graph_2.nodes[0].inputs[2]], 123 | expected_and_grad_2.grad_3, 124 | ) 125 | 126 | 127 | @value 128 | struct torch_output_split: 129 | var expected1: Tensor[dtype] 130 | var expected2: Tensor[dtype] 131 | var expected3: Tensor[dtype] 132 | var grad: Tensor[dtype] 133 | 134 | 135 | fn torch_split( 136 | input: Tensor, 137 | upper_grad_1: Tensor, 138 | upper_grad_2: Tensor, 139 | upper_grad_3: Tensor, 140 | sections: List[Int], 141 | dim: Int, 142 | ) -> torch_output_split: 143 | try: 144 | var py = Python.import_module("builtins") 145 | var torch = Python.import_module("torch") 146 | var np = Python.import_module("numpy") 147 | 148 | var input = torch.from_numpy(to_numpy(input)).requires_grad_(True) 149 | 150 | var sizes = py.list() 151 | sizes.append(sections[0]) 152 | sizes.append(sections[1]) 153 | sizes.append(sections[2]) 154 | 155 | var chunks: PythonObject = input.split(sizes, dim=dim) 156 | 157 | # uppergrad & backwards 158 | var upper_grad_1 = torch.from_numpy(to_numpy(upper_grad_1)) 159 | var upper_grad_2 = torch.from_numpy(to_numpy(upper_grad_2)) 160 | var upper_grad_3 = torch.from_numpy(to_numpy(upper_grad_3)) 161 | _ = chunks[0].backward(upper_grad_1) 162 | _ = chunks[1].backward(upper_grad_2) 163 | _ = chunks[2].backward(upper_grad_3) 164 | 165 | return torch_output_split( 166 | to_tensor(chunks[0].detach().numpy()), 167 | to_tensor(chunks[1].detach().numpy()), 168 | to_tensor(chunks[2].detach().numpy()), 169 | to_tensor(input.grad.numpy()), 170 | ) 171 | 172 | except e: 173 | print("Error importing torch: ", e) 174 | var d = Tensor[dtype](1) 175 | return torch_output_split(d, d, d, d) 176 | 177 | 178 | fn test_SPLIT() raises: 179 | alias t1_shape = TensorShape(11, 3, 17, 19) 180 | var t1 = Tensor[dtype](t1_shape) 181 | rand(t1.data(), t1.num_elements()) 182 | 183 | # default: dim = 0 184 | alias sections = List[Int](3, 6, 2) # 11 185 | alias graph = create_graph_split(t1_shape, sections, dim=0) 186 | var model = Model[graph]() 187 | var results = model.inference(t1) 188 | 189 | alias ug1_shape = TensorShape(3, 3, 17, 19) 190 | alias ug2_shape = TensorShape(6, 3, 17, 19) 191 | alias ug3_shape = TensorShape(2, 3, 17, 19) 192 | var ug1 = Tensor[dtype](ug1_shape) 193 | var ug2 = Tensor[dtype](ug2_shape) 194 | var ug3 = Tensor[dtype](ug3_shape) 195 | rand(ug1.data(), ug1.num_elements()) 196 | rand(ug2.data(), ug2.num_elements()) 197 | rand(ug3.data(), ug3.num_elements()) 198 | 199 | var expected_and_grad = torch_split(t1, ug1, ug2, ug3, sections, dim=0) 200 | model.backward(ug1, ug2, ug3) 201 | 202 | assert_tensors_equal["almost"](results[0], expected_and_grad.expected1) 203 | assert_tensors_equal["almost"](results[1], expected_and_grad.expected2) 204 | assert_tensors_equal["almost"](results[2], expected_and_grad.expected3) 205 | assert_tensors_equal["almost"]( 206 | model.parameters.grads[graph.nodes[0].inputs[0]], 207 | expected_and_grad.grad, 208 | ) 209 | 210 | # dim = 2 211 | alias sections_2 = List[Int](3, 6, 8) # 17 212 | alias graph_2 = create_graph_split(t1_shape, sections_2, dim=2) 213 | var model_2 = Model[graph_2]() 214 | var results_2 = model_2.inference(t1) 215 | 216 | alias ug1_shape_2 = TensorShape(11, 3, 3, 19) 217 | alias ug2_shape_2 = TensorShape(11, 3, 6, 19) 218 | alias ug3_shape_2 = TensorShape(11, 3, 8, 19) 219 | var ug1_2 = Tensor[dtype](ug1_shape_2) 220 | var ug2_2 = Tensor[dtype](ug2_shape_2) 221 | var ug3_2 = Tensor[dtype](ug3_shape_2) 222 | rand(ug1_2.data(), ug1_2.num_elements()) 223 | rand(ug2_2.data(), ug2_2.num_elements()) 224 | rand(ug3_2.data(), ug3_2.num_elements()) 225 | 226 | var expected_and_grad_2 = torch_split(t1, ug1_2, ug2_2, ug3_2, sections_2, dim=2) 227 | model_2.backward(ug1_2, ug2_2, ug3_2) 228 | 229 | assert_tensors_equal["almost"](results_2[0], expected_and_grad_2.expected1) 230 | assert_tensors_equal["almost"](results_2[1], expected_and_grad_2.expected2) 231 | assert_tensors_equal["almost"](results_2[2], expected_and_grad_2.expected3) 232 | assert_tensors_equal["almost"]( 233 | model_2.parameters.grads[graph_2.nodes[0].inputs[0]], expected_and_grad_2.grad 234 | ) 235 | 236 | 237 | fn main(): 238 | print("Running dynamic ops (compare with torch) tests") 239 | try: 240 | test_CONCAT() 241 | test_SPLIT() 242 | except e: 243 | print("[ERROR] Error in dynamic ops (compare with torch)") 244 | print(e) 245 | return 246 | 247 | print("Finished dynamic ops (compare with torch) tests") 248 | -------------------------------------------------------------------------------- /tests/python/test_models_mnist.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from python import Python 3 | from testing import assert_almost_equal 4 | from utils.index import IndexList 5 | 6 | from basalt import dtype 7 | from basalt.autograd import Graph, OP 8 | from basalt.autograd.attributes import AttributeVector, Attribute 9 | from basalt.nn import ( 10 | Tensor, 11 | TensorShape, 12 | Model, 13 | ReLU, 14 | MaxPool2d, 15 | CrossEntropyLoss, 16 | optim, 17 | ) 18 | from basalt.autograd.params import Param 19 | 20 | from tests import assert_tensors_equal, to_numpy, to_tensor 21 | 22 | 23 | fn create_CNN( 24 | batch_size: Int, 25 | conv1_weights: List[Scalar[dtype]], 26 | conv1_bias: List[Scalar[dtype]], 27 | conv2_weights: List[Scalar[dtype]], 28 | conv2_bias: List[Scalar[dtype]], 29 | linear1_weights: List[Scalar[dtype]], 30 | linear1_bias: List[Scalar[dtype]], 31 | ) -> Graph: 32 | var g = Graph() 33 | var x = g.input(TensorShape(batch_size, 1, 28, 28)) 34 | 35 | # conv1 36 | # var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2) 37 | var c1_w = g.param(TensorShape(16, x.shape[1], 5, 5), init=Param(conv1_weights)) 38 | var c1_b = g.param(TensorShape(16), init=Param(conv1_bias)) 39 | var x1 = g.op( 40 | OP.CONV2D, 41 | x, 42 | c1_w, 43 | c1_b, 44 | attributes=AttributeVector( 45 | Attribute("padding", IndexList[2](2, 2)), 46 | Attribute("stride", IndexList[2](1, 1)), 47 | Attribute("dilation", IndexList[2](1, 1)), 48 | ), 49 | ) 50 | 51 | var x2 = ReLU(g, x1) 52 | var x3 = MaxPool2d(g, x2, kernel_size=2) 53 | 54 | # conv2 55 | # var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2) 56 | var c2_w = g.param(TensorShape(32, x3.shape[1], 5, 5), init=Param(conv2_weights)) 57 | var c2_b = g.param(TensorShape(32), init=Param(conv2_bias)) 58 | var x4 = g.op( 59 | OP.CONV2D, 60 | x3, 61 | c2_w, 62 | c2_b, 63 | attributes=AttributeVector( 64 | Attribute("padding", IndexList[2](2, 2)), 65 | Attribute("stride", IndexList[2](1, 1)), 66 | Attribute("dilation", IndexList[2](1, 1)), 67 | ), 68 | ) 69 | 70 | var x5 = ReLU(g, x4) 71 | var x6 = MaxPool2d(g, x5, kernel_size=2) 72 | var x6_shape = x6.shape 73 | var x7 = g.op( 74 | OP.RESHAPE, 75 | x6, 76 | attributes=AttributeVector( 77 | Attribute( 78 | "shape", 79 | TensorShape(x6_shape[0], x6_shape[1] * x6_shape[2] * x6_shape[3]), 80 | ) 81 | ), 82 | ) 83 | 84 | # linear1 85 | # var out = nn.Linear(g, x7, n_outputs=10) 86 | var l1_w = g.param(TensorShape(x7.shape[1], 10), init=Param(linear1_weights)) 87 | var l1_b = g.param(TensorShape(10), init=Param(linear1_bias)) 88 | var res = g.op(OP.DOT, x7, l1_w) 89 | var out = g.op(OP.ADD, res, l1_b) 90 | g.out(out) 91 | 92 | var y_true = g.input(TensorShape(batch_size, 10)) 93 | var loss = CrossEntropyLoss(g, out, y_true) 94 | # var loss = nn.MSELoss(g, out, y_true) 95 | g.loss(loss) 96 | 97 | return g ^ 98 | 99 | 100 | fn run_mojo[ 101 | batch_size: Int, 102 | conv1_weights: List[Scalar[dtype]], 103 | conv1_bias: List[Scalar[dtype]], 104 | conv2_weights: List[Scalar[dtype]], 105 | conv2_bias: List[Scalar[dtype]], 106 | linear1_weights: List[Scalar[dtype]], 107 | linear1_bias: List[Scalar[dtype]], 108 | ]( 109 | epochs: Int, 110 | learning_rate: Float64, 111 | inputs: Tensor[dtype], 112 | labels: Tensor[dtype], 113 | ) -> List[Scalar[dtype]]: 114 | alias graph = create_CNN( 115 | batch_size, 116 | conv1_weights, 117 | conv1_bias, 118 | conv2_weights, 119 | conv2_bias, 120 | linear1_weights, 121 | linear1_bias, 122 | ) 123 | 124 | var model = Model[graph]() 125 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]()) 126 | 127 | var losses = List[Scalar[dtype]]() 128 | 129 | for i in range(epochs): 130 | var loss = model.forward(inputs, labels) 131 | 132 | # Backward pass 133 | optim.zero_grad() 134 | model.backward() 135 | optim.step() 136 | 137 | losses.append(loss[0]) 138 | 139 | return losses 140 | 141 | 142 | fn run_torch( 143 | epochs: Int, 144 | learning_rate: Float64, 145 | inputs: Tensor, 146 | labels: Tensor, 147 | owned conv1_weights: Tensor, 148 | owned conv1_bias: Tensor, 149 | owned conv2_weights: Tensor, 150 | owned conv2_bias: Tensor, 151 | owned linear1_weights: Tensor, 152 | owned linear1_bias: Tensor, 153 | ) -> List[Scalar[dtype]]: 154 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]() 155 | 156 | try: 157 | var torch = Python.import_module("torch") 158 | var F = Python.import_module("torch.nn.functional") 159 | var np = Python.import_module("numpy") 160 | Python.add_to_path("./tests/python") 161 | var torch_models = Python.import_module("test_models_torch") 162 | 163 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True) 164 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True) 165 | 166 | var conv1_weights = torch.from_numpy(to_numpy(conv1_weights)).requires_grad_( 167 | True 168 | ) 169 | var conv1_bias = torch.from_numpy(to_numpy(conv1_bias)).requires_grad_(True) 170 | var conv2_weights = torch.from_numpy(to_numpy(conv2_weights)).requires_grad_( 171 | True 172 | ) 173 | var conv2_bias = torch.from_numpy(to_numpy(conv2_bias)).requires_grad_(True) 174 | var linear1_weights = torch.from_numpy( 175 | to_numpy(linear1_weights) 176 | ).requires_grad_(True) 177 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True) 178 | 179 | var cnn = torch_models.CNN( 180 | conv1_weights, 181 | conv1_bias, 182 | conv2_weights, 183 | conv2_bias, 184 | linear1_weights, 185 | linear1_bias, 186 | ) 187 | 188 | var loss_func = torch_models.CrossEntropyLoss2() 189 | # var loss_func = torch.nn.CrossEntropyLoss() 190 | var optimizer = torch.optim.Adam(cnn.parameters(), learning_rate) 191 | 192 | for i in range(epochs): 193 | var output = cnn.forward(inputs) 194 | var loss = loss_func(output, labels) 195 | 196 | _ = optimizer.zero_grad() 197 | _ = loss.backward() 198 | _ = optimizer.step() 199 | 200 | out.append(to_tensor(loss)[0]) 201 | 202 | return out 203 | 204 | except e: 205 | print("Error importing torch") 206 | print(e) 207 | return out 208 | 209 | 210 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]: 211 | var weights = List[Scalar[dtype]](capacity=num_elements) 212 | for i in range(num_elements): 213 | if zero: 214 | weights.append(Scalar[dtype](0.0)) 215 | else: 216 | weights.append(Scalar[dtype](0.02)) 217 | return weights ^ 218 | 219 | 220 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]: 221 | var t = Tensor[dtype](shape) 222 | if t.num_elements() != len(dv): 223 | print("[WARNING] tensor and dv not the shame shape") 224 | for i in range(t.num_elements()): 225 | t[i] = dv[i] 226 | return t ^ 227 | 228 | 229 | fn main(): 230 | alias learning_rate = 1e-3 231 | alias epochs = 100 232 | alias batch_size = 4 233 | 234 | var inputs = Tensor[dtype](batch_size, 1, 28, 28) 235 | rand[dtype](inputs.data(), inputs.num_elements()) 236 | var labels = Tensor[dtype](batch_size, 10) # one-hot encoded (probabilities) 237 | for i in range(4): 238 | labels[i * 10 + i] = 1.0 239 | 240 | alias cv1_w_shape = TensorShape(16, 1, 5, 5) 241 | alias conv1_weights = create_weights(cv1_w_shape.num_elements(), zero=False) 242 | alias cv1_b_shape = TensorShape(16) 243 | alias conv1_bias = create_weights(16, zero=True) 244 | 245 | alias cv2_w_shape = TensorShape(32, 16, 5, 5) 246 | alias conv2_weights = create_weights(cv2_w_shape.num_elements(), zero=False) 247 | alias cv2_b_shape = TensorShape(32) 248 | alias conv2_bias = create_weights(32, zero=True) 249 | 250 | alias l1_w_shape = TensorShape(32 * 7 * 7, 10) 251 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False) 252 | alias l1_b_shape = TensorShape(10) 253 | alias linear1_bias = create_weights(10, zero=True) 254 | 255 | var losses_mojo = run_mojo[ 256 | batch_size, 257 | conv1_weights, 258 | conv1_bias, 259 | conv2_weights, 260 | conv2_bias, 261 | linear1_weights, 262 | linear1_bias, 263 | ]( 264 | epochs, 265 | learning_rate, 266 | inputs, 267 | labels, 268 | ) 269 | 270 | var losses_torch = run_torch( 271 | epochs, 272 | learning_rate, 273 | inputs, 274 | labels, 275 | dv_to_tensor(conv1_weights, cv1_w_shape), 276 | dv_to_tensor(conv1_bias, cv1_b_shape), 277 | dv_to_tensor(conv2_weights, cv2_w_shape), 278 | dv_to_tensor(conv2_bias, cv2_b_shape), 279 | dv_to_tensor(linear1_weights, l1_w_shape), 280 | dv_to_tensor(linear1_bias, l1_b_shape), 281 | ) 282 | 283 | for i in range(epochs): 284 | print("loss_mojo: ", losses_mojo[i], " loss_torch: ", losses_torch[i]) 285 | 286 | for i in range(epochs): 287 | var loss_mojo = losses_mojo[i] 288 | var loss_torch = losses_torch[i] 289 | print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch) 290 | try: 291 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-5) 292 | except e: 293 | print("Losses not equal") 294 | print(e) 295 | break 296 | -------------------------------------------------------------------------------- /tests/python/test_models_regression.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from python import Python 3 | from utils.numerics import max_finite 4 | from testing import assert_almost_equal 5 | 6 | from basalt import dtype 7 | from basalt.autograd import Graph, OP 8 | from basalt.nn import Tensor, TensorShape, Model, MSELoss, optim 9 | from basalt.utils.rand_utils import MersenneTwister 10 | from basalt.autograd.params import Param 11 | 12 | from tests import to_numpy, to_tensor 13 | 14 | 15 | fn create_linear_regression( 16 | batch_size: Int, 17 | n_outputs: Int, 18 | linear1_weights: List[Scalar[dtype]], 19 | linear1_bias: List[Scalar[dtype]], 20 | ) -> Graph: 21 | var g = Graph() 22 | var x = g.input(TensorShape(batch_size, 13)) 23 | 24 | # linear1 25 | # var out = nn.Linear(g, x, n_outputs=1) 26 | var l1_w = g.param(TensorShape(13, n_outputs), init=Param(linear1_weights)) 27 | var l1_b = g.param(TensorShape(n_outputs), init=Param(linear1_bias)) 28 | var res = g.op(OP.DOT, x, l1_w) 29 | var out = g.op(OP.ADD, res, l1_b) 30 | g.out(out) 31 | 32 | var y_true = g.input(TensorShape(batch_size, n_outputs)) 33 | var loss = MSELoss(g, out, y_true) 34 | g.loss(loss) 35 | 36 | return g ^ 37 | 38 | 39 | fn run_mojo[ 40 | batch_size: Int, 41 | n_outputs: Int, 42 | linear1_weights: List[Scalar[dtype]], 43 | linear1_bias: List[Scalar[dtype]], 44 | ]( 45 | epochs: Int, 46 | learning_rate: Float64, 47 | inputs: Tensor[dtype], 48 | labels: Tensor[dtype], 49 | ) -> List[Scalar[dtype]]: 50 | alias graph = create_linear_regression( 51 | batch_size, 52 | n_outputs, 53 | linear1_weights, 54 | linear1_bias, 55 | ) 56 | 57 | var model = Model[graph]() 58 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]()) 59 | 60 | var losses = List[Scalar[dtype]]() 61 | 62 | for i in range(epochs): 63 | var loss = model.forward(inputs, labels) 64 | 65 | # Backward pass 66 | optim.zero_grad() 67 | model.backward() 68 | optim.step() 69 | 70 | losses.append(loss[0]) 71 | 72 | return losses 73 | 74 | 75 | fn run_torch( 76 | epochs: Int, 77 | learning_rate: Float64, 78 | inputs: Tensor, 79 | labels: Tensor, 80 | owned linear1_weights: Tensor, 81 | owned linear1_bias: Tensor, 82 | ) -> List[Scalar[dtype]]: 83 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]() 84 | 85 | try: 86 | var torch = Python.import_module("torch") 87 | var F = Python.import_module("torch.nn.functional") 88 | var np = Python.import_module("numpy") 89 | Python.add_to_path("./tests/python") 90 | var torch_models = Python.import_module("test_models_torch") 91 | 92 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True) 93 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True) 94 | 95 | var linear1_weights = torch.from_numpy( 96 | to_numpy(linear1_weights) 97 | ).requires_grad_(True) 98 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True) 99 | 100 | var regression = torch_models.LinearRegression( 101 | linear1_weights, 102 | linear1_bias, 103 | ) 104 | 105 | var loss_func = torch_models.MSELoss() 106 | var optimizer = torch.optim.Adam(regression.parameters(), learning_rate) 107 | 108 | for i in range(epochs): 109 | var output = regression.forward(inputs) 110 | var loss = loss_func(output, labels) 111 | 112 | _ = optimizer.zero_grad() 113 | _ = loss.backward() 114 | _ = optimizer.step() 115 | 116 | out.append(to_tensor(loss)[0].cast[dtype]()) 117 | 118 | return out 119 | 120 | except e: 121 | print("Error importing torch") 122 | print(e) 123 | return out 124 | 125 | 126 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]: 127 | var prng = MersenneTwister(123456) 128 | var weights = List[Scalar[dtype]](capacity=num_elements) 129 | for i in range(num_elements): 130 | if zero: 131 | weights.append(Scalar[dtype](0.0)) 132 | else: 133 | var rand_float = prng.next().cast[dtype]() / max_finite[DType.int32]().cast[ 134 | dtype 135 | ]() 136 | weights.append(Scalar[dtype](rand_float / 10)) 137 | return weights ^ 138 | 139 | 140 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]: 141 | var t = Tensor[dtype](shape) 142 | if t.num_elements() != len(dv): 143 | print("[WARNING] tensor and dv not the shame shape") 144 | for i in range(t.num_elements()): 145 | t[i] = dv[i] 146 | return t ^ 147 | 148 | 149 | fn main(): 150 | alias learning_rate = 1e-3 151 | alias epochs = 100 152 | alias batch_size = 64 153 | alias n_outputs = 10 154 | 155 | var inputs = Tensor[dtype](batch_size, 13) 156 | rand[dtype](inputs.data(), inputs.num_elements()) 157 | var labels = Tensor[dtype](batch_size, n_outputs) 158 | for i in range(batch_size): 159 | for j in range(n_outputs): 160 | labels[i * n_outputs + j] = 1 161 | 162 | alias l1_w_shape = TensorShape(13, n_outputs) 163 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False) 164 | alias l1_b_shape = TensorShape(n_outputs) 165 | alias linear1_bias = create_weights(l1_b_shape.num_elements(), zero=False) 166 | 167 | var losses_mojo = run_mojo[batch_size, n_outputs, linear1_weights, linear1_bias,]( 168 | epochs, 169 | learning_rate, 170 | inputs, 171 | labels, 172 | ) 173 | 174 | var losses_torch = run_torch( 175 | epochs, 176 | learning_rate, 177 | inputs, 178 | labels, 179 | dv_to_tensor(linear1_weights, l1_w_shape), 180 | dv_to_tensor(linear1_bias, l1_b_shape), 181 | ) 182 | 183 | var success = True 184 | for i in range(epochs): 185 | var loss_mojo = losses_mojo[i] 186 | var loss_torch = losses_torch[i] 187 | # print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch) 188 | try: 189 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-4) 190 | except e: 191 | print("Losses not equal") 192 | print(e) 193 | success = False 194 | break 195 | 196 | if success: 197 | print("SUCCESS: All losses in Linear Regression model are equal.") 198 | -------------------------------------------------------------------------------- /tests/python/test_models_sin_estimate.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from python import Python 3 | from utils.numerics import max_finite 4 | from testing import assert_almost_equal 5 | import math 6 | 7 | from basalt import dtype 8 | from basalt.autograd import Graph, OP 9 | from basalt.nn import Tensor, TensorShape, Model, ReLU, MSELoss, optim 10 | from basalt.utils.rand_utils import MersenneTwister 11 | from basalt.autograd.params import Param 12 | 13 | from tests import to_numpy, to_tensor 14 | 15 | 16 | fn create_simple_nn( 17 | batch_size: Int, 18 | linear1_weights: List[Scalar[dtype]], 19 | linear1_bias: List[Scalar[dtype]], 20 | linear2_weights: List[Scalar[dtype]], 21 | linear2_bias: List[Scalar[dtype]], 22 | linear3_weights: List[Scalar[dtype]], 23 | linear3_bias: List[Scalar[dtype]], 24 | ) -> Graph: 25 | var g = Graph() 26 | 27 | var x = g.input(TensorShape(batch_size, 1)) 28 | var y_true = g.input(TensorShape(batch_size, 1)) 29 | 30 | # Linear 1: nn.Linear(g, x, n_outputs=32) 31 | var l1_w = g.param(TensorShape(1, 32), init=Param(linear1_weights)) 32 | var l1_b = g.param(TensorShape(32), init=Param(linear1_bias)) 33 | var res_1 = g.op(OP.DOT, x, l1_w) 34 | var x1 = g.op(OP.ADD, res_1, l1_b) 35 | 36 | # ReLU 1 37 | var x2 = ReLU(g, x1) 38 | 39 | # Linear 2: nn.Linear(g, x2, n_outputs=32) 40 | var l2_w = g.param(TensorShape(32, 32), init=Param(linear2_weights)) 41 | var l2_b = g.param(TensorShape(32), init=Param(linear2_bias)) 42 | var res_2 = g.op(OP.DOT, x2, l2_w) 43 | var x3 = g.op(OP.ADD, res_2, l2_b) 44 | 45 | # ReLU 2 46 | var x4 = ReLU(g, x3) 47 | 48 | # Linear 3: nn.Linear(g, x4, n_outputs=1) 49 | var l3_w = g.param(TensorShape(32, 1), init=Param(linear3_weights)) 50 | var l3_b = g.param(TensorShape(1), init=Param(linear3_bias)) 51 | var res_3 = g.op(OP.DOT, x4, l3_w) 52 | var y_pred = g.op(OP.ADD, res_3, l3_b) 53 | g.out(y_pred) 54 | 55 | var loss = MSELoss(g, y_pred, y_true) 56 | g.loss(loss) 57 | 58 | return g ^ 59 | 60 | 61 | fn run_mojo[ 62 | batch_size: Int, 63 | linear1_weights: List[Scalar[dtype]], 64 | linear1_bias: List[Scalar[dtype]], 65 | linear2_weights: List[Scalar[dtype]], 66 | linear2_bias: List[Scalar[dtype]], 67 | linear3_weights: List[Scalar[dtype]], 68 | linear3_bias: List[Scalar[dtype]], 69 | ]( 70 | epochs: Int, 71 | learning_rate: Float64, 72 | inputs: Tensor[dtype], 73 | labels: Tensor[dtype], 74 | ) -> List[Scalar[dtype]]: 75 | alias graph = create_simple_nn( 76 | batch_size, 77 | linear1_weights, 78 | linear1_bias, 79 | linear2_weights, 80 | linear2_bias, 81 | linear3_weights, 82 | linear3_bias, 83 | ) 84 | 85 | var model = Model[graph]() 86 | var optim = optim.Adam[graph](model.parameters, lr=learning_rate.cast[dtype]()) 87 | 88 | var losses = List[Scalar[dtype]]() 89 | 90 | for i in range(epochs): 91 | var loss = model.forward(inputs, labels) 92 | 93 | # Backward pass 94 | optim.zero_grad() 95 | model.backward() 96 | optim.step() 97 | 98 | losses.append(loss[0]) 99 | 100 | return losses 101 | 102 | 103 | fn run_torch( 104 | epochs: Int, 105 | learning_rate: Float64, 106 | inputs: Tensor, 107 | labels: Tensor, 108 | owned linear1_weights: Tensor, 109 | owned linear1_bias: Tensor, 110 | owned linear2_weights: Tensor, 111 | owned linear2_bias: Tensor, 112 | owned linear3_weights: Tensor, 113 | owned linear3_bias: Tensor, 114 | ) -> List[Scalar[dtype]]: 115 | var out: List[Scalar[dtype]] = List[Scalar[dtype]]() 116 | 117 | try: 118 | var torch = Python.import_module("torch") 119 | var F = Python.import_module("torch.nn.functional") 120 | var np = Python.import_module("numpy") 121 | Python.add_to_path("./tests/python") 122 | var torch_models = Python.import_module("test_models_torch") 123 | 124 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True) 125 | var labels = torch.from_numpy(to_numpy(labels)).requires_grad_(True) 126 | 127 | var linear1_weights = torch.from_numpy( 128 | to_numpy(linear1_weights) 129 | ).requires_grad_(True) 130 | var linear1_bias = torch.from_numpy(to_numpy(linear1_bias)).requires_grad_(True) 131 | var linear2_weights = torch.from_numpy( 132 | to_numpy(linear2_weights) 133 | ).requires_grad_(True) 134 | var linear2_bias = torch.from_numpy(to_numpy(linear2_bias)).requires_grad_(True) 135 | var linear3_weights = torch.from_numpy( 136 | to_numpy(linear3_weights) 137 | ).requires_grad_(True) 138 | var linear3_bias = torch.from_numpy(to_numpy(linear3_bias)).requires_grad_(True) 139 | 140 | var regression = torch_models.SimpleNN( 141 | linear1_weights, 142 | linear1_bias, 143 | linear2_weights, 144 | linear2_bias, 145 | linear3_weights, 146 | linear3_bias, 147 | ) 148 | 149 | var loss_func = torch_models.MSELoss() 150 | var optimizer = torch.optim.Adam(regression.parameters(), learning_rate) 151 | 152 | for i in range(epochs): 153 | var output = regression.forward(inputs) 154 | var loss = loss_func(output, labels) 155 | 156 | _ = optimizer.zero_grad() 157 | _ = loss.backward() 158 | _ = optimizer.step() 159 | 160 | out.append(to_tensor(loss)[0].cast[dtype]()) 161 | 162 | return out 163 | 164 | except e: 165 | print("Error importing torch") 166 | print(e) 167 | return out 168 | 169 | 170 | fn create_weights(num_elements: Int, zero: Bool) -> List[Scalar[dtype]]: 171 | var prng = MersenneTwister(123456) 172 | var weights = List[Scalar[dtype]](capacity=num_elements) 173 | for i in range(num_elements): 174 | if zero: 175 | weights.append(Scalar[dtype](0.0)) 176 | else: 177 | var rand_float = prng.next().cast[dtype]() / max_finite[DType.int32]().cast[ 178 | dtype 179 | ]() 180 | weights.append(Scalar[dtype](rand_float / 10)) 181 | return weights ^ 182 | 183 | 184 | fn dv_to_tensor(dv: List[Scalar[dtype]], shape: TensorShape) -> Tensor[dtype]: 185 | var t = Tensor[dtype](shape) 186 | if t.num_elements() != len(dv): 187 | print("[WARNING] tensor and dv not the shame shape") 188 | for i in range(t.num_elements()): 189 | t[i] = dv[i] 190 | return t ^ 191 | 192 | 193 | fn main(): 194 | alias learning_rate = 1e-3 195 | alias epochs = 100 196 | alias batch_size = 64 197 | alias n_outputs = 10 198 | 199 | var x_data = Tensor[dtype](batch_size, 1) 200 | rand[dtype](x_data.data(), x_data.num_elements()) 201 | var y_data = Tensor[dtype](batch_size, 1) 202 | for j in range(batch_size): 203 | x_data[j] = x_data[j] * 2 - 1 204 | y_data[j] = math.sin(x_data[j]) 205 | 206 | alias l1_w_shape = TensorShape(1, 32) 207 | alias l1_b_shape = TensorShape(32) 208 | alias l2_w_shape = TensorShape(32, 32) 209 | alias l2_b_shape = TensorShape(32) 210 | alias l3_w_shape = TensorShape(32, 1) 211 | alias l3_b_shape = TensorShape(1) 212 | 213 | alias linear1_weights = create_weights(l1_w_shape.num_elements(), zero=False) 214 | alias linear1_bias = create_weights(l1_b_shape.num_elements(), zero=False) 215 | alias linear2_weights = create_weights(l2_w_shape.num_elements(), zero=False) 216 | alias linear2_bias = create_weights(l2_b_shape.num_elements(), zero=False) 217 | alias linear3_weights = create_weights(l3_w_shape.num_elements(), zero=False) 218 | alias linear3_bias = create_weights(l3_b_shape.num_elements(), zero=False) 219 | 220 | var losses_mojo = run_mojo[ 221 | batch_size, 222 | linear1_weights, 223 | linear1_bias, 224 | linear2_weights, 225 | linear2_bias, 226 | linear3_weights, 227 | linear3_bias, 228 | ](epochs, learning_rate, x_data, y_data) 229 | 230 | var losses_torch = run_torch( 231 | epochs, 232 | learning_rate, 233 | x_data, 234 | y_data, 235 | dv_to_tensor(linear1_weights, l1_w_shape), 236 | dv_to_tensor(linear1_bias, l1_b_shape), 237 | dv_to_tensor(linear2_weights, l2_w_shape), 238 | dv_to_tensor(linear2_bias, l2_b_shape), 239 | dv_to_tensor(linear3_weights, l3_w_shape), 240 | dv_to_tensor(linear3_bias, l3_b_shape), 241 | ) 242 | 243 | var success = True 244 | for i in range(epochs): 245 | var loss_mojo = losses_mojo[i] 246 | var loss_torch = losses_torch[i] 247 | # print("loss_mojo: ", loss_mojo, " loss_torch: ", loss_torch) 248 | try: 249 | assert_almost_equal(loss_mojo, loss_torch, rtol=1e-4) 250 | except e: 251 | print("Losses not equal") 252 | print(e) 253 | success = False 254 | break 255 | 256 | if success: 257 | print("SUCCESS: All losses in Sin estimate model are equal.") 258 | -------------------------------------------------------------------------------- /tests/python/test_models_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LinearRegression(nn.Module): 7 | def __init__(self, linear1_weights, linear1_bias): 8 | super(LinearRegression, self).__init__() 9 | 10 | self.linear1_weights = nn.Parameter(linear1_weights) 11 | self.linear1_bias = nn.Parameter(linear1_bias) 12 | 13 | def forward(self, x): 14 | output = F.linear(x, self.linear1_weights.T, self.linear1_bias) 15 | return output 16 | 17 | 18 | class MSELoss(nn.Module): 19 | def __init__(self): 20 | super(MSELoss, self).__init__() 21 | 22 | def forward(self, output, target): 23 | loss = F.mse_loss(output, target) 24 | return loss 25 | 26 | 27 | class CrossEntropyLoss(nn.Module): 28 | def __init__(self): 29 | super(CrossEntropyLoss, self).__init__() 30 | 31 | def forward(self, output, target): 32 | loss = -torch.sum(target * torch.log(output)) / output.size(0) 33 | return loss 34 | 35 | 36 | # Implement the class for crossentropy loss with logsoftmax 37 | class CrossEntropyLoss2(nn.Module): 38 | def __init__(self): 39 | super(CrossEntropyLoss2, self).__init__() 40 | 41 | def forward(self, output, target): 42 | loss = -torch.sum(target * F.log_softmax(output, dim=1)) / output.size(0) 43 | return loss 44 | 45 | 46 | class CNN(nn.Module): 47 | def __init__( 48 | self, 49 | conv1_weights, 50 | conv1_bias, 51 | conv2_weights, 52 | conv2_bias, 53 | linear1_weights, 54 | linear1_bias, 55 | ): 56 | super(CNN, self).__init__() 57 | 58 | self.conv1_weights = nn.Parameter(conv1_weights) 59 | self.conv1_bias = nn.Parameter(conv1_bias) 60 | self.conv2_weights = nn.Parameter(conv2_weights) 61 | self.conv2_bias = nn.Parameter(conv2_bias) 62 | self.linear1_weights = nn.Parameter(linear1_weights) 63 | self.linear1_bias = nn.Parameter(linear1_bias) 64 | 65 | def forward(self, x): 66 | x = F.conv2d(x, self.conv1_weights, self.conv1_bias, stride=1, padding=2) 67 | x = F.relu(x) 68 | x = F.max_pool2d(x, 2) 69 | x = F.conv2d(x, self.conv2_weights, self.conv2_bias, stride=1, padding=2) 70 | x = F.relu(x) 71 | x = F.max_pool2d(x, 2) 72 | x = x.view(x.size(0), -1) 73 | output = F.linear(x, self.linear1_weights.T, self.linear1_bias) 74 | return output 75 | 76 | def print_grads(self): 77 | print("\nCONV1 WEIGHTS", self.conv1_weights.grad.shape) 78 | print(self.conv1_weights.grad) 79 | 80 | print("\nCONV1 BIAS", self.conv1_bias.grad.shape) 81 | print(self.conv1_bias.grad) 82 | 83 | print("\nCONV2 WEIGHTS", self.conv2_weights.grad.shape) 84 | print(self.conv2_weights.grad) 85 | 86 | print("\nCONV2 BIAS", self.conv2_bias.grad.shape) 87 | print(self.conv2_bias.grad) 88 | 89 | print("\nLINEAR1 WEIGHTS", self.linear1_weights.grad.shape) 90 | print(self.linear1_weights.grad) 91 | 92 | print("\nLINEAR1 BIAS", self.linear1_bias.grad.shape) 93 | print(self.linear1_bias.grad) 94 | 95 | 96 | class SimpleNN(nn.Module): 97 | def __init__( 98 | self, 99 | linear1_weights, 100 | linear1_bias, 101 | linear2_weights, 102 | linear2_bias, 103 | linear3_weights, 104 | linear3_bias, 105 | ): 106 | super(SimpleNN, self).__init__() 107 | 108 | self.linear1_weights = nn.Parameter(linear1_weights) 109 | self.linear1_bias = nn.Parameter(linear1_bias) 110 | self.linear2_weights = nn.Parameter(linear2_weights) 111 | self.linear2_bias = nn.Parameter(linear2_bias) 112 | self.linear3_weights = nn.Parameter(linear3_weights) 113 | self.linear3_bias = nn.Parameter(linear3_bias) 114 | 115 | self.relu1 = nn.ReLU() 116 | self.relu2 = nn.ReLU() 117 | 118 | def forward(self, x): 119 | x1 = F.linear(x, self.linear1_weights.T, self.linear1_bias) 120 | x2 = self.relu1(x1) 121 | x3 = F.linear(x2, self.linear2_weights.T, self.linear2_bias) 122 | x4 = self.relu2(x3) 123 | y_pred = F.linear(x4, self.linear3_weights.T, self.linear3_bias) 124 | return y_pred 125 | -------------------------------------------------------------------------------- /tests/python/test_pool.mojo: -------------------------------------------------------------------------------- 1 | from random import rand 2 | from python.python import Python 3 | from testing import assert_equal 4 | from utils.index import IndexList 5 | 6 | from basalt import dtype, nelts 7 | from basalt.autograd import Graph, OP 8 | from basalt.autograd.ops.pool import MAXPOOL2D 9 | from basalt.autograd.ops.conv import get_result_shape 10 | from basalt.autograd.attributes import Attribute, AttributeVector 11 | from basalt.nn import Tensor, TensorShape, Model 12 | 13 | from tests import assert_tensors_equal, to_numpy, to_tensor 14 | 15 | 16 | @value 17 | struct torch_maxpool2d_output: 18 | var expected: Tensor[dtype] 19 | var expected_grad: Tensor[dtype] 20 | 21 | 22 | fn torch_maxpool2d( 23 | inputs: Tensor, 24 | kernel_size: IndexList[2], 25 | padding: IndexList[2], 26 | stride: IndexList[2], 27 | dilation: IndexList[2], 28 | upper_grad: Tensor, 29 | ) -> torch_maxpool2d_output: 30 | var out: torch_maxpool2d_output 31 | 32 | try: 33 | var torch = Python.import_module("torch") 34 | var F = Python.import_module("torch.nn.functional") 35 | var np = Python.import_module("numpy") 36 | 37 | var inputs = torch.from_numpy(to_numpy(inputs)).requires_grad_(True) 38 | 39 | var expected = F.max_pool2d( 40 | inputs, 41 | (kernel_size[0], kernel_size[1]), 42 | (stride[0], stride[1]), 43 | (padding[0], padding[1]), 44 | (dilation[0], dilation[1]), 45 | ) 46 | 47 | # uppergrad & backwards 48 | var upper_grad = torch.from_numpy(to_numpy(upper_grad)) 49 | _ = expected.backward(upper_grad) 50 | 51 | # expected 52 | out = torch_maxpool2d_output( 53 | to_tensor(expected.detach().numpy()), to_tensor(inputs.grad.numpy()) 54 | ) 55 | return out 56 | 57 | except: 58 | print("Error in torch_maxpool2d") 59 | var d = Tensor[dtype](1) 60 | var out = torch_maxpool2d_output(d, d) 61 | return out 62 | 63 | 64 | fn test_pool_forward[ 65 | input_shape: TensorShape, 66 | kernel_size: IndexList[2], 67 | padding: IndexList[2], 68 | stride: IndexList[2], 69 | dilation: IndexList[2], 70 | ](inputs: Tensor[dtype]) raises: 71 | fn create_graph() -> Graph: 72 | var g = Graph() 73 | var inp = g.input(input_shape) 74 | 75 | var res = g.op( 76 | OP.MAXPOOL2D, 77 | inp, 78 | attributes=AttributeVector( 79 | Attribute("kernel_size", kernel_size), 80 | Attribute("padding", padding), 81 | Attribute("stride", stride), 82 | Attribute("dilation", dilation), 83 | ), 84 | ) 85 | g.out(res) 86 | 87 | return g ^ 88 | 89 | alias graph = create_graph() 90 | assert_equal(len(graph.nodes), 1) 91 | 92 | var model = Model[graph](inference_only=True) 93 | var res = model.inference(inputs)[0] 94 | 95 | var torch_out = torch_maxpool2d( 96 | inputs, 97 | kernel_size=kernel_size, 98 | padding=padding, 99 | stride=stride, 100 | dilation=dilation, 101 | upper_grad=Tensor[dtype](res.shape()), 102 | ) 103 | 104 | assert_tensors_equal(res, torch_out.expected) 105 | 106 | 107 | fn test_forward_1() raises: 108 | # padding=2, stride=1, dilation=1 109 | # input shape: (4, 1, 28, 28) kernel size: (5, 5) 110 | alias kernel_size = 5 111 | alias padding = 2 112 | alias stride = 1 113 | alias dilation = 1 114 | alias input_shape = TensorShape(4, 1, 28, 28) 115 | var inputs = Tensor[dtype](input_shape) 116 | rand[dtype](inputs.data(), inputs.num_elements()) 117 | 118 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs) 119 | 120 | 121 | fn test_forward_2() raises: 122 | # padding=0, stride=1, dilation=1 123 | # input shape: (4, 1, 32, 17) kernel size: (2, 2) 124 | alias kernel_size = IndexList[2](2, 2) 125 | alias padding = 0 126 | alias stride = 1 127 | alias dilation = 1 128 | alias input_shape = TensorShape(4, 1, 32, 17) 129 | var inputs = Tensor[dtype](input_shape) 130 | rand[dtype](inputs.data(), inputs.num_elements()) 131 | 132 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs) 133 | 134 | 135 | fn test_forward_3() raises: 136 | # padding=(3, 1), stride=(2, 3), dilation=(2, 3) 137 | # input shape: (4, 3, 32, 17) kernel size: (6, 6) 138 | alias kernel_size = IndexList[2](6, 6) 139 | alias padding = IndexList[2](3, 1) 140 | alias stride = IndexList[2](2, 3) 141 | alias dilation = IndexList[2](2, 3) 142 | alias input_shape = TensorShape(4, 3, 32, 17) 143 | var inputs = Tensor[dtype](input_shape) 144 | rand[dtype](inputs.data(), inputs.num_elements()) 145 | 146 | test_pool_forward[input_shape, kernel_size, padding, stride, dilation](inputs) 147 | 148 | 149 | fn test_pool_backward[ 150 | ug_shape: TensorShape, 151 | input_shape: TensorShape, 152 | kernel_size: IndexList[2], 153 | padding: IndexList[2], 154 | stride: IndexList[2], 155 | dilation: IndexList[2], 156 | ](ug: Tensor[dtype], inputs: Tensor[dtype]) raises: 157 | alias attributes = AttributeVector( 158 | Attribute("kernel_size", kernel_size), 159 | Attribute("padding", padding), 160 | Attribute("stride", stride), 161 | Attribute("dilation", dilation), 162 | ) 163 | 164 | var grad = MAXPOOL2D.backward[ug_shape, input_shape, attributes](ug, inputs) 165 | 166 | var torch_out = torch_maxpool2d( 167 | inputs, 168 | kernel_size=kernel_size, 169 | padding=padding, 170 | stride=stride, 171 | dilation=dilation, 172 | upper_grad=ug, 173 | ) 174 | 175 | assert_tensors_equal["almost"](grad, torch_out.expected_grad) 176 | 177 | 178 | fn test_backward_1() raises: 179 | # padding=2, stride=1, dilation=1 180 | # input shape: (4, 1, 28, 28) kernel size: (5, 5) 181 | alias kernel_size = 5 182 | alias padding = 2 183 | alias stride = 1 184 | alias dilation = 1 185 | alias input_shape = TensorShape(4, 1, 28, 28) 186 | var inputs = Tensor[dtype](input_shape) 187 | rand[dtype](inputs.data(), inputs.num_elements()) 188 | 189 | # uppergrad 190 | alias res = get_result_shape( 191 | input_shape, TensorShape(kernel_size, kernel_size), padding, stride, dilation 192 | ) 193 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1]) 194 | var ug = Tensor[dtype](ug_shape) 195 | rand[dtype](ug.data(), ug.num_elements()) 196 | 197 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation]( 198 | ug, inputs 199 | ) 200 | 201 | 202 | fn test_backward_2() raises: 203 | # padding=0, stride=1, dilation=1 204 | # input shape: (4, 1, 32, 17) kernel size: (2, 2) 205 | alias kernel_size = 2 206 | alias padding = 0 207 | alias stride = 1 208 | alias dilation = 1 209 | alias input_shape = TensorShape(4, 1, 32, 17) 210 | var inputs = Tensor[dtype](input_shape) 211 | rand[dtype](inputs.data(), inputs.num_elements()) 212 | 213 | # uppergrad 214 | alias res = get_result_shape( 215 | input_shape, TensorShape(kernel_size, kernel_size), padding, stride, dilation 216 | ) 217 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1]) 218 | var ug = Tensor[dtype](ug_shape) 219 | rand[dtype](ug.data(), ug.num_elements()) 220 | 221 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation]( 222 | ug, inputs 223 | ) 224 | 225 | 226 | fn test_backward_3() raises: 227 | # padding=(3, 1), stride=(2, 3), dilation=(2, 3) 228 | # input shape: (4, 3, 32, 17) kernel size: (6, 6) 229 | alias kernel_size = IndexList[2](6, 6) 230 | alias padding = IndexList[2](3, 1) 231 | alias stride = IndexList[2](2, 3) 232 | alias dilation = IndexList[2](2, 3) 233 | alias input_shape = TensorShape(4, 3, 32, 17) 234 | var inputs = Tensor[dtype](input_shape) 235 | rand[dtype](inputs.data(), inputs.num_elements()) 236 | 237 | # uppergrad 238 | alias kernel_size_static: IndexList[2] = kernel_size 239 | alias res = get_result_shape( 240 | input_shape, TensorShape(kernel_size_static), padding, stride, dilation 241 | ) 242 | alias ug_shape = TensorShape(input_shape[0], input_shape[1], res[0], res[1]) 243 | var ug = Tensor[dtype](ug_shape) 244 | rand[dtype](ug.data(), ug.num_elements()) 245 | 246 | test_pool_backward[ug_shape, input_shape, kernel_size, padding, stride, dilation]( 247 | ug, inputs 248 | ) 249 | 250 | 251 | fn main(): 252 | try: 253 | test_forward_1() 254 | test_forward_2() 255 | test_forward_3() 256 | test_backward_1() 257 | test_backward_2() 258 | test_backward_3() 259 | except e: 260 | print("[Error] Error in MaxPool2D") 261 | print(e) 262 | -------------------------------------------------------------------------------- /tests/testing_utils.mojo: -------------------------------------------------------------------------------- 1 | from python.python import Python 2 | from collections import OptionalReg 3 | from testing import assert_equal, assert_almost_equal 4 | 5 | from basalt import dtype 6 | from basalt.autograd import Graph, OP 7 | from basalt.autograd.ops.ops import backward_op 8 | from basalt.autograd.attributes import AttributeVector 9 | from basalt.nn import Tensor, TensorShape, Model 10 | from basalt.utils.tensor_creation_utils import to_numpy, to_tensor 11 | 12 | 13 | # The below regex should be used to convert deprecated calls 14 | # assert_tensors_equal\(([^,]+),\s*([^,]+),\s*"([^"]+)"\) 15 | # assert_tensors_equal["$3"]($1, $2) 16 | fn assert_tensors_equal[ 17 | mode: String = "exact", msg: String = "Error" 18 | ](t1: Tensor[dtype], t2: Tensor[dtype]) raises: 19 | constrained[ 20 | mode == "exact" or mode == "almost", "Mode must be either 'exact' or 'almost'" 21 | ]() 22 | 23 | assert_equal(t1.shape(), t2.shape(), "Tensor shape mismatch") 24 | 25 | for i in range(t1.num_elements()): 26 | if mode == "almost": 27 | assert_almost_equal(t1[i], t2[i], rtol=1e-5, atol=1e-5, msg=msg) 28 | else: 29 | assert_equal(t1[i], t2[i], msg=msg) 30 | 31 | 32 | fn test_unary_op[ 33 | op: OP, t1_shape: TensorShape, attrs: OptionalReg[AttributeVector] = None 34 | ](t1: Tensor[dtype], expected: Tensor[dtype]) raises: 35 | fn create_graph() -> Graph: 36 | var g = Graph() 37 | var t1 = g.input(t1_shape) 38 | 39 | if attrs: 40 | var res = g.op(op, t1, attributes=attrs.value()) 41 | g.out(res) 42 | return g ^ 43 | else: 44 | var res = g.op(op, t1) 45 | g.out(res) 46 | return g ^ 47 | 48 | alias graph = create_graph() 49 | assert_equal(len(graph.nodes), 1) 50 | 51 | var model = Model[graph](inference_only=True) 52 | var res = model.inference(t1)[0] 53 | 54 | assert_tensors_equal["almost"](res, expected) 55 | 56 | 57 | fn test_binary_op[ 58 | op: OP, 59 | t1_shape: TensorShape, 60 | t2_shape: TensorShape, 61 | attrs: OptionalReg[AttributeVector] = None, 62 | ](t1: Tensor[dtype], t2: Tensor[dtype], expected: Tensor[dtype]) raises: 63 | fn create_graph() -> Graph: 64 | var g = Graph() 65 | var t1 = g.input(t1_shape) 66 | var t2 = g.input(t2_shape) 67 | 68 | if attrs: 69 | var res = g.op(op, t1, t2, attributes=attrs.value()) 70 | g.out(res) 71 | return g ^ 72 | else: 73 | var res = g.op(op, t1, t2) 74 | g.out(res) 75 | return g ^ 76 | 77 | alias graph = create_graph() 78 | assert_equal(len(graph.nodes), 1) 79 | 80 | var model = Model[graph](inference_only=True) 81 | var res = model.inference(t1, t2)[0] 82 | 83 | assert_tensors_equal["almost"](res, expected) 84 | 85 | 86 | fn test_ternary_op[ 87 | op: OP, t1_shape: TensorShape, t2_shape: TensorShape, t3_shape: TensorShape 88 | ]( 89 | t1: Tensor[dtype], t2: Tensor[dtype], t3: Tensor[dtype], expected: Tensor[dtype] 90 | ) raises: 91 | @parameter 92 | fn create_graph() -> Graph: 93 | var g = Graph() 94 | var t1 = g.input(t1_shape) 95 | var t2 = g.input(t2_shape) 96 | var t3 = g.input(t3_shape) 97 | 98 | var res = g.op(op, t1, t2, t3) 99 | g.out(res) 100 | 101 | return g ^ 102 | 103 | alias graph = create_graph() 104 | assert_equal(len(graph.nodes), 1) 105 | 106 | var model = Model[graph](inference_only=True) 107 | var res = model.inference(t1, t2, t3)[0] 108 | 109 | assert_tensors_equal["almost"](res, expected) 110 | 111 | 112 | fn test_unary_op_backward[ 113 | op: OP, 114 | t1_shape: TensorShape, 115 | ug_shape: TensorShape, 116 | attrs: AttributeVector = AttributeVector(), 117 | ](t1: Tensor[dtype], ug: Tensor[dtype], grad_1_expected: Tensor[dtype],) raises: 118 | var grad_1 = Tensor[dtype](t1_shape) 119 | backward_op[0, op, ug_shape, t1_shape, attrs](ug, t1, grad_1) 120 | assert_tensors_equal["almost"](grad_1, grad_1_expected) 121 | 122 | 123 | fn test_binary_op_backward[ 124 | op: OP, 125 | t1_shape: TensorShape, 126 | t2_shape: TensorShape, 127 | ug_shape: TensorShape, 128 | attrs: AttributeVector = AttributeVector(), 129 | ]( 130 | t1: Tensor[dtype], 131 | t2: Tensor[dtype], 132 | ug: Tensor[dtype], 133 | grad_1_expected: Tensor[dtype], 134 | grad_2_expected: Tensor[dtype], 135 | ) raises: 136 | var grad_1 = Tensor[dtype](t1_shape) 137 | backward_op[0, op, ug_shape, t1_shape, t2_shape, attrs](ug, t1, t2, grad_1) 138 | assert_tensors_equal["almost"](grad_1, grad_1_expected) 139 | 140 | var grad_2 = Tensor[dtype](t2_shape) 141 | backward_op[1, op, ug_shape, t1_shape, t2_shape, attrs](ug, t1, t2, grad_2) 142 | assert_tensors_equal["almost"](grad_2, grad_2_expected) 143 | 144 | 145 | fn test_ternary_op_backward[ 146 | op: OP, 147 | t1_shape: TensorShape, 148 | t2_shape: TensorShape, 149 | t3_shape: TensorShape, 150 | ug_shape: TensorShape, 151 | attrs: AttributeVector = AttributeVector(), 152 | ]( 153 | t1: Tensor[dtype], 154 | t2: Tensor[dtype], 155 | t3: Tensor[dtype], 156 | ug: Tensor[dtype], 157 | grad_1_expected: Tensor[dtype], 158 | grad_2_expected: Tensor[dtype], 159 | grad_3_expected: Tensor[dtype], 160 | ) raises: 161 | var grad_1 = Tensor[dtype](t1_shape) 162 | backward_op[0, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs]( 163 | ug, t1, t2, t3, grad_1 164 | ) 165 | assert_tensors_equal["almost"](grad_1, grad_1_expected) 166 | 167 | var grad_2 = Tensor[dtype](t2_shape) 168 | backward_op[1, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs]( 169 | ug, t1, t2, t3, grad_2 170 | ) 171 | assert_tensors_equal["almost"](grad_2, grad_2_expected) 172 | 173 | var grad_3 = Tensor[dtype](t3_shape) 174 | backward_op[2, op, ug_shape, t1_shape, t2_shape, t3_shape, attrs]( 175 | ug, t1, t2, t3, grad_3 176 | ) 177 | assert_tensors_equal["almost"](grad_3, grad_3_expected) 178 | 179 | 180 | fn create_graph_concat( 181 | t1_shape: TensorShape, t2_shape: TensorShape, t3_shape: TensorShape, dim: Int 182 | ) -> Graph: 183 | # Testing with 3 operands 184 | var g = Graph() 185 | var t1 = g.input(t1_shape, trainable=True) 186 | var t2 = g.input(t2_shape, trainable=True) 187 | var t3 = g.input(t3_shape, trainable=True) 188 | var res = g.concat(t1, t2, t3, dim=dim) 189 | g.out(res) 190 | g.loss(res) 191 | return g ^ 192 | 193 | 194 | fn create_graph_split(t_shape: TensorShape, sections: List[Int], dim: Int) -> Graph: 195 | var g = Graph() 196 | var t = g.input(t_shape, trainable=True) 197 | var results = g.split(t, sections=sections, dim=dim) 198 | for i in range(len(sections)): 199 | g.out(results[i]) 200 | g.loss(results[0]) # Any one 201 | return g ^ 202 | --------------------------------------------------------------------------------