├── .gitignore ├── LICENSE ├── src ├── utility.zig ├── permutate.zig ├── root.zig ├── sizes_and_strides.zig ├── stack_allocator.zig ├── tensor.zig ├── expression_parsing.zig ├── linear_caching_allocator.zig ├── tensor_factory.zig └── tensor_ops.zig └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | .zig-cache/ 3 | zig-out/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 andrewCodeDev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utility.zig: -------------------------------------------------------------------------------- 1 | /// src/utility.zig 2 | const std = @import("std"); 3 | const tensor = @import("./tensor.zig"); 4 | const builtin = @import("builtin"); 5 | 6 | // use this to compile out safety checks 7 | // and enforce invariants for debug builds. 8 | pub const debug: bool = (builtin.mode == .Debug); 9 | 10 | pub fn arrayProduct(comptime rank: usize, comptime T: type, values: *const [rank]T) T { 11 | const s: @Vector(rank, T) = values.*; 12 | return @reduce(std.builtin.ReduceOp.Mul, s); 13 | } 14 | 15 | pub fn arraySum(comptime rank: usize, comptime T: type, values: *const [rank]T) T { 16 | const s: @Vector(rank, T) = values.*; 17 | return @reduce(std.builtin.ReduceOp.Sum, s); 18 | } 19 | 20 | pub fn sliceProduct(comptime T: type, values: []const T) T { 21 | if (values.len == 0) { 22 | return 0; 23 | } 24 | var total: T = 1; 25 | for (values) |n| { 26 | total *= n; 27 | } 28 | return total; 29 | } 30 | 31 | pub fn sliceSum(comptime T: type, values: []const T) T { 32 | var total: T = 0; 33 | for (values) |n| { 34 | total += n; 35 | } 36 | return total; 37 | } 38 | 39 | test "basic tensor access" { 40 | var data = [9]i32{ 1, 2, 3, 4, 5, 6, 7, 8, 9 }; 41 | var X = tensor.Tensor(i32, 2, tensor.Rowwise).init(&data, .{ 3, 3 }); 42 | const x = X.getValue(.{ 0, 2 }); 43 | try std.testing.expect(x == 3); 44 | } 45 | -------------------------------------------------------------------------------- /src/permutate.zig: -------------------------------------------------------------------------------- 1 | // make this an enum at some point 2 | pub const SizeAndStride = @import("./sizes_and_strides.zig").SizeAndStride; 3 | pub const SizesAndStrides = @import("./sizes_and_strides.zig").SizesAndStrides; 4 | pub const SizesType = SizeAndStride.ValueType; 5 | const OrderType = @import("./sizes_and_strides.zig").OrderType; 6 | 7 | const permutateParse = @import("./expression_parsing.zig").permutateParse; 8 | 9 | pub fn permutate(comptime rank: usize, comptime order: OrderType, comptime str: []const u8, ss: *SizesAndStrides(rank, order)) void { 10 | const permutation = comptime permutateParse(rank, str); 11 | 12 | var tmp: SizesAndStrides(rank, order) = undefined; 13 | 14 | var i: usize = 0; 15 | for (permutation) |p| { 16 | tmp.setSizeAndStride(i, ss.getSizeAndStride(p)); 17 | tmp.permutation[i] = p; 18 | i += 1; 19 | } 20 | ss.* = tmp; 21 | } 22 | 23 | test "Permutation" { 24 | const expectEqual = @import("std").testing.expectEqual; 25 | 26 | var ss = SizesAndStrides(3, OrderType.rowwise).init(.{ 10, 20, 30 }); 27 | 28 | try expectEqual(ss.permutation[0], 0); 29 | try expectEqual(ss.permutation[1], 1); 30 | try expectEqual(ss.permutation[2], 2); 31 | try expectEqual(ss.sizes[0], 10); 32 | try expectEqual(ss.sizes[1], 20); 33 | try expectEqual(ss.sizes[2], 30); 34 | 35 | permutate(3, OrderType.rowwise, "ijk->kji", &ss); 36 | 37 | try expectEqual(ss.permutation[0], 2); 38 | try expectEqual(ss.permutation[1], 1); 39 | try expectEqual(ss.permutation[2], 0); 40 | try expectEqual(ss.sizes[0], 30); 41 | try expectEqual(ss.sizes[1], 20); 42 | try expectEqual(ss.sizes[2], 10); 43 | } 44 | -------------------------------------------------------------------------------- /src/root.zig: -------------------------------------------------------------------------------- 1 | // Zein interface file - import this file directly into your project to begin using Zein. 2 | 3 | // import core SizesAndStrides version 4 | const SizesAndStridesVersion = @import("./sizes_and_strides.zig"); 5 | pub const SizeAndStride = SizesAndStridesVersion.SizeAndStride; 6 | pub const SizesAndStrides = SizesAndStridesVersion.SizesAndStrides; 7 | 8 | // import core tensor version... this can be swapped for different tensor implementations. 9 | const TensorVersion = @import("./tensor.zig"); 10 | pub const Tensor = TensorVersion.Tensor; 11 | pub const TensorError = TensorVersion.TensorError; 12 | pub const Rowwise = TensorVersion.Rowwise; 13 | pub const Colwise = TensorVersion.Colwise; 14 | 15 | // import core TensorFactory version 16 | const TensorFactoryVersion = @import("./tensor_factory.zig"); 17 | pub const TensorFactory = TensorFactoryVersion.TensorFactory; 18 | pub const AllocatorError = TensorFactoryVersion.AllocatorError; 19 | 20 | // import core TensorOps version 21 | const TensorOpsVersion = @import("./tensor_ops.zig"); 22 | 23 | pub const sum = TensorOpsVersion.sum; 24 | pub const product = TensorOpsVersion.product; 25 | pub const min = TensorOpsVersion.min; 26 | pub const max = TensorOpsVersion.max; 27 | pub const contraction = TensorOpsVersion.contraction; 28 | pub const scale = TensorOpsVersion.scale; 29 | pub const bias = TensorOpsVersion.bias; 30 | pub const add = TensorOpsVersion.add; 31 | pub const mul = TensorOpsVersion.mul; 32 | pub const sub = TensorOpsVersion.sub; 33 | pub const absmax = TensorOpsVersion.absmax; 34 | pub const absmin = TensorOpsVersion.absmin; 35 | 36 | pub const quantize = TensorOpsVersion.quantize; 37 | pub const unquantize = TensorOpsVersion.unquantize; 38 | 39 | pub const fill = TensorOpsVersion.fill; 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZEIN 2 | 3 | Zig-based implementation of general-rank tensors! [1, 64) 4 | 5 | ## Importing ZEIN 6 | 7 | 1. Fetch ZEIN: 8 | 9 | `zig fetch --save git+https://github.com/andrewCodeDev/ZEIN#main` 10 | 11 | 2. Add the module in your `build.zig`: 12 | 13 | ```zig 14 | const zein = b.dependency("ZEIN", .{ 15 | .target = target, 16 | .optimize = optimize, 17 | }); 18 | 19 | exe.root_module.addImport("zein", zein.module("ZEIN")); 20 | ``` 21 | You can now add `const zein = @import("zein");` to your file. 22 | 23 | ## Using Tensor Objects 24 | 25 | Tensors can be created in the following way: 26 | 27 | ```zig 28 | // initialize underlying tensor memory: 29 | var data = [9]i32{ 1, 2, 3, 4, 5, 6, 7, 8, 9 }; 30 | 31 | // create a rank 2, 3x3, Rowwise tensor of i32 from data: 32 | var X = zein.Tensor(i32, 2, Rowwise).init(&data, .{ 3, 3 }); 33 | 34 | const x = X.getValue(.{0, 2}); // access value 3... 35 | ``` 36 | 37 | ## Allocating Tensor Data 38 | 39 | The TensorFactory offers the ability to track and free allocations: 40 | 41 | ```zig 42 | var factory = zein.TensorFactory(f32).init(.{ 43 | .system_allocator = your_allocator, // for TensorFactory components 44 | .tensor_allocator = your_allocator, // for TensorFactory value data 45 | }); 46 | 47 | // Begin tracking tensor allocations (default is no-tracking): 48 | factory.tracking(.start); 49 | 50 | // Stop tracking tensor allocations (does not free tensors): 51 | factory.tracking(.stop); 52 | 53 | // Free tracked tensor allocations (no-op if no tensors are tracked): 54 | factory.tracking(.free); 55 | 56 | // Deinit will free the allocator and currently tracked tensors: 57 | factory.deinit(); 58 | ```` 59 | 60 | ```zig 61 | // Assign a new tensor from allocator: 62 | var Y = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 63 | ``` 64 | 65 | ```zig 66 | // Assign memory into existing tensor: 67 | var X = Tensor(f32, 2, Rowwise).init(null, .{ 10, 10 }); 68 | try factory.allocToTensor(&X); // alloc 100 elements... 69 | ```` 70 | 71 | ## Tensor Operations 72 | 73 | Tensor operations are are in the form of either _Free Functions_ or _Factory Functions_: 74 | 75 | - Free Functions require operands and the destination tensor. 76 | 77 | - Factory Functions use operands to create the destination tensor. 78 | 79 | The operations use compile time strings as einsum notation: 80 | 81 | ```zig 82 | // Collapse tensor values using contraction: 83 | zein.contraction("ijk->ji", &x, &y); // free function - assign to existing memory 84 | var y = factory.contraction("ijk->ji", &x); // factory function - allocate new memory 85 | ``` 86 | 87 | ```zig 88 | // Elementary binary functions (add, multiply): 89 | zein.add(&x, &y, &z); // free function - assign to existing memory 90 | var x = factory.add(&x, &y); // factory function - allocate new memory 91 | ``` 92 | 93 | ```zig 94 | // Transpose/permutate tensor views (does not modify underlying data). 95 | var y = x.permutate("ijk->kji"); 96 | ``` 97 | 98 | ```zig 99 | // Elementary vectorized reduction functions (sum, product, min, max): 100 | const a = zein.sum(&x); 101 | const b = zein.product(&x); 102 | const c = zein.max(&x); 103 | const d = zein.min(&x); 104 | ``` 105 | 106 | ## Using the Zein library 107 | 108 | The main ZEIN/Zein.zig file provides an interface for the library implementation. 109 | 110 | ## Memory Ownership and Viewership 111 | 112 | Currently, tensor permutations only change the indexing of a tensor - they do not 113 | invalidate underlying memory. If the user chooses to use the TensorFactory, 114 | it will track allocations and delete them automatically when calling deinit. 115 | V1 is only tested on single thread environments - thread safety with allocators 116 | will be coming in a later version! 117 | 118 | ## Additonal functionality coming soon. 119 | 120 | This library is still in the beginning phases. If you want to contribute, please 121 | contact me! This is a big job and I'll take the help! 122 | -------------------------------------------------------------------------------- /src/sizes_and_strides.zig: -------------------------------------------------------------------------------- 1 | // Another implementation of this similar to Pytorch:C10 is to make a union with a dynamic 2 | // memory member variable that allows for extending the tensor modes beyond the static 3 | // storage size. Unfortunately, that incurs the cost of checking which member is in use. 4 | 5 | // A potential work around is to return a slice (or some reference object) and use that. 6 | // That is cumbersome though, especially for internal implementation details. 7 | 8 | pub const OrderType = enum { 9 | rowwise, // rank > 0 10 | colwise, // rank > 0 11 | }; 12 | 13 | pub const Rowwise = OrderType.rowwise; 14 | pub const Colwise = OrderType.colwise; 15 | 16 | const SizesType = u32; 17 | 18 | pub const SizeAndStride = struct { 19 | pub const ValueType = SizesType; 20 | size: ValueType = 0, 21 | stride: ValueType = 0, 22 | }; 23 | 24 | ///////////////////////////////////////////////////////// 25 | // Split SizeAndStrides into a contiguous segmented array 26 | 27 | inline fn unpackOptionalSizes(comptime rank: usize, sizes: ?[rank]SizesType) [rank]SizesType { 28 | if (sizes) |data| { 29 | return data; 30 | } else { 31 | var data: [rank]SizeAndStride.ValueType = undefined; 32 | @memset(&data, 0); 33 | return data; 34 | } 35 | } 36 | 37 | fn inferStridesFromSizes(comptime rank: usize, comptime order: OrderType, sizes: ?[rank]SizesType) [rank]SizesType { 38 | var strides: [rank]SizesType = undefined; 39 | 40 | if (rank == 1) { 41 | strides[0] = 1; 42 | return strides; 43 | } 44 | 45 | if (sizes) |data| { 46 | strides = data; 47 | 48 | if (order == OrderType.rowwise) { 49 | var i: usize = (rank - 1); 50 | var n: SizesType = 1; 51 | 52 | while (i > 0) : (i -= 1) { 53 | strides[i] = n; 54 | n *= data[i]; 55 | } 56 | strides[0] = n; 57 | } else { 58 | var i: usize = 0; 59 | var n: SizesType = 1; 60 | 61 | while (i < (rank - 1)) : (i += 1) { 62 | strides[i] = n; 63 | n *= data[i]; 64 | } 65 | strides[rank - 1] = n; 66 | } 67 | } else { 68 | @memset(&strides, 0); // zero seems like a sensible default... 69 | } 70 | return strides; 71 | } 72 | 73 | pub fn defaultPermutation(comptime rank: usize) [rank]SizesType { 74 | var tmp: [rank]SizesType = undefined; 75 | var i: SizesType = 0; 76 | while (i < rank) : (i += 1) { 77 | tmp[i] = i; 78 | } 79 | return tmp; 80 | } 81 | 82 | ///////////////////////////////////////// 83 | // SizesAndStrides Struct Implementation 84 | 85 | pub fn SizesAndStrides(comptime rank: usize, comptime order: OrderType) type { 86 | return struct { 87 | const Rank = rank; 88 | 89 | const Self = @This(); 90 | 91 | const SelfPtr = *Self; 92 | 93 | const Order = order; 94 | 95 | const ConstSelfPtr = *const Self; 96 | 97 | pub const ValueType = SizesType; 98 | 99 | sizes: [Rank]ValueType = undefined, 100 | strides: [Rank]ValueType = undefined, 101 | permutation: [Rank]ValueType = undefined, 102 | 103 | pub fn init(sizes: ?[Rank]ValueType) Self { 104 | return Self{ .sizes = unpackOptionalSizes(Rank, sizes), .strides = inferStridesFromSizes(Rank, Order, sizes), .permutation = defaultPermutation(Rank) }; 105 | } 106 | 107 | //// pairwise setters/getter 108 | pub fn getSizeAndStride(self: ConstSelfPtr, i: usize) SizeAndStride { 109 | return .{ .size = self.sizes[i], .stride = self.strides[i] }; 110 | } 111 | pub fn setSizeAndStride(self: SelfPtr, i: usize, pair: SizeAndStride) void { 112 | self.sizes[i] = pair.size; 113 | self.strides[i] = pair.stride; 114 | } 115 | }; 116 | } 117 | 118 | ///////////////////////////////// 119 | //////////// TESTING //////////// 120 | 121 | test "Rowwise/Colwise Ordering" { 122 | const std = @import("std"); 123 | 124 | { //////////////////////////////////////////// 125 | const s1 = SizesAndStrides(3, Rowwise).init(.{ 3, 2, 2 }); 126 | try std.testing.expect(s1.sizes[0] == 3); 127 | try std.testing.expect(s1.sizes[1] == 2); 128 | try std.testing.expect(s1.sizes[2] == 2); 129 | try std.testing.expect(s1.strides[0] == 4); 130 | try std.testing.expect(s1.strides[1] == 2); 131 | try std.testing.expect(s1.strides[2] == 1); 132 | } 133 | { //////////////////////////////////////////// 134 | const s1 = SizesAndStrides(3, Colwise).init(.{ 3, 2, 2 }); 135 | try std.testing.expect(s1.sizes[0] == 3); 136 | try std.testing.expect(s1.sizes[1] == 2); 137 | try std.testing.expect(s1.sizes[2] == 2); 138 | try std.testing.expect(s1.strides[0] == 1); 139 | try std.testing.expect(s1.strides[1] == 3); 140 | try std.testing.expect(s1.strides[2] == 6); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/stack_allocator.zig: -------------------------------------------------------------------------------- 1 | /////////////////////////////////////////////////////////////// 2 | //// Motivation and Explanation for StackAllocator //////////// 3 | 4 | // The StackAllocator has a congtiguous memory buffer (size in bytes) 5 | // that it attempts to utilize before deferring to it's backing_allocator. 6 | // 7 | // It can only roll-back the used capacity if what is being freed was the 8 | // last thing to be allocated (like a typical stack, first-in-last-out). 9 | // 10 | // To free all of the memory from the stack, deallocate the items in 11 | // reverse order to what they were allocated in. 12 | // 13 | // Resize will only work if you are attempting to resize the last 14 | // allocated item (item on top of the stack). 15 | // 16 | // If you overflow the stack, the StackAllocator will defer to using 17 | // its backing_allocator. 18 | // 19 | 20 | const std = @import("std"); 21 | 22 | pub fn StackBuffer(comptime size: usize) type { 23 | return struct { 24 | const Self = @This(); 25 | const Size = size; 26 | 27 | items: [Size]u8 = undefined, 28 | used: usize = 0, 29 | 30 | pub fn withdraw(self: *Self, n: usize) ?[]u8 { 31 | if ((n + self.used) <= self.items.len) { 32 | const data = self.items[self.used .. self.used + n]; 33 | self.used += n; 34 | return data; 35 | } 36 | return null; 37 | } 38 | 39 | pub inline fn owns(self: *const Self, data: []u8) bool { 40 | const lhs = @intFromPtr(&self.items[0]); 41 | const rhs = @intFromPtr(&self.items[self.items.len - 1]); 42 | const ptr = @intFromPtr(data.ptr); 43 | return (lhs <= ptr) and (ptr <= rhs); 44 | } 45 | 46 | pub inline fn isTop(self: *const Self, data: []u8) bool { 47 | // can only pop values off the top of the stack 48 | if (self.used < data.len) { 49 | return false; 50 | } 51 | // check to see if we can back up the values 52 | return (@intFromPtr(&self.items[self.used - data.len]) == @intFromPtr(data.ptr)); 53 | } 54 | 55 | pub fn canResize(self: *const Self, data: []u8, n: usize) bool { 56 | // can only resize values at the top of the stack 57 | if (!self.isTop(data)) { 58 | return false; 59 | } 60 | const old_used = self.used - data.len; 61 | const new_used = old_used + n; 62 | return new_used <= Size; 63 | } 64 | 65 | pub fn deposit(self: *Self, data: []u8) bool { 66 | if (!self.owns(data)) { 67 | return false; 68 | } 69 | // check to see if we can back up the values 70 | if (self.isTop(data)) { 71 | self.used -= data.len; 72 | } 73 | return true; 74 | } 75 | }; 76 | } 77 | 78 | //////////////////////////////////////////////////////// 79 | //////// StackAllocator Implementation ///////////////// 80 | 81 | pub fn StackAllocator(comptime size: usize) type { 82 | return struct { 83 | const Self = @This(); 84 | const Size = size; 85 | 86 | stack_buffer: StackBuffer(Size), 87 | backing_allocator: std.mem.Allocator, 88 | 89 | // TODO: Create a dummy mutex that can be swapped via policy 90 | mutex: std.Thread.Mutex = std.Thread.Mutex{}, 91 | 92 | pub fn init(backing_allocator: std.mem.Allocator) Self { 93 | return Self{ 94 | .backing_allocator = backing_allocator, 95 | .stack_buffer = .{}, 96 | }; 97 | } 98 | 99 | pub fn allocator(self: *Self) std.mem.Allocator { 100 | return .{ 101 | .ptr = self, 102 | .vtable = &.{ 103 | .alloc = alloc, 104 | .resize = resize, 105 | .free = free, 106 | }, 107 | }; 108 | } 109 | 110 | pub fn alloc(ctx: *anyopaque, len: usize, log2_ptr_align: u8, ret_addr: usize) ?[*]u8 { 111 | const self: *Self = @ptrCast(@alignCast(ctx)); 112 | 113 | self.mutex.lock(); 114 | 115 | defer self.mutex.unlock(); 116 | 117 | if (self.stack_buffer.withdraw(len)) |data| { 118 | return data.ptr; 119 | } 120 | return self.backing_allocator.rawAlloc(len, log2_ptr_align, ret_addr); 121 | } 122 | 123 | pub fn resize( 124 | ctx: *anyopaque, 125 | old_mem: []u8, 126 | log2_align: u8, 127 | new_len: usize, 128 | ret_addr: usize, 129 | ) bool { 130 | const self: *Self = @ptrCast(@alignCast(ctx)); 131 | 132 | self.mutex.lock(); 133 | 134 | defer self.mutex.unlock(); 135 | 136 | if (!self.stack_buffer.owns(old_mem)) { 137 | return self.backing_allocator.rawResize(old_mem, log2_align, new_len, ret_addr); 138 | } 139 | return self.stack_buffer.canResize(old_mem, new_len); 140 | } 141 | 142 | pub fn free( 143 | ctx: *anyopaque, 144 | old_mem: []u8, 145 | log2_align: u8, 146 | ret_addr: usize, 147 | ) void { 148 | const self: *Self = @ptrCast(@alignCast(ctx)); 149 | 150 | self.mutex.lock(); 151 | 152 | defer self.mutex.unlock(); 153 | 154 | // if we do not own the memory, we'll try 155 | // to free it using the backing allocator 156 | if (!self.stack_buffer.deposit(old_mem)) { 157 | self.backing_allocator.rawFree(old_mem, log2_align, ret_addr); 158 | } 159 | } 160 | }; 161 | } 162 | 163 | ///////////////////////////////////////////////////////// 164 | /////// StackAllocator Testing Section ////////////////// 165 | 166 | test "basic stack properties" { 167 | var GPA = std.heap.GeneralPurposeAllocator(.{}){}; 168 | var stack_allocator = StackAllocator(100).init(GPA.allocator()); 169 | var allocator = stack_allocator.allocator(); 170 | 171 | { // reverse-order stack popping 172 | const a = try allocator.alloc(u8, 10); 173 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 10); 174 | const b = try allocator.alloc(u8, 10); 175 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 20); 176 | 177 | allocator.free(b); 178 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 10); 179 | allocator.free(a); 180 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 0); 181 | } 182 | 183 | { // unordered stack popping 184 | const a = try allocator.alloc(u8, 10); 185 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 10); 186 | const b = try allocator.alloc(u8, 10); 187 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 20); 188 | 189 | allocator.free(a); 190 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 20); 191 | allocator.free(b); 192 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 10); 193 | } 194 | 195 | if (GPA.deinit() == .leak) @panic("MEMORY LEAK DETECTED!!"); 196 | } 197 | 198 | test "basic stack resize" { 199 | var GPA = std.heap.GeneralPurposeAllocator(.{}){}; 200 | var stack_allocator = StackAllocator(100).init(GPA.allocator()); 201 | var allocator = stack_allocator.allocator(); 202 | 203 | { // resize checking 204 | const a = try allocator.alloc(u8, 10); 205 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 10); 206 | const b = try allocator.alloc(u8, 10); 207 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 20); 208 | 209 | // a cannot resize because it is not on the top of the stack 210 | try std.testing.expect(!allocator.resize(a, 20)); 211 | 212 | // b can resize because it is on the top of the stack 213 | try std.testing.expect(allocator.resize(b, 20)); 214 | 215 | // b should be able to take the remaining memory 216 | try std.testing.expect(allocator.resize(b, 90)); 217 | 218 | // b should not be able to take more than remainder 219 | try std.testing.expect(!allocator.resize(b, 91)); 220 | } 221 | 222 | if (GPA.deinit() == .leak) @panic("MEMORY LEAK DETECTED!!"); 223 | } 224 | 225 | test "stack-overflow allocation" { 226 | var GPA = std.heap.GeneralPurposeAllocator(.{}){}; 227 | var stack_allocator = StackAllocator(100).init(GPA.allocator()); 228 | var allocator = stack_allocator.allocator(); 229 | 230 | { // overflow the full memory stack 231 | const a = try allocator.alloc(u8, 100); 232 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 100); 233 | 234 | const b = try allocator.alloc(u8, 100); 235 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 100); 236 | 237 | allocator.free(a); 238 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 0); 239 | allocator.free(b); 240 | try std.testing.expectEqual(stack_allocator.stack_buffer.used, 0); 241 | } 242 | 243 | if (GPA.deinit() == .leak) @panic("MEMORY LEAK DETECTED!!"); 244 | } 245 | -------------------------------------------------------------------------------- /src/tensor.zig: -------------------------------------------------------------------------------- 1 | // Here we find the heart of Zein - Tensors. Before proceeding, please read the following: 2 | 3 | /////////////////////////////////// 4 | // DESIGN PHILOSOPHY (June 3, 2023) 5 | 6 | // MEMORY, OWNDERSHIP, AND REFERENCING // 7 | 8 | // There is no plan to make a distinction between a tensor and a "view" of a tensor. 9 | // Tensors here are, by design, a way to view data. As such, a different "tensored" view 10 | // of the same data is just another tensor that shares underlying memory. 11 | 12 | // !!! THIS STRONGLY IMPLIES THAT TENSORS DO NOT *OWN* DATA, THEY VIEW IT !!! 13 | 14 | // If anything can be said to "own" memory, it is the allocator. Allocators are going 15 | // to play an important role in this library (as they do in Zig more generally). 16 | 17 | // To create a tensor that has initialized memory is the job of a factory. 18 | // The design of such a tensor factory, as it were, will be handled in a source 19 | // file dedicated to that exact job. It is very important that we do not cross 20 | // responsibilities in this system. 21 | 22 | // TENSORS AS THEY RELATE TO ARRAYS // 23 | 24 | // Because of the design descisions outlined above, users should be able to easily 25 | // make a tensor with their desired dimensions to wrap existing arrays and manipulate 26 | // them as if they were tensors themselves. This means that a tensor can act like 27 | // an adapter to already existing memory. 28 | 29 | // Because of this, there is not a current plan to enforce that tensors must be of 30 | // one type or another. It is my hope to provide a generic tensor based interface 31 | // that can be used on a variety of objects at the user's caution. 32 | 33 | // At some point, it may be important to then provide a generic functional interface 34 | // to provide for further use cases such as generically holding objects that users 35 | // create themselves. While this is an interesting goal, the scope of V1 is currently 36 | // focused on integer and floating point numbers. User provided types will have to 37 | // be reviewed as time goes forward. 38 | 39 | const std = @import("std"); 40 | 41 | const Util = @import("utility.zig"); 42 | 43 | // STD import files... 44 | const ReduceOp = @import("std").builtin.ReduceOp; 45 | 46 | const arrayProduct = @import("./utility.zig").arrayProduct; 47 | 48 | // Zein import files... 49 | pub const SizeAndStride = @import("./sizes_and_strides.zig").SizeAndStride; 50 | pub const SizesAndStrides = @import("./sizes_and_strides.zig").SizesAndStrides; 51 | pub const OrderType = @import("./sizes_and_strides.zig").OrderType; 52 | pub const Rowwise = @import("./sizes_and_strides.zig").Rowwise; 53 | pub const Colwise = @import("./sizes_and_strides.zig").Colwise; 54 | const Permutate = @import("./permutate.zig"); 55 | 56 | // Tensor Utilities... 57 | pub const TensorError = error{ InvalidTensorLayout, InvalidPermutation, AllocSizeMismatch, CapacityMismatch, RankMismatch }; 58 | 59 | pub inline fn computeTensorIndex( 60 | comptime rank: usize, 61 | comptime size_type: type, 62 | strides: []const size_type, 63 | indices: []const size_type 64 | ) size_type { 65 | return switch(rank) { 66 | 1 => indices[0], // direct index... just an array 67 | 2 => indices[0] * strides[0] + indices[1] * strides[1], 68 | else => blk: { // inner product between indices and strides 69 | const s: @Vector(rank, size_type) = strides[0..rank].*; 70 | const i: @Vector(rank, size_type) = indices[0..rank].*; 71 | break :blk @reduce(ReduceOp.Add, s * i); 72 | }, 73 | }; 74 | } 75 | 76 | /////////////////////////// 77 | // Tensor Implementation // 78 | 79 | pub fn Tensor(comptime value_type: type, comptime rank: usize, comptime order: OrderType) type { 80 | if (63 < rank) { 81 | @compileError("Tensors of rank 64 or greater are not supported."); 82 | } 83 | 84 | if (0 == rank) { 85 | @compileError("Tensors of rank zero are not supported."); 86 | } 87 | 88 | return struct { 89 | pub const Rank = rank; 90 | 91 | pub const Order = order; 92 | 93 | pub const SizesType = SizeAndStride.ValueType; 94 | 95 | pub const ValueType = value_type; 96 | 97 | pub const ValueSlice = []ValueType; 98 | 99 | pub const SizesAndStridesType = SizesAndStrides(Rank, Order); 100 | 101 | const Self = @This(); 102 | 103 | const SelfPtr = *Self; 104 | 105 | const ConstSelfPtr = *const Self; 106 | 107 | values: ValueSlice, 108 | 109 | sizes_and_strides: SizesAndStridesType, 110 | 111 | pub fn init(values: ?ValueSlice, sizes: ?[Rank]SizesType) Self { 112 | return Self{ 113 | .values = if (values) |vs| (vs) else &[_]ValueType{}, 114 | .sizes_and_strides = SizesAndStridesType.init(sizes), 115 | }; 116 | } 117 | 118 | pub fn sliceSizes(self: ConstSelfPtr, i: usize, j: usize) []const SizesType { 119 | return &self.sizes_and_strides.sizes[i..j]; 120 | } 121 | pub fn sliceStrides(self: ConstSelfPtr, i: usize, j: usize) []const SizesType { 122 | return &self.sizes_and_strides.strides[i..j]; 123 | } 124 | pub fn slicePermutation(self: ConstSelfPtr, i: usize, j: usize) []const SizesType { 125 | return &self.sizes_and_strides.permutation[i..j]; 126 | } 127 | 128 | pub fn getSizes(self: ConstSelfPtr) []const SizesType { 129 | return &self.sizes_and_strides.sizes; 130 | } 131 | pub fn getStrides(self: ConstSelfPtr) []const SizesType { 132 | return &self.sizes_and_strides.strides; 133 | } 134 | pub fn getPermutation(self: ConstSelfPtr) []const SizesType { 135 | return &self.sizes_and_strides.permutation; 136 | } 137 | 138 | pub fn valueCapacity(self: ConstSelfPtr) usize { 139 | return arrayProduct(Rank, SizesType, &self.sizes_and_strides.sizes); 140 | } 141 | 142 | pub fn valueSize(self: ConstSelfPtr) usize { 143 | return self.values.len; 144 | } 145 | 146 | pub fn isValid(self: ConstSelfPtr) bool { 147 | return self.valueSize() != 0 and self.valueSize() == self.valueCapacity(); 148 | } 149 | 150 | pub fn swap(self: SelfPtr, other: SelfPtr) void { 151 | self.swapValues(other); 152 | self.swapSizesAndStrides(other); 153 | } 154 | 155 | pub fn swapValues(self: SelfPtr, other: SelfPtr) void { 156 | // to assure that sizes and strides are not 157 | // invalidated, we check size and capacity 158 | std.debug.assert(self.valueSize() == other.valueSize()); 159 | std.debug.assert(self.isValid() and other.isValid()); 160 | 161 | const values = self.values; 162 | self.values = other.values; 163 | other.values = values; 164 | } 165 | 166 | pub fn swapSizesAndStrides(self: SelfPtr, other: SelfPtr) void { 167 | // we only want to compute these once... 168 | 169 | if (comptime Util.debug) { 170 | const capacity_a = self.valueCapacity(); 171 | const capacity_b = other.valueCapacity(); 172 | // tensors can have different SizesAndStrides 173 | // and still share the total value capcity 174 | std.debug.assert(capacity_a == capacity_b); 175 | // check that both tensors are at capacity without additional computation 176 | std.debug.assert( 177 | self.valueSize() == capacity_a and other.valueSize() == capacity_b 178 | ); 179 | } 180 | 181 | // there is probably a faster way to do this 182 | const tmp = self.sizes_and_strides; 183 | self.sizes_and_strides = other.sizes_and_strides; 184 | other.sizes_and_strides = tmp; 185 | } 186 | 187 | pub fn permutate(self: SelfPtr, comptime expression: []const u8) Self { 188 | // create a permutated tensor that shares the same underlying memory 189 | std.debug.assert(self.isValid()); 190 | 191 | var tmp = self.*; // share values 192 | Permutate.permutate(Rank, Order, expression, &tmp.sizes_and_strides); 193 | return tmp; 194 | } 195 | 196 | pub fn getValue(self: ConstSelfPtr, indices: [rank]SizesType) ValueType { 197 | const n = computeTensorIndex(Rank, SizesType, self.getStrides(), &indices); 198 | return self.values[n]; 199 | } 200 | 201 | pub fn setValue(self: ConstSelfPtr, value: ValueType, indices: [rank]SizesType) void { 202 | const n = computeTensorIndex(Rank, SizesType, self.getStrides(), &indices); 203 | self.values[n] = value; 204 | } 205 | 206 | pub inline fn getSize(self: ConstSelfPtr, i: usize) SizesType { 207 | return self.sizes_and_strides.sizes[i]; 208 | } 209 | 210 | pub inline fn getStride(self: ConstSelfPtr, i: usize) SizesType { 211 | return self.sizes_and_strides.strides[i]; 212 | } 213 | }; 214 | } 215 | 216 | test "Initialization" { 217 | const expect = std.testing.expect; 218 | 219 | var x = Tensor(u32, 3, Rowwise).init(null, .{ 10, 20, 30 }); 220 | 221 | const total: usize = 10 * 20 * 30; 222 | 223 | try expect(total == x.valueCapacity()); 224 | } 225 | 226 | test "Tensor Swapping" { 227 | const expect = std.testing.expect; 228 | 229 | const x_values = try std.heap.page_allocator.alloc(i8, 100); 230 | defer std.heap.page_allocator.free(x_values); 231 | 232 | const y_values = try std.heap.page_allocator.alloc(i8, 100); 233 | defer std.heap.page_allocator.free(y_values); 234 | 235 | var x = Tensor(i8, 2, Rowwise).init(x_values, .{ 10, 10 }); 236 | var y = Tensor(i8, 2, Rowwise).init(y_values, .{ 10, 10 }); 237 | 238 | x.swap(&y); 239 | 240 | try expect(x.values.ptr == y_values.ptr); 241 | try expect(y.values.ptr == x_values.ptr); 242 | 243 | const total: usize = 10 * 10; 244 | 245 | try expect(total == x.valueCapacity()); 246 | try expect(total == y.valueCapacity()); 247 | } 248 | 249 | test "Tensor Transpose" { 250 | const expect = std.testing.expect; 251 | 252 | var data = [9]i32{ 1, 2, 3, 4, 5, 6, 7, 8, 9 }; 253 | 254 | var x = Tensor(i32, 2, Rowwise).init(&data, .{ 3, 3 }); 255 | 256 | try expect(x.isValid()); 257 | 258 | try expect(x.getValue(.{ 0, 0 }) == 1); 259 | try expect(x.getValue(.{ 0, 1 }) == 2); 260 | try expect(x.getValue(.{ 0, 2 }) == 3); 261 | try expect(x.getValue(.{ 1, 0 }) == 4); 262 | try expect(x.getValue(.{ 1, 1 }) == 5); 263 | try expect(x.getValue(.{ 1, 2 }) == 6); 264 | try expect(x.getValue(.{ 2, 0 }) == 7); 265 | try expect(x.getValue(.{ 2, 1 }) == 8); 266 | try expect(x.getValue(.{ 2, 2 }) == 9); 267 | 268 | var y = x.permutate("ij->ji"); 269 | 270 | try expect(y.getValue(.{ 0, 0 }) == 1); 271 | try expect(y.getValue(.{ 0, 1 }) == 4); 272 | try expect(y.getValue(.{ 0, 2 }) == 7); 273 | try expect(y.getValue(.{ 1, 0 }) == 2); 274 | try expect(y.getValue(.{ 1, 1 }) == 5); 275 | try expect(y.getValue(.{ 1, 2 }) == 8); 276 | try expect(y.getValue(.{ 2, 0 }) == 3); 277 | try expect(y.getValue(.{ 2, 1 }) == 6); 278 | try expect(y.getValue(.{ 2, 2 }) == 9); 279 | } 280 | -------------------------------------------------------------------------------- /src/expression_parsing.zig: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////// 2 | // Expression Parsing for Einsum style string expressions. 3 | 4 | // Currently, the expression parser does not tolerate 5 | // whitespace in expressions. This will be reviewed 6 | // at a later date, but currently is not required to 7 | // create well-formed strings. 8 | 9 | // parser utility functions. These functions are intended 10 | // to be executed at comptime. 11 | 12 | const SizesType = @import("./sizes_and_strides.zig").SizeAndStride.ValueType; 13 | 14 | pub fn between(comptime value: u8, comptime lower: u8, comptime upper: u8) bool { 15 | return lower <= value and value <= upper; 16 | } 17 | 18 | pub fn isAlpha(comptime value: u8) bool { 19 | return between(value, 65, 90) or between(value, 97, 122); // [91, 96] are: [\]^_` 20 | } 21 | 22 | pub fn allAlpha(comptime str: []const u8) bool { 23 | comptime var i: usize = 0; 24 | inline while (i < str.len) : (i += 1) { 25 | if (!isAlpha(str[i])) { 26 | return false; 27 | } 28 | } 29 | return true; 30 | } 31 | 32 | pub fn contains(comptime char: u8, comptime string: []const u8) bool { 33 | comptime var i: usize = 0; 34 | inline while (i < string.len) : (i += 1) { 35 | if (char == string[i]) { 36 | return true; 37 | } 38 | } 39 | return false; 40 | } 41 | 42 | // check that a permutation is both full and accounted for 43 | pub fn isPermutation(comptime source: []const u8, comptime target: []const u8) bool { 44 | if (source.len != target.len) { 45 | return false; 46 | } 47 | if (source.len == 0) { // the empty set is a permutation of itself 48 | return true; 49 | } 50 | // create mask for proper permutation 51 | const full: usize = (1 << source.len) - 1; 52 | comptime var i_mask: usize = 0; 53 | comptime var j_mask: usize = 0; 54 | 55 | comptime var i: usize = 0; 56 | comptime var j: usize = 0; 57 | inline while (i < source.len) : ({ 58 | i += 1; 59 | j = 0; 60 | }) { 61 | inline while (j < target.len) : (j += 1) { 62 | if (source[i] == target[j]) { 63 | i_mask |= (1 << i); 64 | j_mask |= (1 << j); 65 | } 66 | } 67 | } 68 | return i_mask == j_mask and i_mask == full; 69 | } 70 | 71 | pub fn countUniqueAlpha(comptime string: []const u8) usize { 72 | comptime var n: u64 = 0; 73 | comptime var i: usize = 0; 74 | inline while (i < string.len) : (i += 1) { 75 | if (isAlpha(string[i])) { 76 | n |= (1 << (string[i] - 65)); 77 | } 78 | } 79 | return @popCount(n); 80 | } 81 | 82 | pub fn uniqueAlpha(comptime string: []const u8) [countUniqueAlpha(string)]u8 { 83 | const N = comptime countUniqueAlpha(string); 84 | comptime var i: usize = 0; 85 | comptime var j: usize = 0; 86 | comptime var chars: [N]u8 = .{0} ** N; 87 | inline while (i < string.len) : (i += 1) { 88 | if (comptime isAlpha(string[i]) and !contains(string[i], &chars)) { 89 | chars[j] = string[i]; 90 | j += 1; 91 | } 92 | } 93 | return chars; 94 | } 95 | 96 | const ArrowOp = struct { 97 | tail: usize = 0, 98 | head: usize = 0, 99 | }; 100 | 101 | pub fn findArrowOp(str: []const u8) ArrowOp { 102 | // reference for array operator 103 | const arrow: []const u8 = "->"; 104 | 105 | comptime var head: usize = 0; 106 | comptime var tail: usize = 0; 107 | comptime var index: usize = 0; 108 | inline while (index < str.len) : (index += 1) { 109 | if (str[index] == arrow[0]) { 110 | tail = index; 111 | } 112 | if (str[index] == arrow[1]) { 113 | head = index; 114 | } 115 | } 116 | if ((tail + 1) != head) { 117 | @compileError("Malformed arrow operator: " ++ str); 118 | } 119 | if (tail == 0 or head > (str.len - 2)) { 120 | @compileError("Arrow must be used as infix operator: " ++ str); 121 | } 122 | return ArrowOp{ .tail = tail, .head = head }; 123 | } 124 | 125 | pub fn findCommaOp(str: []const u8) usize { 126 | comptime var comma: usize = 0; 127 | comptime var index: usize = 0; 128 | inline while (index < str.len) : (index += 1) { 129 | if (str[index] == ","[0]) { 130 | comma = index; 131 | break; 132 | } 133 | } 134 | if (comma == 0 or comma >= (str.len - 1)) { 135 | @compileError("Comma must be used as infix operator: " ++ str); 136 | } 137 | return comma; 138 | } 139 | 140 | pub fn permutateParse(comptime Rank: usize, comptime str: []const u8) [Rank]SizesType { 141 | const arrow = comptime findArrowOp(str); 142 | const lhs = str[0..arrow.tail]; 143 | const rhs = str[arrow.head + 1 ..]; 144 | 145 | if (lhs.len != Rank) { 146 | @compileError("Left operand is not equal to the rank: " ++ lhs); 147 | } 148 | if (rhs.len != Rank) { 149 | @compileError("Right operand is not equal to the rank: " ++ rhs); 150 | } 151 | if (!comptime allAlpha(lhs)) { 152 | @compileError("Non-alphabetical character found in: " ++ lhs); 153 | } 154 | if (!comptime allAlpha(rhs)) { 155 | @compileError("Non-alphabetical character found in: " ++ rhs); 156 | } 157 | if (!comptime isPermutation(lhs, rhs)) { 158 | @compileError("Permutate requires left and right operands to be permutations of eachother." ++ str); 159 | } 160 | 161 | //////////////////////////////////////// 162 | // build permutation contraction indices 163 | 164 | comptime var i: usize = 0; 165 | comptime var j: usize = 0; 166 | comptime var indices: [Rank]SizesType = undefined; 167 | 168 | inline while (i < Rank) : ({ 169 | i += 1; 170 | j = 0; 171 | }) { 172 | inline while (j < Rank) : (j += 1) { 173 | if (rhs[i] == lhs[j]) { 174 | indices[i] = j; 175 | break; 176 | } 177 | } 178 | } 179 | return indices; 180 | } 181 | 182 | // Contraction parsing is expects strings of the form: 183 | // 184 | // example: ijk->jk 185 | // 186 | // The expressions must be larger on the left-operand than 187 | // the right operand (denoting contracted indices). 188 | // 189 | // The left and right operands must be alpha-characters. 190 | 191 | pub fn contractedRank(comptime str: []const u8) usize { 192 | return (str.len - (comptime findArrowOp(str)).head) - 1; 193 | } 194 | 195 | pub fn ContractionPlan(comptime lRank: usize, comptime rRank: usize) type { 196 | return struct { 197 | lhs: [lRank]SizesType = undefined, 198 | rhs: [rRank]SizesType = undefined, 199 | }; 200 | } 201 | 202 | pub fn contractionParse(comptime lRank: usize, comptime rRank: usize, comptime str: []const u8) ContractionPlan(lRank, rRank) { 203 | comptime var index: usize = 0; 204 | 205 | const arrow = comptime findArrowOp(str); 206 | const lhs = str[0..arrow.tail]; 207 | const rhs = str[arrow.head + 1 ..]; 208 | 209 | if (lhs.len == 0) { 210 | @compileError("Empty left-side operand: " ++ str); 211 | } 212 | if (rhs.len == 0) { 213 | @compileError("Empty right-side operand: " ++ str); 214 | } 215 | if (lhs.len != lRank) { 216 | @compileError("Provided indices do not match left-side operand rank: " ++ lhs); 217 | } 218 | if (rhs.len != rRank) { 219 | @compileError("Provided indices do not match right-side operand rank: " ++ rhs); 220 | } 221 | if (!comptime allAlpha(lhs)) { 222 | @compileError("Non-alphabetical character found in: " ++ lhs); 223 | } 224 | if (!comptime allAlpha(rhs)) { 225 | @compileError("Non-alphabetical character found in: " ++ rhs); 226 | } 227 | 228 | //////////////////////////////////////// 229 | // build permutation contraction indices 230 | 231 | comptime var x_indices: [lhs.len]u32 = undefined; 232 | comptime var y_indices: [rhs.len]u32 = undefined; 233 | comptime var remainder: [lhs.len + rhs.len]u32 = undefined; 234 | comptime var char: u8 = undefined; 235 | comptime var match: u32 = 0; 236 | comptime var rhs_i: u32 = 0; 237 | comptime var rem_i: u32 = 0; 238 | comptime var found: bool = false; 239 | 240 | index = 0; 241 | inline while (index < lhs.len) : (index += 1) { 242 | 243 | // matched + unmatched = total 244 | if (match == rhs.len and rem_i == remainder.len) { 245 | break; 246 | } 247 | 248 | char = lhs[index]; 249 | 250 | found = false; 251 | 252 | // try to match the current char 253 | // in both rhs and lhs operands 254 | 255 | rhs_i = 0; 256 | inline while (rhs_i < rhs.len) : (rhs_i += 1) { 257 | if (rhs[rhs_i] == char) { 258 | x_indices[match] = index; 259 | y_indices[match] = rhs_i; 260 | found = true; 261 | match += 1; 262 | break; 263 | } 264 | } 265 | 266 | // if no match, add to remainder 267 | 268 | if (!found) { 269 | remainder[rem_i] = index; 270 | rem_i += 1; 271 | } 272 | } 273 | 274 | if (match != rhs.len) { 275 | @compileError("Unmatched dimensions between operands:" ++ str); 276 | } 277 | 278 | rem_i = 0; 279 | index = rhs.len; 280 | inline while (index < lhs.len) : ({ 281 | index += 1; 282 | rem_i += 1; 283 | }) { 284 | x_indices[index] = remainder[rem_i]; 285 | } 286 | return ContractionPlan(lRank, rRank){ .lhs = x_indices, .rhs = y_indices }; 287 | } 288 | 289 | /////////////////////// 290 | //// Inner-Product //// 291 | 292 | pub fn InnerProductPlan(comptime N: usize) type { 293 | const pass_flag: usize = 9999; 294 | 295 | return struct { 296 | pass: usize = pass_flag, 297 | x_perm: [N]usize = .{pass_flag} ** N, 298 | y_perm: [N]usize = .{pass_flag} ** N, 299 | z_perm: [N]usize = .{pass_flag} ** N, 300 | s_ctrl: [N]usize = .{pass_flag} ** N, 301 | total: usize = N, 302 | }; 303 | } 304 | 305 | pub fn innerProductParse(comptime XRank: usize, comptime YRank: usize, comptime ZRank: usize, comptime expression: []const u8) InnerProductPlan(countUniqueAlpha(expression)) { 306 | const arrow = comptime findArrowOp(expression); 307 | const comma = comptime findCommaOp(expression); 308 | 309 | if (comma >= (arrow.tail - 1)) { 310 | @compileError("Comma operator must come before left operand: " ++ expression); 311 | } 312 | 313 | const lhs = expression[0..comma]; 314 | const rhs = expression[comma + 1 .. arrow.tail]; 315 | const out = expression[arrow.head + 1 ..]; 316 | 317 | if (lhs.len == 0) { 318 | @compileError("Empty left-side operand: " ++ expression); 319 | } 320 | if (rhs.len == 0) { 321 | @compileError("Empty right-side operand: " ++ expression); 322 | } 323 | if (out.len == 0) { 324 | @compileError("Empty expression result: " ++ expression); 325 | } 326 | if (lhs.len != XRank) { 327 | @compileError("Provided indices do not match left-side operand rank: " ++ lhs); 328 | } 329 | if (rhs.len != YRank) { 330 | @compileError("Provided indices do not match right-side operand rank: " ++ rhs); 331 | } 332 | if (out.len != ZRank) { 333 | @compileError("Provided indices do not match result rank: " ++ out); 334 | } 335 | if (!comptime allAlpha(lhs)) { 336 | @compileError("Non-alphabetical character found in: " ++ lhs); 337 | } 338 | if (!comptime allAlpha(rhs)) { 339 | @compileError("Non-alphabetical character found in: " ++ rhs); 340 | } 341 | if (!comptime allAlpha(out)) { 342 | @compileError("Non-alphabetical character found in: " ++ out); 343 | } 344 | 345 | //////////////////////////////////////// 346 | // build inner product control indices 347 | 348 | const N = countUniqueAlpha(expression); 349 | 350 | comptime var plan = InnerProductPlan(N){}; 351 | 352 | // loop index variables 353 | comptime var i = 0; 354 | comptime var j = 0; 355 | const chars = uniqueAlpha(expression); 356 | 357 | i = 0; 358 | inline while (i < N) : (i += 1) { 359 | j = 0; 360 | inline while (j < lhs.len) : (j += 1) { 361 | if (lhs[j] == chars[i]) { 362 | plan.x_perm[i] = j; 363 | plan.s_ctrl[i] = 0; 364 | } 365 | } 366 | j = 0; 367 | inline while (j < rhs.len) : (j += 1) { 368 | if (rhs[j] == chars[i]) { 369 | plan.y_perm[i] = j; 370 | plan.s_ctrl[i] = 1; 371 | } 372 | } 373 | j = 0; 374 | inline while (j < out.len) : (j += 1) { 375 | if (out[j] == chars[i]) { 376 | plan.z_perm[i] = j; 377 | } 378 | } 379 | } 380 | return plan; 381 | } 382 | 383 | pub fn OuterProductPlan(comptime N: usize) type { 384 | const pass_flag: usize = 9999; 385 | 386 | return struct { 387 | pass: usize = pass_flag, 388 | x_perm: [N]usize = .{pass_flag} ** N, 389 | y_perm: [N]usize = .{pass_flag} ** N, 390 | z_perm: [N]usize = .{pass_flag} ** N, 391 | total: usize = N, 392 | }; 393 | } 394 | 395 | pub fn outerProductParse(comptime XRank: usize, comptime YRank: usize, comptime ZRank: usize, comptime expression: []const u8) OuterProductPlan(countUniqueAlpha(expression)) { 396 | const arrow = comptime findArrowOp(expression); 397 | const comma = comptime findCommaOp(expression); 398 | 399 | if (comma >= (arrow.tail - 1)) { 400 | @compileError("Comma operator must come before left operand: " ++ expression); 401 | } 402 | 403 | const lhs = expression[0..comma]; 404 | const rhs = expression[comma + 1 .. arrow.tail]; 405 | const out = expression[arrow.head + 1 ..]; 406 | 407 | if (lhs.len == 0) { 408 | @compileError("Empty left-side operand: " ++ expression); 409 | } 410 | if (rhs.len == 0) { 411 | @compileError("Empty right-side operand: " ++ expression); 412 | } 413 | if (out.len == 0) { 414 | @compileError("Empty expression result: " ++ expression); 415 | } 416 | if (lhs.len != XRank) { 417 | @compileError("Provided indices do not match left-side operand rank: " ++ lhs); 418 | } 419 | if (rhs.len != YRank) { 420 | @compileError("Provided indices do not match right-side operand rank: " ++ rhs); 421 | } 422 | if (out.len != ZRank) { 423 | @compileError("Provided indices do not match result rank: " ++ out); 424 | } 425 | if (!comptime allAlpha(lhs)) { 426 | @compileError("Non-alphabetical character found in: " ++ lhs); 427 | } 428 | if (!comptime allAlpha(rhs)) { 429 | @compileError("Non-alphabetical character found in: " ++ rhs); 430 | } 431 | if (!comptime allAlpha(out)) { 432 | @compileError("Non-alphabetical character found in: " ++ out); 433 | } 434 | 435 | //////////////////////////////////////// 436 | // build inner product control indices 437 | 438 | const N = countUniqueAlpha(expression); 439 | 440 | comptime var plan = OuterProductPlan(N){}; 441 | 442 | // loop index variables 443 | comptime var i = 0; 444 | comptime var j = 0; 445 | const chars = uniqueAlpha(expression); 446 | 447 | i = 0; 448 | inline while (i < N) : (i += 1) { 449 | j = 0; 450 | inline while (j < lhs.len) : (j += 1) { 451 | if (lhs[j] == chars[i]) { 452 | plan.x_perm[i] = j; 453 | } 454 | } 455 | j = 0; 456 | inline while (j < rhs.len) : (j += 1) { 457 | if (rhs[j] == chars[i]) { 458 | plan.y_perm[i] = j; 459 | } 460 | } 461 | j = 0; 462 | inline while (j < out.len) : (j += 1) { 463 | if (out[j] == chars[i]) { 464 | plan.z_perm[i] = j; 465 | } 466 | } 467 | } 468 | return plan; 469 | } 470 | -------------------------------------------------------------------------------- /src/linear_caching_allocator.zig: -------------------------------------------------------------------------------- 1 | 2 | /////////////////////////////////////////////////////////////// 3 | //// Motivation and Explanation for LinearCachingAllocator //// 4 | 5 | // -- General Introduction -- 6 | // 7 | // Even though this allocator employs a custom binary search 8 | // similar to lower-bound lookup, it can have worst case linear 9 | // performance. 10 | // 11 | // Most importantly, it has a linear growth factor - for 12 | // every new allocation not currently cached, it increases 13 | // the cache size by one. 14 | // 15 | // Likewise, the caching allocator needs to scan for unused 16 | // blocks once it locates a segment of the cache that can 17 | // fulfill the size request. In the worst case, this is 18 | // O(N), as each element could be the same size and all blocks 19 | // could currently be in use. 20 | // 21 | // However, in a fresh cache where all blocks are free, the lookup 22 | // will be O(log(N)). For each "hole" that is added via blocks 23 | // being marked as used, we could assume that we will have to advance 24 | // beyond said hole to find the next free block. In general, this 25 | // means our searches will be O(log(N)) + n where n is usually 26 | // signficantly less than N, but at worst is equal. 27 | // 28 | // For optimal caching behavior, we want: 29 | // 30 | // -- Enough cached memory to satisfy a range of requests 31 | // 32 | // -- Frequent check-ins from used blocks to restore holes 33 | // 34 | // -- Why use an array of indpendent `u8 slices? -- 35 | // 36 | // A typical implementation strategy for free-list style allocators 37 | // (or cascading allocators more generally) is to embed a link 38 | // next to the allocation itself (an intrusive link). This works 39 | // where it is safe to assume that all memory will be accessible 40 | // from a common memory pool. However, this allocator supports 41 | // caching memory to devices other than host memory. In other words, 42 | // we could create a segmentation fault trying to read links on 43 | // a different device. 44 | // 45 | // Device memory may also need to be allocated along more peculiar bounds 46 | // and accessing them like they were host memory can cause segmentation 47 | // faults even on valid elements. We therefore rely on the backing_allocator 48 | // to return proper alignment by default. 49 | // 50 | // -- Intended use cases and assumptions --- 51 | // 52 | // This allocator was designed to be for a certain set of assumptions: 53 | // 54 | // 1. ASSUMES: that batch-style free is desirable. The cache can be 55 | // dumped all at once using the "clear" function. Likewise, the cache 56 | // can be primed by using the "addToCache" function to preallocate 57 | // memory - "warming" the cache before use. 58 | // 59 | // 2. ASSUMES: that new allocations can be predicted via the size of old 60 | // allocations. This prevents the cache from continuing to grow 61 | // linearly. 62 | // 63 | // 3. ASSUMES: that alloc and free will be called cyclically. There 64 | // is no benefit to using this allocator in a program that only 65 | // calls alloc and free once for a given item. 66 | // 67 | // NOTE: the backing_allocator can be reassigned for different 68 | // underyling allocators to be used. By default, it is the page_allocator. 69 | 70 | const std = @import("std"); 71 | 72 | const OrderedCache = struct { 73 | 74 | const Self = @This(); 75 | 76 | const CacheType = std.ArrayList(CacheBlock); 77 | 78 | const CacheBlock = struct { 79 | data: []u8, 80 | used: bool, 81 | alignment: u8, 82 | }; 83 | 84 | buffer: CacheType, 85 | 86 | pub fn init(allocator: std.mem.Allocator) Self { 87 | return Self{ .buffer = CacheType.init(allocator) }; 88 | } 89 | pub fn deinit(self: *Self, allocator: *std.mem.Allocator) void { 90 | for(0..self.size()) |i| { 91 | allocator.free(self.itemData(i)); 92 | } 93 | self.buffer.deinit(); 94 | } 95 | pub fn clear(self: *Self, allocator: *std.mem.Allocator) void { 96 | for(0..self.size()) |i| { 97 | allocator.free(self.itemData(i)); 98 | } 99 | // Calling resize will test for capacity and then 100 | // set the length to the new size. Since we're only 101 | // going to zero, we don't need to check for capacity. 102 | self.buffer.items.len = 0; 103 | } 104 | 105 | pub inline fn size(self: *const Self) usize { 106 | return self.buffer.items.len; 107 | } 108 | inline fn itemUsed(self: *const Self, i: usize) bool { 109 | return self.buffer.items[i].used; 110 | } 111 | inline fn itemData(self: *const Self, i: usize) []u8 { 112 | return self.buffer.items[i].data; 113 | } 114 | inline fn itemAlignment(self: *const Self, i: usize) u8 { 115 | return self.buffer.items[i].alignment; 116 | } 117 | inline fn itemSize(self: *const Self, i: usize) usize { 118 | return self.buffer.items[i].data.len; 119 | } 120 | inline fn setUsed(self: *const Self, i: usize, used: bool) void { 121 | self.buffer.items[i].used = used; 122 | } 123 | 124 | pub fn lowerBoundSize(self: *const Self, n: usize) usize { 125 | var len = self.size(); 126 | var idx: usize = 0; 127 | while (len > 0) { 128 | const half = (len >> 1); 129 | const mid = half + idx; 130 | if (self.itemSize(mid) < n) { 131 | idx = mid + 1; 132 | len = (len - half) - 1; 133 | } else { 134 | len = half; 135 | } 136 | } 137 | return idx; 138 | } 139 | 140 | fn scanForUnused(self: *const Self, idx: usize, n: usize, alignment: u8) ?[]u8 { 141 | 142 | // heuristic: requests cannot grab allocations greater than 2x their size 143 | const limit = n <<| 1; 144 | 145 | var i = idx; 146 | while ((i < self.size()) and (self.itemSize(i) <= limit)) : (i += 1) { 147 | if (!self.itemUsed(i) and (self.itemAlignment(i) >= alignment)) { 148 | self.setUsed(i, true); 149 | return self.itemData(i); 150 | } 151 | } 152 | return null; 153 | } 154 | 155 | pub fn locateMemory(self: *const Self, data: []u8) ?usize { 156 | 157 | // If this function succeeds, it returns an 158 | // index within the cache-size boundary that 159 | // relates to the index of the data argument 160 | 161 | if ((self.size() == 0) or (data.len == 0)) { 162 | return null; 163 | } 164 | 165 | const limit = data.len <<| 1; 166 | 167 | var i = if (data.len <= self.itemSize(0)) 0 else self.lowerBoundSize(data.len); 168 | 169 | while ((i < self.size()) and (self.itemSize(i) <= limit)) : (i += 1) { 170 | if(self.itemData(i).ptr == data.ptr) { 171 | return i; 172 | } 173 | } 174 | return null; 175 | } 176 | 177 | pub fn withdraw(self: *Self, n: usize, alignment: u8) ?[]u8 { 178 | 179 | if ((self.size() == 0) or (n == 0)) { 180 | return null; 181 | } 182 | 183 | // Worst case guard -- if binary search finds that 184 | // element zero is a candidate, we'll search the 185 | // entire cache for direct O(N) performance. 186 | 187 | if (n <= self.itemSize(0)) { 188 | return self.scanForUnused(0, n, alignment); 189 | } 190 | 191 | // Check if cache can support size request. 192 | if (n > self.itemSize(self.size() - 1)) { 193 | return null; 194 | } 195 | 196 | // Begin scanning from first candidate index. 197 | return self.scanForUnused(self.lowerBoundSize(n), n, alignment); 198 | } 199 | 200 | pub fn deposit(self: *Self, data: []u8, alignment: u8) !void { 201 | 202 | if (data.len == 0) { 203 | return; 204 | } 205 | 206 | // Find lowest equal size index first... 207 | const idx = self.lowerBoundSize(data.len); 208 | 209 | const limit = data.len <<| 1; 210 | 211 | // From there, scan up the cache to see if 212 | // we have already encountered this pointer. 213 | // If so, set it to used and return. 214 | 215 | var i = idx; 216 | while ((i < self.size()) and (self.itemSize(i) <= limit)) : (i += 1) { 217 | if(self.itemData(i).ptr == data.ptr) { 218 | return self.setUsed(i, false); 219 | } 220 | } 221 | 222 | // insert is capcity checked -- add to cache 223 | try self.buffer.insert( 224 | idx, .{ .data = data, .used = false, .alignment = alignment } 225 | ); 226 | } 227 | 228 | pub fn reserve( 229 | self: *Self, 230 | entries: usize, 231 | ) !void { 232 | return self.buffer.ensureTotalCapacityPrecise(entries); 233 | } 234 | }; 235 | 236 | //////////////////////////////////////////////////////// 237 | //////// LinearCachingAllocator Implementation /////////////// 238 | 239 | const Config = struct { 240 | mutex: bool = false, 241 | }; 242 | 243 | const DummyMutex = struct { 244 | fn lock(_: *DummyMutex) void {} 245 | fn unlock(_: *DummyMutex) void {} 246 | }; 247 | 248 | pub fn LinearCachingAllocator(comptime config: Config) type { 249 | 250 | return struct { 251 | 252 | const Self = @This(); 253 | 254 | const MutexType = if (config.mutex) std.Thread.Mutex else DummyMutex; 255 | 256 | cache: OrderedCache, 257 | 258 | backing_allocator: std.mem.Allocator, 259 | 260 | mutex: MutexType = .{ }, 261 | 262 | pub fn init(fallback: std.mem.Allocator) Self { 263 | return Self { 264 | .backing_allocator = fallback, 265 | .cache = OrderedCache.init(fallback) 266 | }; 267 | } 268 | 269 | pub fn clear(self: *Self) void { 270 | self.cache.clear(&self.backing_allocator); 271 | } 272 | 273 | pub fn deinit(self: *Self) void { 274 | self.cache.deinit(&self.backing_allocator); 275 | } 276 | 277 | pub fn allocator(self: *Self) std.mem.Allocator { 278 | return .{ 279 | .ptr = self, 280 | .vtable = &.{ 281 | .alloc = alloc, 282 | .resize = resize, 283 | .free = free, 284 | }, 285 | }; 286 | } 287 | 288 | pub fn alloc( 289 | ctx: *anyopaque, 290 | len: usize, 291 | log2_align: u8, 292 | ret_addr: usize 293 | ) ?[*]u8 { 294 | const self: *Self = @ptrCast(@alignCast(ctx)); 295 | 296 | self.mutex.lock(); 297 | 298 | defer self.mutex.unlock(); 299 | 300 | if(self.cache.withdraw(len, log2_align)) |data| { 301 | return data.ptr; 302 | } 303 | return self.backing_allocator.rawAlloc(len, log2_align, ret_addr); 304 | } 305 | 306 | pub fn resize( 307 | ctx: *anyopaque, 308 | old_mem: []u8, 309 | log2_align: u8, 310 | new_len: usize, 311 | ret_addr: usize, 312 | ) bool { 313 | const self: *Self = @ptrCast(@alignCast(ctx)); 314 | 315 | self.mutex.lock(); 316 | 317 | defer self.mutex.unlock(); 318 | 319 | // locate pointer in cache (if exists) 320 | if (self.cache.locateMemory(old_mem)) |idx| { 321 | 322 | var data = self.cache.itemData(idx); 323 | 324 | if (self.backing_allocator.rawResize(data, log2_align, new_len, ret_addr)) { 325 | 326 | data = self.cache.buffer.orderedRemove(idx).data; 327 | 328 | // The only reason this would fail is because 329 | // the buffer allocator couldn't resize the array. 330 | // We know, however, that the capacity of the array 331 | // is already large enough for this insertion. 332 | 333 | data.len = new_len; 334 | 335 | self.cache.deposit(data, log2_align) catch unreachable; 336 | 337 | return true; 338 | } 339 | } 340 | return false; 341 | } 342 | 343 | pub fn free( 344 | ctx: *anyopaque, 345 | old_mem: []u8, 346 | log2_align: u8, 347 | ret_addr: usize, 348 | ) void { 349 | const self: *Self = @ptrCast(@alignCast(ctx)); 350 | 351 | self.mutex.lock(); 352 | 353 | defer self.mutex.unlock(); 354 | 355 | self.cache.deposit(old_mem, log2_align) catch { 356 | self.backing_allocator.rawFree(old_mem, log2_align, ret_addr); 357 | }; 358 | } 359 | 360 | pub fn reserve( 361 | self: *Self, 362 | entries: usize, 363 | ) !void { 364 | return self.cache.reserve(entries); 365 | } 366 | }; 367 | } 368 | 369 | ///////////////////////////////////////////////////////// 370 | /////// OrderedCache Testing Section //////////////////// 371 | 372 | fn ensureWeakOrdering(buffer: *const OrderedCache) bool { 373 | for(0..(buffer.size() - 1)) |i| { 374 | if(buffer.itemSize(i) > buffer.itemSize(i + 1)) { 375 | return false; 376 | } 377 | } 378 | return true; 379 | } 380 | 381 | test "OrderedCache: ensure weak-ordering" { 382 | 383 | const GPA = @import("std").heap.GeneralPurposeAllocator(.{}); 384 | const rand = @import("std").rand; 385 | 386 | var gpa = GPA{}; 387 | var allocator = gpa.allocator(); 388 | var buffer = OrderedCache.init(allocator); 389 | var PCG = rand.Pcg.init(42); 390 | var pcg = PCG.random(); 391 | 392 | 393 | defer { 394 | buffer.deinit(&allocator); 395 | if (gpa.deinit() == .leak) { 396 | @panic("LEAK DETECTED"); 397 | } 398 | } 399 | 400 | // Create randomly sized allocations, deposit them 401 | // and then clear, rinse, repeat. Currently, this 402 | // is run 10 * (100 + 100) times, so 2000 items. 403 | 404 | for(0..10) |_| { 405 | // some repeat elements 406 | for(0..100) |_| { 407 | var n = pcg.int(usize) % 100; 408 | n = if(n == 0) 1 else n; 409 | try buffer.deposit(try allocator.alloc(u8, n), @alignOf(u8)); 410 | } 411 | try std.testing.expectEqual(buffer.size(), 100); 412 | try std.testing.expect(ensureWeakOrdering(&buffer)); 413 | buffer.clear(&allocator); 414 | try std.testing.expectEqual(buffer.size(), 0); 415 | 416 | // many repeat elements 417 | for(0..100) |_| { 418 | var n = pcg.int(usize) % 10; 419 | n = if(n == 0) 1 else n; 420 | try buffer.deposit(try allocator.alloc(u8, n), @alignOf(u8)); 421 | } 422 | 423 | try std.testing.expectEqual(buffer.size(), 100); 424 | try std.testing.expect(ensureWeakOrdering(&buffer)); 425 | buffer.clear(&allocator); 426 | try std.testing.expectEqual(buffer.size(), 0); 427 | } 428 | } 429 | 430 | test "OrderedCache: basic heuristic testing" { 431 | 432 | var gpa = std.heap.GeneralPurposeAllocator(.{}){ }; 433 | var allocator = gpa.allocator(); 434 | var buffer = OrderedCache.init(allocator); 435 | 436 | defer { 437 | buffer.deinit(&allocator); 438 | if (gpa.deinit() == .leak) { 439 | @panic("LEAK DETECTED"); 440 | } 441 | } 442 | 443 | // Say you have items A, B, C. 444 | // 445 | // A wants 100 bytes 446 | // 447 | // B wants 300 bytes 448 | // 449 | // Then both A and B surrender their memory… so we now have cached { 100, 300 } 450 | // 451 | // Now let’s say that A is followed by C, so it’s asking for { 100, 100 }… but the cache only has { 100, 300 } 452 | // 453 | // Now if B comes back and wants memory, it’ll ask for 300 bytes again. We’re empty so we have to allocate… now we have { 100, 300, 300 }. 454 | // 455 | // Instead, if we forced C to allocate when it asked for 100 bytes and we were empty, we would end up with { 100, 100, 300 } which is ideal. 456 | // 457 | // So the heuristic just has to make sure that the actual requests are as close to what ends up in the cache… something like: 458 | // 459 | // optimize min: |sum(actual) - sum(cached)| 460 | 461 | const request1: usize = 100; 462 | const request2: usize = 300; 463 | const request3: usize = 100; 464 | 465 | { 466 | // allocate first two requests 467 | const a = try allocator.alloc(u8, request1); 468 | const b = try allocator.alloc(u8, request2); 469 | 470 | // deposit requests into cache (simulating free) 471 | try buffer.deposit(a, @alignOf(u8)); 472 | try buffer.deposit(b, @alignOf(u8)); 473 | } 474 | 475 | // cache now contains { 100, 300 } 476 | try std.testing.expectEqual(buffer.itemSize(0), 100); 477 | try std.testing.expectEqual(buffer.itemSize(1), 300); 478 | 479 | { 480 | // request memory { 100, 100, 300 } 481 | const a = buffer.withdraw(request1, @alignOf(u8)) orelse try allocator.alloc(u8, request1); 482 | const c = buffer.withdraw(request3, @alignOf(u8)) orelse try allocator.alloc(u8, request3); 483 | const b = buffer.withdraw(request2, @alignOf(u8)) orelse try allocator.alloc(u8, request2); 484 | 485 | // deposit requests into cache (simulating free) 486 | try buffer.deposit(a, @alignOf(u8)); 487 | try buffer.deposit(b, @alignOf(u8)); 488 | try buffer.deposit(c, @alignOf(u8)); 489 | } 490 | // cache should contain { 100, 100, 300 } 491 | try std.testing.expectEqual(buffer.itemSize(0), 100); 492 | try std.testing.expectEqual(buffer.itemSize(1), 100); 493 | try std.testing.expectEqual(buffer.itemSize(2), 300); 494 | } 495 | 496 | ///////////////////////////////////////////////////////// 497 | /////// LinearCachingAllocator Testing Section ///////////// 498 | 499 | test "LinearCachingAllocator: buffer size" { 500 | 501 | const TypeA = struct { 502 | x: usize = 0 503 | }; 504 | 505 | var caching_allocator = LinearCachingAllocator(.{}).init(std.heap.page_allocator); 506 | 507 | defer caching_allocator.deinit(); 508 | 509 | var allocator = caching_allocator.allocator(); 510 | 511 | const a = try allocator.alloc(TypeA, 10); 512 | 513 | allocator.free(a); 514 | 515 | try std.testing.expectEqual(caching_allocator.cache.size(), 1); 516 | } 517 | 518 | test "LinearCachingAllocator: cache utilization" { 519 | 520 | const TypeA = struct { 521 | x: usize = 0 522 | }; 523 | 524 | var caching_allocator = LinearCachingAllocator(.{}).init(std.heap.page_allocator); 525 | 526 | defer caching_allocator.deinit(); 527 | 528 | var allocator = caching_allocator.allocator(); 529 | 530 | const a = try allocator.alloc(TypeA, 10); 531 | 532 | const b = a; 533 | 534 | allocator.free(a); 535 | 536 | const c = try allocator.alloc(TypeA, 10); 537 | 538 | try std.testing.expect(b.ptr == c.ptr); 539 | } 540 | 541 | test "LinearCachingAllocator: alignment" { 542 | 543 | const TypeA = struct { 544 | x: usize = 0 545 | }; 546 | const TypeB = struct { 547 | x: usize = 0, 548 | y: bool = false 549 | }; 550 | 551 | { 552 | // ensure that log2 alignment is different... 553 | const align_a = std.math.ceilPowerOfTwoAssert(usize, @bitSizeOf(TypeA)); 554 | const align_b = std.math.ceilPowerOfTwoAssert(usize, @bitSizeOf(TypeB)); 555 | const log2_a = std.math.log2(align_a); 556 | const log2_b = std.math.log2(align_b); 557 | try std.testing.expect(log2_a < log2_b); 558 | } 559 | 560 | var caching_allocator = LinearCachingAllocator(.{}).init(std.heap.page_allocator); 561 | 562 | defer caching_allocator.deinit(); 563 | 564 | var allocator = caching_allocator.allocator(); 565 | 566 | const a = try allocator.alloc(TypeA, 10); 567 | try std.testing.expectEqual(a.len, 10); 568 | 569 | allocator.free(a); 570 | 571 | const b = try allocator.alloc(TypeB, 4); 572 | try std.testing.expectEqual(b.len, 4); 573 | 574 | try std.testing.expect(@intFromPtr(a.ptr) == @intFromPtr(b.ptr)); 575 | 576 | allocator.free(b); 577 | 578 | try std.testing.expectEqual(caching_allocator.cache.size(), 1); 579 | 580 | // attempt to iterate through items 581 | 582 | for(b) |*item| { 583 | item.x = 0; 584 | item.y = false; 585 | } 586 | } 587 | 588 | test "LinearCachingAllocator: resize" { 589 | 590 | // So testing resize is tough. Resizes can "fail" 591 | // legitimately. That's why they return a bool and 592 | // not an error. Unfortunatecan have worst case linear 593 | // performance 594 | 595 | // That said, we can keep a few things in mind: 596 | 597 | // 1. The resize function dispatches to the 598 | // backing_allocator.rawResize function. 599 | // Therfore, we would technically be 600 | // testing that ultimately. 601 | 602 | // 2. Because of 1, the only way we can be 603 | // the source of failure is by either 604 | // failing to find the memory in cache, 605 | // identifying the wrong memory, failing 606 | // to deposit the memory, or leaking the 607 | // memory after resize. 608 | 609 | // At this point, the deposit function is well 610 | // established. So we need to show that the 611 | // search function locateMemory identifies the 612 | // correct memory in cache, and returns null 613 | // on invalid requests. 614 | 615 | const rand = @import("std").rand; 616 | 617 | const TypeA = struct { 618 | x: usize = 0 619 | }; 620 | 621 | var caching_allocator = LinearCachingAllocator(.{}).init(std.heap.page_allocator); 622 | 623 | defer caching_allocator.deinit(); 624 | 625 | var allocator = caching_allocator.allocator(); 626 | 627 | var PCG = rand.Pcg.init(42); 628 | var pcg = PCG.random(); 629 | 630 | // To test locateMemory, we'll allocate in 100 631 | // elements and force it to find the element after 632 | // depositing it. 633 | 634 | for (0..100) |_| { 635 | 636 | var n = pcg.int(usize) % 100; 637 | 638 | n = if(n == 0) 1 else n; 639 | 640 | const data = try allocator.alloc(TypeA, n); 641 | 642 | // deposit into cache... 643 | allocator.free(data); 644 | 645 | const check: []u8 = std.mem.sliceAsBytes(data); 646 | 647 | // lookup memory in allocator cache... 648 | const index = caching_allocator.cache.locateMemory(check); 649 | 650 | // null means we didn't find it. 651 | try std.testing.expect(index != null); 652 | 653 | const item = caching_allocator.cache.itemData(index.?); 654 | 655 | // ensure that it is the same data. 656 | try std.testing.expect(@intFromPtr(check.ptr) == @intFromPtr(item.ptr)); 657 | } 658 | 659 | // we need to be beyond the heuristic to test. 660 | const data = try allocator.alloc(TypeA, 300); 661 | 662 | { // check that un-cached memory isn't "found". 663 | const check: []u8 = std.mem.sliceAsBytes(data); 664 | const index = caching_allocator.cache.locateMemory(check); 665 | try std.testing.expect(index == null); 666 | } 667 | 668 | // deposit into cache... 669 | allocator.free(data); 670 | 671 | { // check that cached memory is found. 672 | const check: []u8 = std.mem.sliceAsBytes(data); 673 | const index = caching_allocator.cache.locateMemory(check); 674 | try std.testing.expect(index != null); 675 | const item = caching_allocator.cache.itemData(index.?); 676 | try std.testing.expect(@intFromPtr(check.ptr) == @intFromPtr(item.ptr)); 677 | } 678 | } 679 | -------------------------------------------------------------------------------- /src/tensor_factory.zig: -------------------------------------------------------------------------------- 1 | // TensorFactory Implementation file. Before proceeding, please read the following: 2 | 3 | /////////////////////////////////// 4 | // DESIGN PHILOSOPHY (June 5, 2023) 5 | 6 | // The TensorFactory is an adapter around a provided allocator type. 7 | // It is primarily here to ensure provide automatic sizing to a given tensor. 8 | // In the future, the TensorFactory will handle things like concatenation 9 | // because that is ultimately a memory operation. 10 | 11 | // Fundamentally, this class can be avoided if you intend to use your own 12 | // allocations to assign to tensor values. The allocatoins will still be 13 | // checked if using the default functions. 14 | 15 | // Allocators still need to have the deinit() function called as per usual. 16 | 17 | // Zein import files... 18 | const std = @import("std"); 19 | const Allocator = @import("std").mem.Allocator; 20 | const Tensor = @import("./tensor.zig").Tensor; 21 | const TensorError = @import("./tensor.zig").TensorError; 22 | const SizesAndStridesType = @import("./sizes_and_strides.zig").SizeAndStride.ValueType; 23 | 24 | const SizeAndStride = @import("./sizes_and_strides.zig").SizeAndStride; 25 | const SizesAndStrides = @import("./sizes_and_strides.zig").SizesAndStrides; 26 | const OrderType = @import("./sizes_and_strides.zig").OrderType; 27 | const Rowwise = @import("./sizes_and_strides.zig").Rowwise; 28 | const Colwise = @import("./sizes_and_strides.zig").Colwise; 29 | const Ops = @import("./tensor_ops.zig"); 30 | const OpsError = @import("./tensor_ops.zig").OpsError; 31 | const contractionParse = @import("./expression_parsing.zig").contractionParse; 32 | const innerProductParse = @import("./expression_parsing.zig").innerProductParse; 33 | const contractedRank = @import("./expression_parsing.zig").contractedRank; 34 | const sliceProduct = @import("./utility.zig").sliceProduct; 35 | 36 | const LinearCachingAllocator = @import("./linear_caching_allocator.zig").LinearCachingAllocator(.{}); 37 | 38 | pub const AllocatorError = error{ UnknownObject, TensorSizeZero, TensorHasAlloc, WrongAllocator, IndexAlreadyFreed, InvalidIndex }; 39 | 40 | // used to keep track of tensor allocations 41 | const ArrayList = @import("std").ArrayList; 42 | 43 | const Mutex = @import("std").Thread.Mutex; 44 | 45 | // this is a current work around until more work can 46 | // go into the allocator itself - right now, it uses 47 | // an ArrayList of allocations to free memory. Unfortunately, 48 | // that means that it also needs an allocator itself. 49 | // at some point, this should be replaced, but it 50 | // works for now. I'm setting the BufferSize to something 51 | // large enough to handle anything reasonable (and then some...) 52 | 53 | // GPA for null initialized tensor allocators. 54 | const LCA = LinearCachingAllocator; 55 | 56 | // A number large enough that it shouldn't matter. 57 | const BufferSize = 100; 58 | var BufferMutex = Mutex{}; 59 | var LCABuffer: [BufferSize]?LCA = undefined; 60 | var LCAUsed: usize = 0; 61 | 62 | const TrackingMode = enum { start, stop, free }; 63 | 64 | // let m = current mode; 65 | // if m == start: 66 | // free: dealocate memory, m -> free 67 | // stop: no-op, m -> stop 68 | // 69 | // if m == stop: 70 | // free: dealocate memory, m -> free 71 | // start: no-op, m -> start 72 | // 73 | // if m == free: 74 | // start: no-op, m -> start 75 | // stop: no-op, m -> free 76 | 77 | const TensorFactoryConfig = struct { 78 | system_allocator: Allocator, 79 | tensor_allocator: Allocator, 80 | }; 81 | 82 | pub fn TensorFactory(comptime value_type: type) type { 83 | return struct { 84 | const Self = @This(); 85 | 86 | const SelfPtr = *Self; 87 | 88 | const ConstSelfPtr = *const Self; 89 | 90 | const ValueType = value_type; 91 | 92 | const ValueSlice = []ValueType; 93 | 94 | const SizesType = SizesAndStridesType; 95 | 96 | const TrackingData = ArrayList([]ValueType); 97 | 98 | tensor_allocator: Allocator, 99 | system_allocator: Allocator, 100 | 101 | tracking_data: TrackingData, 102 | tracking_mode: TrackingMode, 103 | 104 | pub fn init(config: TensorFactoryConfig) Self { 105 | return Self{ 106 | .tensor_allocator = config.tensor_allocator, 107 | .system_allocator = config.system_allocator, 108 | .tracking_data = TrackingData.init(config.system_allocator), 109 | .tracking_mode = TrackingMode.free, 110 | }; 111 | } 112 | 113 | pub fn deinit(self: SelfPtr) void { 114 | self.tracking(.free); 115 | self.tracking_data.deinit(); 116 | } 117 | 118 | /////////////////////////////////// 119 | // private allocation functions /// 120 | 121 | fn allocValues(self: SelfPtr, size: usize) !ValueSlice { 122 | const alloc = try self.tensor_allocator.alloc(ValueType, size); 123 | 124 | if (self.tracking_mode == .start) { 125 | try self.tracking_data.append(alloc); 126 | } 127 | return alloc; 128 | } 129 | 130 | fn freeValues(self: SelfPtr, values: ValueSlice) !void { 131 | if (self.tracking_mode == .start) { 132 | for (self.tracking_data.items) |data| { 133 | if (values.ptr == data.ptr) { 134 | return; 135 | } 136 | } 137 | try self.tracking_data.append(values); 138 | } else { 139 | self.tensor_allocator.free(values); 140 | } 141 | } 142 | 143 | /////////////////////////////////// 144 | // Change the tracking mode /////// 145 | 146 | pub fn tracking(self: SelfPtr, mode: TrackingMode) void { 147 | if (self.tracking_mode == .free and mode == .stop) { 148 | return; // free is inherently not tracking, so stay free 149 | } 150 | if ((self.tracking_mode == .start or self.tracking_mode == .stop) and mode == .free) { 151 | while (self.tracking_data.items.len > 0) { 152 | self.tensor_allocator.free(self.tracking_data.pop()); 153 | } 154 | } 155 | self.tracking_mode = mode; 156 | } 157 | 158 | ////////////////////////////////// 159 | // Tensor Allocation functions /// 160 | 161 | pub fn allocToTensor(self: SelfPtr, tensor: anytype) !void { 162 | if (tensor.*.valueSize() != 0) { 163 | return AllocatorError.TensorHasAlloc; 164 | } 165 | tensor.values = try self.allocValues(tensor.valueCapacity()); 166 | } 167 | 168 | pub fn freeFromTensor(self: SelfPtr, tensor: anytype) !void { 169 | try self.freeValues(tensor.*.values); 170 | tensor.values = &[_]ValueType{}; 171 | } 172 | 173 | pub fn allocTensor(self: SelfPtr, comptime rank: usize, comptime order: OrderType, sizes: [rank]SizesType) !Tensor(ValueType, rank, order) { 174 | const size = sliceProduct(SizesType, &sizes); 175 | 176 | if (size == 0) { 177 | return AllocatorError.TensorSizeZero; 178 | } 179 | const alloc = try self.allocValues(size); 180 | 181 | return Tensor(ValueType, rank, order){ 182 | .values = alloc, 183 | .sizes_and_strides = SizesAndStrides(rank, order).init(sizes), 184 | }; 185 | } 186 | 187 | pub fn copyTensor(self: SelfPtr, tensor: anytype) !@TypeOf(tensor.*) { 188 | const T = @TypeOf(tensor.*); 189 | 190 | const alloc = try self.allocValues(tensor.valueSize()); 191 | 192 | @memcpy(alloc, tensor.values); 193 | 194 | return T{ 195 | .values = alloc, 196 | .sizes_and_strides = tensor.sizes_and_strides, 197 | }; 198 | } 199 | 200 | ///////////////////////////// 201 | // Factory Math Functions /// 202 | 203 | pub fn add(self: SelfPtr, x: anytype, y: anytype) !@TypeOf(x.*) { 204 | var z = try self.allocTensor(@TypeOf(x.*).Rank, @TypeOf(x.*).Order, x.sizes_and_strides.sizes); 205 | Ops.add(x, y, &z); 206 | return z; 207 | } 208 | 209 | pub fn sub(self: SelfPtr, x: anytype, y: anytype) !@TypeOf(x.*) { 210 | var z = try self.allocTensor(@TypeOf(x.*).Rank, @TypeOf(x.*).Order, x.sizes_and_strides.sizes); 211 | Ops.sub(x, y, &z); 212 | return z; 213 | } 214 | 215 | pub fn mul(self: SelfPtr, x: anytype, y: anytype) !@TypeOf(x.*) { 216 | var z = try self.allocTensor(@TypeOf(x.*).Rank, @TypeOf(x.*).Order, x.sizes_and_strides.sizes); 217 | Ops.mul(x, y, &z); 218 | return z; 219 | } 220 | 221 | pub fn bias(self: SelfPtr, x: anytype, b: @TypeOf(x.*).ValueType) !@TypeOf(x.*) { 222 | var y = try self.allocTensor(@TypeOf(x.*).Rank, @TypeOf(x.*).Order, x.sizes_and_strides.sizes); 223 | Ops.bias(x, &y, b); 224 | return y; 225 | } 226 | 227 | pub fn scale(self: SelfPtr, x: anytype, s: @TypeOf(x.*).ValueType) !@TypeOf(x.*) { 228 | if (!x.isValid()) { 229 | return TensorError.InvalidTensorLayout; 230 | } 231 | var y = try self.allocTensor(@TypeOf(x.*).Rank, @TypeOf(x.*).Order, x.sizes_and_strides.sizes); 232 | Ops.scale(x, &y, s); 233 | return y; 234 | } 235 | 236 | pub fn contraction(self: SelfPtr, comptime expression: []const u8, x: anytype) !Tensor(ValueType, contractedRank(expression), @TypeOf(x.*).Order) { 237 | std.debug.assert(x.isValid()); 238 | 239 | const XRank = @TypeOf(x.*).Rank; 240 | const YRank = comptime contractedRank(expression); 241 | const ip = comptime contractionParse(XRank, YRank, expression); 242 | 243 | var y_ss: [YRank]SizesType = undefined; 244 | { 245 | var i: usize = 0; 246 | while (i < YRank) : (i += 1) { 247 | y_ss[i] = x.sizes_and_strides.sizes[ip.lhs[i]]; 248 | } 249 | } 250 | var y = try self.allocTensor(YRank, @TypeOf(x.*).Order, y_ss); 251 | 252 | var xc: [XRank]SizesType = undefined; 253 | var yc: [YRank]SizesType = undefined; 254 | 255 | @memset(y.values, 0); 256 | 257 | @call(.always_inline, Ops.recursiveContraction, .{ ValueType, SizesType, XRank, YRank, ip.lhs, ip.rhs, 0, x, &y, &xc, &yc }); 258 | return y; 259 | } 260 | 261 | pub fn innerProduct(self: SelfPtr, comptime expression: []const u8, x: anytype, y: anytype) !Tensor(ValueType, contractedRank(expression), @TypeOf(x.*).Order) { 262 | std.debug.assert(x.isValid() and y.isValid()); 263 | 264 | const XRank = @TypeOf(x.*).Rank; 265 | const YRank = @TypeOf(y.*).Rank; 266 | const ZRank = comptime contractedRank(expression); 267 | const plan = comptime innerProductParse(XRank, YRank, ZRank, expression); 268 | 269 | var z_sizes: [ZRank]SizesType = undefined; 270 | { 271 | var i: usize = 0; 272 | while (i < plan.total) : (i += 1) { 273 | if (plan.z_perm[i] == plan.pass) { 274 | continue; 275 | } else if (plan.s_ctrl[i] == 0) { 276 | z_sizes[plan.z_perm[i]] = x.getSize(plan.x_perm[i]); 277 | } else { 278 | z_sizes[plan.z_perm[i]] = y.getSize(plan.y_perm[i]); 279 | } 280 | } 281 | } 282 | 283 | var z = try self.allocTensor(ZRank, @TypeOf(x.*).Order, z_sizes); 284 | 285 | var x_i: [XRank]SizesType = undefined; 286 | var y_i: [YRank]SizesType = undefined; 287 | var z_i: [ZRank]SizesType = undefined; 288 | 289 | @memset(z.values, 0); 290 | 291 | @call(.always_inline, Ops.recursiveInnerProduct, .{ @TypeOf(x.*).ValueType, SizesType, 0, plan, x, y, &z, &x_i, &y_i, &z_i }); 292 | 293 | return z; 294 | } 295 | }; 296 | } 297 | 298 | test "Allocate and Free" { 299 | 300 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 301 | 302 | const expect = std.testing.expect; 303 | 304 | var factory = TensorFactory(f32).init(.{ 305 | .system_allocator = gpa.allocator(), 306 | .tensor_allocator = gpa.allocator(), 307 | }); 308 | 309 | defer { 310 | factory.deinit(); 311 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 312 | } 313 | 314 | ///////////////////////////////////////// 315 | { // assign into to tensor ////////////// 316 | var X = Tensor(f32, 2, Rowwise).init(null, .{ 10, 10 }); 317 | 318 | // create 100 elements... 10x10 319 | try factory.allocToTensor(&X); 320 | try expect(X.valueSize() == 100); 321 | 322 | // tensor slice should be reset 323 | try factory.freeFromTensor(&X); 324 | try expect(X.valueSize() == 0); 325 | } 326 | ///////////////////////////////////////// 327 | { // assign directly to tensor ////////// 328 | var X = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 329 | 330 | // create 100 elements... 10x10 331 | try expect(X.valueSize() == 100); 332 | 333 | // tensor slice should be reset 334 | try factory.freeFromTensor(&X); 335 | try expect(X.valueSize() == 0); 336 | } 337 | 338 | factory.tracking(.start); // beging tracking allocations 339 | 340 | /////////////////////////////////////// 341 | { // assign directly to tensor ////////// 342 | var X = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 343 | var Y = try factory.copyTensor(&X); 344 | 345 | // create 100 elements... 10x10 346 | try expect(X.valueSize() == 100); 347 | try expect(Y.valueSize() == 100); 348 | 349 | try factory.freeFromTensor(&X); 350 | // do not free y... use deinit 351 | 352 | // tensor slice should be reset 353 | try expect(X.valueSize() == 0); 354 | } 355 | 356 | // make 3 tensors and do not free them 357 | var X = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 358 | var Y = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 359 | var Z = try factory.allocTensor(2, Rowwise, .{ 10, 10 }); 360 | 361 | // trivial operation to avoid compile error 362 | X.setValue(3, .{ 0, 1 }); 363 | Y.setValue(3, .{ 0, 1 }); 364 | Z.setValue(3, .{ 0, 1 }); 365 | } 366 | 367 | test "vectorized reduce" { 368 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 369 | 370 | var factory = TensorFactory(f32).init(.{ 371 | .system_allocator = gpa.allocator(), 372 | .tensor_allocator = gpa.allocator(), 373 | }); 374 | 375 | defer { 376 | factory.deinit(); 377 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 378 | } 379 | 380 | factory.tracking(.start); 381 | 382 | const x = try factory.allocTensor(2, Rowwise, .{ 100, 100 }); 383 | 384 | @memset(x.values, 1); 385 | 386 | { // reduce sum of 10'000 elements 387 | const y = Ops.sum(&x); 388 | try std.testing.expectEqual(y, 10000); 389 | } 390 | { // reduce product of 10'000 elements 391 | const y = Ops.product(&x); 392 | try std.testing.expectEqual(y, 1); 393 | } 394 | { // reduce max of 10'000 elements 395 | x.setValue(999, .{ 24, 62 }); 396 | const y = Ops.max(&x); 397 | try std.testing.expectEqual(y, 999); 398 | } 399 | { // reduce max of 10'000 elements 400 | x.setValue(-999, .{ 92, 10 }); 401 | const y = Ops.min(&x); 402 | try std.testing.expectEqual(y, -999); 403 | } 404 | } 405 | 406 | test "contraction" { 407 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 408 | 409 | var factory = TensorFactory(f32).init(.{ 410 | .system_allocator = gpa.allocator(), 411 | .tensor_allocator = gpa.allocator(), 412 | }); 413 | 414 | defer { 415 | factory.deinit(); 416 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 417 | } 418 | 419 | factory.tracking(.start); 420 | 421 | var x = try factory.allocTensor(3, Rowwise, .{ 3, 4, 3 }); 422 | 423 | @memset(x.values, 1); 424 | 425 | const y = try factory.contraction("ijk->i", &x); 426 | 427 | try std.testing.expectEqual(y.values[0], 12); 428 | try std.testing.expectEqual(y.values[1], 12); 429 | try std.testing.expectEqual(y.values[2], 12); 430 | 431 | const z = try factory.contraction("ijk->j", &x); 432 | 433 | try std.testing.expectEqual(z.values[0], 9); 434 | try std.testing.expectEqual(z.values[1], 9); 435 | try std.testing.expectEqual(z.values[2], 9); 436 | try std.testing.expectEqual(z.values[3], 9); 437 | } 438 | 439 | test "contraction 2" { 440 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 441 | 442 | var factory = TensorFactory(f32).init(.{ 443 | .system_allocator = gpa.allocator(), 444 | .tensor_allocator = gpa.allocator(), 445 | }); 446 | 447 | defer { 448 | factory.deinit(); 449 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 450 | } 451 | 452 | factory.tracking(.start); 453 | 454 | var x = try factory.allocTensor(3, Rowwise, .{ 3, 4, 3 }); 455 | var y = try factory.allocTensor(2, Rowwise, .{ 3, 4 }); 456 | var z = try factory.allocTensor(2, Rowwise, .{ 4, 3 }); 457 | 458 | Ops.fill(&x, 1, 1); 459 | 460 | Ops.contraction("ijk->ij", &x, &y); 461 | 462 | try std.testing.expectEqual(y.values[0], 6); 463 | try std.testing.expectEqual(y.values[1], 15); 464 | try std.testing.expectEqual(y.values[2], 24); 465 | try std.testing.expectEqual(y.values[3], 33); 466 | try std.testing.expectEqual(y.values[4], 42); 467 | try std.testing.expectEqual(y.values[5], 51); 468 | try std.testing.expectEqual(y.values[6], 60); 469 | try std.testing.expectEqual(y.values[7], 69); 470 | try std.testing.expectEqual(y.values[8], 78); 471 | try std.testing.expectEqual(y.values[9], 87); 472 | try std.testing.expectEqual(y.values[10], 96); 473 | try std.testing.expectEqual(y.values[11], 105); 474 | 475 | Ops.contraction("ijk->ji", &x, &z); 476 | 477 | try std.testing.expectEqual(z.values[0], 6); 478 | try std.testing.expectEqual(z.values[1], 42); 479 | try std.testing.expectEqual(z.values[2], 78); 480 | try std.testing.expectEqual(z.values[3], 15); 481 | try std.testing.expectEqual(z.values[4], 51); 482 | try std.testing.expectEqual(z.values[5], 87); 483 | try std.testing.expectEqual(z.values[6], 24); 484 | try std.testing.expectEqual(z.values[7], 60); 485 | try std.testing.expectEqual(z.values[8], 96); 486 | try std.testing.expectEqual(z.values[9], 33); 487 | try std.testing.expectEqual(z.values[10], 69); 488 | try std.testing.expectEqual(z.values[11], 105); 489 | 490 | Ops.contraction("ijk->jk", &x, &z); 491 | 492 | try std.testing.expectEqual(z.values[0], 39); 493 | try std.testing.expectEqual(z.values[1], 42); 494 | try std.testing.expectEqual(z.values[2], 45); 495 | try std.testing.expectEqual(z.values[3], 48); 496 | try std.testing.expectEqual(z.values[4], 51); 497 | try std.testing.expectEqual(z.values[5], 54); 498 | try std.testing.expectEqual(z.values[6], 57); 499 | try std.testing.expectEqual(z.values[7], 60); 500 | try std.testing.expectEqual(z.values[8], 63); 501 | try std.testing.expectEqual(z.values[9], 66); 502 | try std.testing.expectEqual(z.values[10], 69); 503 | try std.testing.expectEqual(z.values[11], 72); 504 | } 505 | 506 | test "inner product 1" { 507 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 508 | 509 | var factory = TensorFactory(f32).init(.{ 510 | .system_allocator = gpa.allocator(), 511 | .tensor_allocator = gpa.allocator(), 512 | }); 513 | 514 | defer { 515 | factory.deinit(); 516 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 517 | } 518 | 519 | 520 | factory.tracking(.start); 521 | 522 | var x = try factory.allocTensor(2, Rowwise, .{ 2, 2 }); 523 | var y = try factory.allocTensor(2, Rowwise, .{ 2, 2 }); 524 | var z = try factory.allocTensor(2, Rowwise, .{ 2, 2 }); 525 | 526 | Ops.fill(&x, 1, 0); 527 | Ops.fill(&y, 1, 1); 528 | 529 | Ops.innerProduct("ij,jk->ik", &x, &y, &z); 530 | 531 | try std.testing.expectEqual(z.values[0], 4); 532 | try std.testing.expectEqual(z.values[1], 6); 533 | try std.testing.expectEqual(z.values[2], 4); 534 | try std.testing.expectEqual(z.values[3], 6); 535 | 536 | Ops.innerProduct("ij,jk->ki", &x, &y, &z); 537 | 538 | try std.testing.expectEqual(z.values[0], 4); 539 | try std.testing.expectEqual(z.values[1], 4); 540 | try std.testing.expectEqual(z.values[2], 6); 541 | try std.testing.expectEqual(z.values[3], 6); 542 | } 543 | 544 | test "inner product 2" { 545 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 546 | 547 | var factory = TensorFactory(f32).init(.{ 548 | .system_allocator = gpa.allocator(), 549 | .tensor_allocator = gpa.allocator(), 550 | }); 551 | 552 | defer { 553 | factory.deinit(); 554 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 555 | } 556 | 557 | factory.tracking(.start); 558 | 559 | var x = try factory.allocTensor(3, Rowwise, .{ 2, 3, 2 }); 560 | var y = try factory.allocTensor(3, Rowwise, .{ 2, 3, 2 }); 561 | 562 | Ops.fill(&x, 0, 1); 563 | Ops.fill(&y, 0, 1); 564 | 565 | const z = try factory.innerProduct("ijk,kjm->im", &x, &y); 566 | 567 | try std.testing.expectEqual(z.values[0], 100); 568 | try std.testing.expectEqual(z.values[1], 115); 569 | try std.testing.expectEqual(z.values[2], 280); 570 | try std.testing.expectEqual(z.values[3], 331); 571 | 572 | const w = try factory.innerProduct("ikj,jkl->kl", &x, &y); 573 | 574 | try std.testing.expectEqual(w.values[0], 48); 575 | try std.testing.expectEqual(w.values[1], 62); 576 | try std.testing.expectEqual(w.values[2], 116); 577 | try std.testing.expectEqual(w.values[3], 138); 578 | try std.testing.expectEqual(w.values[4], 216); 579 | try std.testing.expectEqual(w.values[5], 246); 580 | } 581 | 582 | test "arithmetic 1" { 583 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 584 | 585 | // null uses the general purpose allocator. 586 | // It also means that it will call deinit 587 | // on the gpa allocator when we call deinit. 588 | var factory = TensorFactory(f32).init(.{ 589 | .system_allocator = gpa.allocator(), 590 | .tensor_allocator = gpa.allocator(), 591 | }); 592 | 593 | defer { 594 | factory.deinit(); 595 | if (gpa.deinit() == .leak) @panic("!!! LEAK DETECTED !!!"); 596 | } 597 | 598 | factory.tracking(.start); 599 | 600 | var x = try factory.allocTensor(1, Rowwise, .{100_000}); 601 | var y = try factory.allocTensor(1, Rowwise, .{100_000}); 602 | 603 | Ops.fill(&x, 1, 0); 604 | Ops.fill(&y, 2, 0); 605 | 606 | // factory versions... 607 | { 608 | var z = try factory.add(&x, &y); 609 | const s = Ops.sum(&z); 610 | try std.testing.expect(s == 300_000); 611 | } 612 | { 613 | var z = try factory.mul(&x, &y); 614 | const s = Ops.sum(&z); 615 | try std.testing.expect(s == 200_000); 616 | } 617 | { 618 | var z = try factory.sub(&x, &y); 619 | const s = Ops.sum(&z); 620 | try std.testing.expect(s == -100_000); 621 | } 622 | { 623 | const b: i64 = 4; 624 | var z = try factory.bias(&x, b); 625 | const s = Ops.sum(&z); 626 | try std.testing.expect(s == 500_000); 627 | } 628 | { 629 | const b: i64 = 4; 630 | var z = try factory.scale(&x, b); 631 | const s = Ops.sum(&z); 632 | try std.testing.expect(s == 400_000); 633 | } 634 | 635 | var z = try factory.allocTensor(1, Rowwise, .{100_000}); 636 | 637 | // free versions... 638 | { 639 | Ops.add(&x, &y, &z); 640 | const s = Ops.sum(&z); 641 | try std.testing.expect(s == 300_000); 642 | } 643 | { 644 | Ops.mul(&x, &y, &z); 645 | const s = Ops.sum(&z); 646 | try std.testing.expect(s == 200_000); 647 | } 648 | { 649 | Ops.sub(&x, &y, &z); 650 | const s = Ops.sum(&z); 651 | try std.testing.expect(s == -100_000); 652 | } 653 | { 654 | const b: i64 = 4; 655 | Ops.bias(&x, &z, b); 656 | const s = Ops.sum(&z); 657 | try std.testing.expect(s == 500_000); 658 | } 659 | { 660 | const b: i64 = 4; 661 | Ops.scale(&x, &z, b); 662 | const s = Ops.sum(&z); 663 | try std.testing.expect(s == 400_000); 664 | } 665 | } 666 | 667 | -------------------------------------------------------------------------------- /src/tensor_ops.zig: -------------------------------------------------------------------------------- 1 | // DESIGN PHILOSOPHY June 6th, 2023 // 2 | 3 | // The goal for V1 is simple. Provide reliable (albeit naive) functionality 4 | // that focuses on correctness first. Once that is established, V2 can use 5 | // V1 as a reference for future versions, therefore creating a baseline 6 | // for correctness. As such, the current goal is to provide a complete set 7 | // of functionalities and replace them with more optimal solutions over time. 8 | 9 | const std = @import("std"); 10 | const ReduceOp = std.builtin.ReduceOp; 11 | const math = std.math; 12 | 13 | const Util = @import("utility.zig"); 14 | const Tensor = @import("./tensor.zig").Tensor; 15 | const TensorError = @import("./tensor.zig").TensorError; 16 | const Rowwise = @import("./sizes_and_strides.zig").Rowwise; 17 | const Colwise = @import("./sizes_and_strides.zig").Colwise; 18 | const SizeType = @import("./sizes_and_strides.zig").SizeAndStride.ValueType; 19 | 20 | pub const InnerProductPlan = @import("./expression_parsing.zig").InnerProductPlan; 21 | pub const defaultPermuation = @import("./sizes_and_strides.zig").defaultPermutation; 22 | pub const contractionParse = @import("./expression_parsing.zig").contractionParse; 23 | pub const innerProductParse = @import("./expression_parsing.zig").innerProductParse; 24 | pub const outerProductParse = @import("./expression_parsing.zig").outerProductParse; 25 | pub const computeTensorIndex = @import("./tensor.zig").computeTensorIndex; 26 | 27 | pub const OpsError = error{ UnequalSize, InvalidDimensions, InvalidSizes, SizeZeroTensor, IntegerOverflow }; 28 | 29 | inline fn reduceInit(comptime op: ReduceOp, comptime T: type) T { 30 | 31 | const info = @typeInfo(T); 32 | 33 | return switch (op) { 34 | .Add => 0, // implicit cast 35 | .Mul => 1, // implicit cast 36 | .Min => if (comptime info == .Int) 37 | math.maxInt(T) else math.floatMax(T), 38 | .Max => if (comptime info == .Int) 39 | math.minInt(T) else -math.floatMax(T), 40 | else => @compileError("reduceInit: unsupported op"), 41 | }; 42 | } 43 | 44 | pub fn sum(x: anytype) @TypeOf(x.*).ValueType { 45 | std.debug.assert(x.valueSize() > 0); 46 | return simdReduce(ReduceOp.Add, addGeneric, x, reduceInit(ReduceOp.Add, @TypeOf(x.*).ValueType)); 47 | } 48 | pub fn product(x: anytype) @TypeOf(x.*).ValueType { 49 | std.debug.assert(x.valueSize() > 0); 50 | return simdReduce(ReduceOp.Mul, mulGeneric, x, reduceInit(ReduceOp.Mul, @TypeOf(x.*).ValueType)); 51 | } 52 | 53 | pub fn min(x: anytype) @TypeOf(x.*).ValueType { 54 | std.debug.assert(x.valueSize() > 0); 55 | return simdReduce(ReduceOp.Min, minGeneric, x, reduceInit(ReduceOp.Min, @TypeOf(x.*).ValueType)); 56 | } 57 | pub fn max(x: anytype) @TypeOf(x.*).ValueType { 58 | std.debug.assert(x.valueSize() > 0); 59 | return simdReduce(ReduceOp.Max, maxGeneric, x, reduceInit(ReduceOp.Max, @TypeOf(x.*).ValueType)); 60 | } 61 | 62 | // TODO: Address the issue with checked vs unchecked absGeneric at call sight 63 | pub fn absmax(x: anytype) @TypeOf(x.*).ValueType { 64 | return simdMapReduce(ReduceOp.Max, absGenericUnchecked, maxGeneric, x, reduceInit(ReduceOp.Max, @TypeOf(x.*).ValueType)); 65 | } 66 | 67 | // TODO: Address the issue with checked vs unchecked absGeneric at call sight 68 | pub fn absmin(x: anytype) @TypeOf(x.*).ValueType { 69 | return simdMapReduce(ReduceOp.Min, absGenericUnchecked, maxGeneric, x, reduceInit(ReduceOp.Min, @TypeOf(x.*).ValueType)); 70 | } 71 | 72 | // TODO: does this belong here? 73 | pub fn fill( 74 | x: anytype, 75 | init: @TypeOf(x.*).ValueType, 76 | step: @TypeOf(x.*).ValueType 77 | ) void { 78 | var incr = init; 79 | for (x.values) |*value| { 80 | value.* = incr; 81 | incr += step; 82 | } 83 | } 84 | 85 | ////////////////////////////////////////////////////////////// 86 | ///////// BINARY ARITHMETIC FUNCTIONS //////////////////////// 87 | 88 | fn elementwiseCheck(x: anytype, y: anytype, z: anytype) void { 89 | if (comptime @TypeOf(x) != @TypeOf(y) or @TypeOf(y) != @TypeOf(z)) { 90 | @compileError("Mismatched tensor types for addition."); 91 | } 92 | std.debug.assert(x.isValid() and y.isValid() and z.isValid()); 93 | std.debug.assert(x.valueSize() == y.valueSize() and y.valueSize() == z.valueSize()); 94 | } 95 | 96 | pub fn add(x: anytype, y: anytype, z: anytype) void { 97 | elementwiseCheck(x, y, z); 98 | simdArithmetic(addGeneric, x, y, z); 99 | } 100 | 101 | // <>--------------------------------------------------------<> 102 | 103 | pub fn sub(x: anytype, y: anytype, z: anytype) void { 104 | elementwiseCheck(x, y, z); 105 | simdArithmetic(subGeneric, x, y, z); 106 | } 107 | 108 | // <>--------------------------------------------------------<> 109 | 110 | // TODO: should this be called mul? It's actually a hadamard 111 | pub fn mul(x: anytype, y: anytype, z: anytype) void { 112 | elementwiseCheck(x, y, z); 113 | simdArithmetic(mulGeneric, x, y, z); 114 | } 115 | 116 | // <>--------------------------------------------------------<> 117 | 118 | // TODO: scale seems like a bad name? 119 | pub fn scale(x: anytype, y: @TypeOf(x), s: @TypeOf(x.*).ValueType) void { 120 | std.debug.assert(x.isValid() and y.isValid()); 121 | std.debug.assert(x.valueSize() == y.valueSize()); 122 | simdScalarBroadcast(mulGeneric, x, y, s); 123 | } 124 | 125 | // <>--------------------------------------------------------<> 126 | 127 | pub fn bias(x: anytype, y: @TypeOf(x), b: @TypeOf(x.*).ValueType) void { 128 | std.debug.assert(x.isValid() and y.isValid()); 129 | std.debug.assert(x.valueSize() == y.valueSize()); 130 | simdScalarBroadcast(addGeneric, x, y, b); 131 | } 132 | 133 | // <>--------------------------------------------------------<> 134 | 135 | inline fn quantizeGeneric(comptime int: type, x: anytype) int { 136 | return @intFromFloat(@round(x * comptime @as(@TypeOf(x), math.maxInt(int)))); 137 | } 138 | 139 | pub fn quantize(x: anytype, y: anytype) @TypeOf(x.*).ValueType { 140 | const m = absmax(x); 141 | 142 | if (m > 1.0) { 143 | const s = 1.0 / m; 144 | var i: usize = 0; 145 | while (i < x.values.len) : (i += 1) { 146 | y.values[i] = quantizeGeneric(@TypeOf(y.*).ValueType, x.values[i] * s); 147 | } 148 | } else { 149 | var i: usize = 0; 150 | while (i < 100) : (i += 1) { 151 | y.values[i] = quantizeGeneric(@TypeOf(y.*).ValueType, x.values[i]); 152 | } 153 | } 154 | return m; 155 | } 156 | 157 | // <>--------------------------------------------------------<> 158 | 159 | inline fn unquantizeGeneric(comptime float: type, x: anytype) float { 160 | return @as(float, @floatFromInt(x)) / comptime @as(float, @floatFromInt(math.maxInt(@TypeOf(x)))); 161 | } 162 | 163 | pub fn unquantize(x: anytype, y: anytype, s: @TypeOf(y.*).ValueType) void { 164 | const FT = @TypeOf(y.*).ValueType; 165 | 166 | if (s > 1.0) { 167 | var i: usize = 0; 168 | while (i < x.values.len) : (i += 1) { 169 | y.values[i] = s * unquantizeGeneric(FT, x.values[i]); 170 | } 171 | } else { 172 | var i: usize = 0; 173 | while (i < 100) : (i += 1) { 174 | y.values[i] = unquantizeGeneric(FT, x.values[i]); 175 | } 176 | } 177 | } 178 | 179 | 180 | ///////////////////////////////////////////////////////////// 181 | // This is the naive version of a general tensor permutation. 182 | // In the future, I plan on making more optimal versions of 183 | // this, but it's reliable baseline for future work. 184 | // 185 | // If all goes well, it will unroll to something like this: 186 | // 187 | // for i..I 188 | // indices[0] = i 189 | // for j..J 190 | // indices[1] = j 191 | // ... 192 | // for n..N 193 | // scratch[count] = x.getValue(indices); 194 | // count += 1 195 | // 196 | 197 | pub inline fn recursivePermutate( 198 | comptime VT: type, // value type 199 | comptime IT: type, // int type 200 | comptime R: usize, // tensor rank 201 | comptime I: usize, // starting index 202 | x: anytype, // source tensor 203 | y: []VT, // destination memory 204 | c: *[R]IT, // index container 205 | n: *IT, // scratch counter 206 | ) void { 207 | if (I == (R - 1)) { 208 | // we only need to make this once really... 209 | const x_ss: @Vector(R, IT) = x.*.sizes_and_strides.strides; 210 | 211 | var i: IT = 0; 212 | var n_i = n.*; 213 | while (i < x.*.getSize(I)) : ({ 214 | i += 1; 215 | n_i += 1; 216 | }) { 217 | c[I] = i; 218 | const x_c: @Vector(R, IT) = c.*; 219 | const x_i = @reduce(ReduceOp.Add, x_c * x_ss); 220 | 221 | y[n_i] = x.*.values[x_i]; 222 | } 223 | n.* += i; 224 | } else { 225 | var i: IT = 0; 226 | while (i < x.*.getSize(I)) : (i += 1) { 227 | c[I] = i; 228 | 229 | @call(.always_inline, recursivePermutate, .{ VT, IT, R, (I + 1), x, y, c, n }); 230 | } 231 | } 232 | } 233 | 234 | ///////////////////////////////////////////////////////////// 235 | // This is the naive version of a general tensor contraction. 236 | // In the future, I plan on making more optimal versions of 237 | // this, but it's reliable baseline for future work. 238 | // 239 | // If all goes well, it will unroll to something like this: 240 | // 241 | // for i..I 242 | // x_indices[0] = i 243 | // y_indices[0] = i 244 | // for j..J 245 | // x_indices[1] = j 246 | // y_indices[1] = j 247 | // ... 248 | // for n..N 249 | // x_indices[I] = n; 250 | // y[y_indices] += x.getValue(x_indices); 251 | 252 | pub fn contraction(comptime expression: []const u8, x: anytype, y: anytype) void { 253 | std.debug.assert(x.isValid() and y.isValid()); 254 | 255 | const XT = @TypeOf(x.*); 256 | const YT = @TypeOf(y.*); 257 | const ip = comptime contractionParse(XT.Rank, YT.Rank, expression); 258 | 259 | if (comptime Util.debug) { 260 | const xs = x.getSizes(); 261 | const ys = y.getSizes(); 262 | var i: usize = 1; 263 | while (i < YT.Rank) : (i += 1) { 264 | std.debug.assert(xs[ip.lhs[i]] == ys[ip.rhs[i]]); 265 | } 266 | } 267 | var xc: [XT.Rank]XT.SizesType = undefined; 268 | var yc: [YT.Rank]YT.SizesType = undefined; 269 | 270 | @memset(y.values, 0); 271 | 272 | @call(.always_inline, recursiveContraction, .{ XT.ValueType, XT.SizesType, XT.Rank, YT.Rank, ip.lhs, ip.rhs, 0, x, y, &xc, &yc }); 273 | } 274 | 275 | pub inline fn recursiveContraction( 276 | comptime VT: type, // value type 277 | comptime IT: type, // int type 278 | comptime XR: usize, // tensor x rank 279 | comptime YR: usize, // tensor y rank 280 | comptime xp: [XR]IT, // x permutation 281 | comptime yp: [YR]IT, // y permutation 282 | comptime I: usize, // starting index 283 | x: anytype, // source tensor 284 | y: anytype, // destination memory 285 | xc: *[XR]IT, // index container 286 | yc: *[YR]IT, // index container 287 | ) void { 288 | if (XR <= YR) { 289 | @compileError("Contraction must go from a larger tensor to a smaller one."); 290 | } 291 | 292 | if (I < YR) { 293 | const x_perm_index = xp[I]; 294 | const y_perm_index = yp[I]; 295 | 296 | // this first branch loads up the x and y indices 297 | // and passes them to the next loop. In this case, 298 | // I is still in bounds of both x and y ranks. 299 | 300 | var i: IT = 0; 301 | while (i < x.getSize(x_perm_index)) : (i += 1) { 302 | xc[x_perm_index] = i; 303 | yc[y_perm_index] = i; 304 | 305 | @call(.always_inline, recursiveContraction, .{ VT, IT, XR, YR, xp, yp, (I + 1), x, y, xc, yc }); 306 | } 307 | } else if ((YR <= I) and (I < (XR - 1))) { 308 | 309 | // the second branch deals with values of I that are 310 | // out-of-bounds for y rank, but still in-bounds for 311 | // the x rank. 312 | 313 | const x_perm_index = xp[I]; 314 | 315 | var i: IT = 0; 316 | while (i < x.getSize(x_perm_index)) : (i += 1) { 317 | xc[x_perm_index] = i; 318 | 319 | @call(.always_inline, recursiveContraction, .{ VT, IT, XR, YR, xp, yp, (I + 1), x, y, xc, yc }); 320 | } 321 | } else { 322 | 323 | // the third branch deals with summing up the contracted 324 | // indices and writing them to the related y index 325 | 326 | const x_ss: @Vector(XR, IT) = x.*.sizes_and_strides.strides; 327 | 328 | const x_perm_index = xp[I]; 329 | 330 | var i: IT = 0; 331 | var t: VT = 0; 332 | while (i < x.getSize(x_perm_index)) : (i += 1) { 333 | xc[x_perm_index] = i; 334 | const x_c: @Vector(XR, IT) = xc.*; 335 | const x_i = @reduce(ReduceOp.Add, x_c * x_ss); 336 | t += x.values[x_i]; // accumulate summations 337 | } 338 | const y_ss: @Vector(YR, IT) = y.sizes_and_strides.strides; 339 | const y_c: @Vector(YR, IT) = yc.*; 340 | const y_i = @reduce(ReduceOp.Add, y_c * y_ss); 341 | y.*.values[y_i] += t; 342 | } 343 | } 344 | 345 | // <>--------------------------------------------------------<> 346 | 347 | // TODO: Add explanation for this crazy thing... 348 | 349 | pub fn innerProduct(comptime expression: []const u8, x: anytype, y: anytype, z: anytype) void { 350 | std.debug.assert(x.isValid() and y.isValid() and z.isValid()); 351 | 352 | const XT = @TypeOf(x.*); 353 | const YT = @TypeOf(y.*); 354 | const ZT = @TypeOf(z.*); 355 | 356 | const plan = comptime innerProductParse(XT.Rank, YT.Rank, ZT.Rank, expression); 357 | 358 | if (comptime Util.debug) { 359 | for (0..plan.total) |i| { 360 | if (plan.x_perm[i] != plan.pass and plan.y_perm[i] != plan.pass) { 361 | std.debug.assert(x.getSize(plan.x_perm[i]) == y.getSize(plan.y_perm[i])); 362 | } 363 | // TODO: Add a check for output dimensions... 364 | } 365 | } 366 | 367 | var x_i: [XT.Rank]XT.SizesType = undefined; 368 | var y_i: [YT.Rank]YT.SizesType = undefined; 369 | var z_i: [ZT.Rank]ZT.SizesType = undefined; 370 | 371 | @memset(z.values, 0); 372 | 373 | @call(.always_inline, recursiveInnerProduct, .{ XT.ValueType, XT.SizesType, 0, plan, x, y, z, &x_i, &y_i, &z_i }); 374 | } 375 | 376 | pub inline fn sizeSelector(comptime x_index: usize, comptime y_index: usize, comptime select: usize, x: anytype, y: anytype) usize { 377 | if (select == 0) { 378 | return x.getSize(x_index); 379 | } else { 380 | return y.getSize(y_index); 381 | } 382 | } 383 | 384 | pub inline fn recursiveInnerProduct( 385 | comptime VT: type, // value type 386 | comptime IT: type, // int type 387 | comptime I: usize, // starting index 388 | comptime plan: anytype, // InnerProductPlan 389 | x: anytype, // lhs operand tensor 390 | y: anytype, // rhs operand tensor 391 | z: anytype, // output tensor 392 | xc: *[@TypeOf(x.*).Rank]IT, // index container 393 | yc: *[@TypeOf(y.*).Rank]IT, // index container 394 | zc: *[@TypeOf(z.*).Rank]IT, // index container 395 | ) void { 396 | const XT = @TypeOf(x.*); 397 | const YT = @TypeOf(y.*); 398 | const ZT = @TypeOf(z.*); 399 | 400 | const size = @call(.always_inline, sizeSelector, .{ plan.x_perm[I], plan.y_perm[I], plan.s_ctrl[I], x, y }); 401 | 402 | if (I < (plan.total - 1)) { 403 | var i: IT = 0; 404 | while (i < size) : (i += 1) { 405 | if (comptime plan.x_perm[I] != plan.pass) { 406 | xc[plan.x_perm[I]] = i; 407 | } 408 | if (comptime plan.y_perm[I] != plan.pass) { 409 | yc[plan.y_perm[I]] = i; 410 | } 411 | if (comptime plan.z_perm[I] != plan.pass) { 412 | zc[plan.z_perm[I]] = i; 413 | } 414 | @call(.always_inline, recursiveInnerProduct, .{ VT, IT, (I + 1), plan, x, y, z, xc, yc, zc }); 415 | } 416 | } else { 417 | var i: IT = 0; 418 | while (i < size) : (i += 1) { 419 | if (comptime plan.x_perm[I] != plan.pass) { 420 | xc[plan.x_perm[I]] = i; 421 | } 422 | if (comptime plan.y_perm[I] != plan.pass) { 423 | yc[plan.y_perm[I]] = i; 424 | } 425 | if (comptime plan.z_perm[I] != plan.pass) { 426 | zc[plan.z_perm[I]] = i; 427 | } 428 | const x_n = computeTensorIndex(XT.Rank, XT.SizesType, &x.sizes_and_strides.strides, xc); 429 | const y_n = computeTensorIndex(YT.Rank, YT.SizesType, &y.sizes_and_strides.strides, yc); 430 | const z_n = computeTensorIndex(ZT.Rank, ZT.SizesType, &z.sizes_and_strides.strides, zc); 431 | z.values[z_n] += x.values[x_n] * y.values[y_n]; 432 | } 433 | } 434 | } 435 | 436 | // <>--------------------------------------------------------<> 437 | 438 | // TODO: Add explanation for this crazy thing... 439 | 440 | pub fn outerProduct(comptime expression: []const u8, x: anytype, y: anytype, z: anytype) void { 441 | std.debug.assert(x.isValid() and y.isValid() and z.isValid()); 442 | const XT = @TypeOf(x.*); 443 | const YT = @TypeOf(y.*); 444 | const ZT = @TypeOf(z.*); 445 | 446 | const plan = comptime outerProductParse(XT.Rank, YT.Rank, ZT.Rank, expression); 447 | 448 | if (comptime Util.debug) { 449 | for (plan.x_perm, plan.y_perm, plan.z_perm) |xp, yp, zp| { 450 | if (xp != plan.pass) std.debug.assert(x.getSize(xp) == z.getSize(zp)); 451 | if (yp != plan.pass) std.debug.assert(y.getSize(yp) == z.getSize(zp)); 452 | } 453 | } 454 | 455 | var x_i: [XT.Rank]SizeType = undefined; 456 | var y_i: [YT.Rank]SizeType = undefined; 457 | var z_i: [ZT.Rank]SizeType = undefined; 458 | 459 | @memset(z.values, 0); 460 | 461 | @call(.always_inline, recursiveInnerProduct, .{ XT.ValueType, XT.SizesType, 0, plan, x, y, z, &x_i, &y_i, &z_i }); 462 | } 463 | 464 | pub inline fn recursiveOuterProduct( 465 | comptime VT: type, // value type 466 | comptime IT: type, // int type 467 | comptime I: usize, // starting index 468 | comptime plan: anytype, // InnerProductPlan 469 | x: anytype, // lhs operand tensor 470 | y: anytype, // rhs operand tensor 471 | z: anytype, // output tensor 472 | xc: *[@TypeOf(x.*).Rank]IT, // index container 473 | yc: *[@TypeOf(y.*).Rank]IT, // index container 474 | zc: *[@TypeOf(z.*).Rank]IT, // index container 475 | ) void { 476 | const XT = @TypeOf(x.*); 477 | const YT = @TypeOf(y.*); 478 | const ZT = @TypeOf(z.*); 479 | 480 | const size = @call(.always_inline, sizeSelector, .{ plan.x_perm[I], plan.y_perm[I], plan.s_ctrl[I], x, y }); 481 | 482 | if (I < (plan.total - 1)) { 483 | var i: IT = 0; 484 | while (i < size) : (i += 1) { 485 | if (comptime plan.x_perm[I] != plan.pass) { 486 | xc[plan.x_perm[I]] = i; 487 | } 488 | if (comptime plan.y_perm[I] != plan.pass) { 489 | yc[plan.y_perm[I]] = i; 490 | } 491 | zc[plan.z_perm[I]] = i; 492 | @call(.always_inline, recursiveInnerProduct, .{ VT, IT, (I + 1), plan, x, y, z, xc, yc, zc }); 493 | } 494 | } else { 495 | var i: IT = 0; 496 | while (i < size) : (i += 1) { 497 | if (comptime plan.x_perm[I] != plan.pass) { 498 | xc[plan.x_perm[I]] = i; 499 | } 500 | if (comptime plan.y_perm[I] != plan.pass) { 501 | yc[plan.y_perm[I]] = i; 502 | } 503 | zc[plan.z_perm[I]] = i; 504 | const x_n = computeTensorIndex(XT.Rank, XT.SizesType, &x.sizes_and_strides.strides, xc); 505 | const y_n = computeTensorIndex(YT.Rank, YT.SizesType, &y.sizes_and_strides.strides, yc); 506 | const z_n = computeTensorIndex(ZT.Rank, ZT.SizesType, &z.sizes_and_strides.strides, zc); 507 | z.values[z_n] += x.values[x_n] * y.values[y_n]; 508 | } 509 | } 510 | } 511 | 512 | 513 | // <>--------------------------------------------------------<> 514 | 515 | fn simdReduce( 516 | comptime ReduceType: anytype, 517 | comptime BinaryFunc: anytype, 518 | x: anytype, 519 | init: @TypeOf(x.*).ValueType 520 | ) @TypeOf(x.*).ValueType { 521 | const T = @TypeOf(x.*).ValueType; 522 | var i: usize = 0; 523 | var rdx = init; 524 | 525 | // reduce in size N chunks... 526 | if (comptime std.simd.suggestVectorLength(T)) |N| { 527 | while ((i + N) < x.valueSize()) : (i += N) { 528 | const vec: @Vector(N, T) = x.values[i..i + N][0..N].*; // needs compile time length 529 | rdx = @call(.always_inline, BinaryFunc, .{ rdx, @reduce(ReduceType, vec) }); 530 | } 531 | } 532 | 533 | // reduce remainder... 534 | while (i < x.valueSize()) : (i += 1) { 535 | rdx = @call(.always_inline, BinaryFunc, .{ rdx, x.values[i] }); 536 | } 537 | return rdx; 538 | } 539 | 540 | // <>--------------------------------------------------------<> 541 | 542 | fn simdArithmetic( 543 | comptime BinaryFunc: anytype, 544 | x: anytype, 545 | y: anytype, 546 | z: anytype, 547 | ) void { 548 | 549 | const T = @TypeOf(x.*).ValueType; 550 | var i: usize = 0; 551 | 552 | if (comptime std.simd.suggestVectorLength(T)) |N| { 553 | var j: usize = N; 554 | while(j <= x.valueSize()) : ({i += N; j += N; }) { 555 | const v: @Vector(N, T) = x.values[i..j][0..N].*; 556 | const u: @Vector(N, T) = y.values[i..j][0..N].*; 557 | z.values[i..j][0..N].* = @call(.always_inline, BinaryFunc, .{v, u}); 558 | } 559 | } 560 | 561 | while (i < x.valueSize()) : (i += 1) { 562 | z.values[i] = @call(.always_inline, BinaryFunc, .{ x.values[i], y.values[i] }); 563 | } 564 | } 565 | 566 | // <>--------------------------------------------------------<> 567 | 568 | // TODO: limited in terms of what "map" can be 569 | fn simdMapReduce( 570 | comptime ReduceType: anytype, 571 | comptime UnaryFunc: anytype, 572 | comptime BinaryFunc: anytype, 573 | x: anytype, 574 | init: @TypeOf(x.*).ValueType 575 | ) @TypeOf(x.*).ValueType { 576 | const T = @TypeOf(x.*).ValueType; 577 | 578 | var i: usize = 0; 579 | var rdx = init; 580 | 581 | // reduce in size N chunks... 582 | if (comptime std.simd.suggestVectorLength(T)) |N| { 583 | while ((i + N) < x.valueSize()) : (i += N) { 584 | var vec: @Vector(N, T) = x.values[i..i + N][0..N].*; 585 | vec = @call(.always_inline, UnaryFunc, .{vec}); 586 | rdx = @call(.always_inline, BinaryFunc, .{ rdx, @reduce(ReduceType, vec) }); 587 | } 588 | } 589 | 590 | // reduce remainder... 591 | while (i < x.valueSize()) : (i += 1) { 592 | rdx = @call(.always_inline, BinaryFunc, .{ rdx, x.values[i] }); 593 | } 594 | return rdx; 595 | } 596 | 597 | // <>--------------------------------------------------------<> 598 | 599 | fn simdScalarBroadcast( 600 | comptime BinaryFunc: anytype, 601 | x: anytype, 602 | y: anytype, 603 | s: @TypeOf(x.*).ValueType 604 | ) void { 605 | const T = @TypeOf(x.*).ValueType; 606 | 607 | var i: usize = 0; 608 | 609 | // broadcast in size N chunks... 610 | if (comptime std.simd.suggestVectorLength(T)) |N| { 611 | const u: @Vector(N, T) = @splat(s); 612 | 613 | var j: usize = N; 614 | while (j <= x.values.len) : ({ i += N; j += N; }) { 615 | const v: @Vector(N, T) = x.values[i..j][0..N].*; 616 | y.values[i..j][0..N].* = @call(.always_inline, BinaryFunc, .{ v, u }); 617 | } 618 | } 619 | 620 | // broadcast remainder... 621 | while (i < x.values.len) : (i += 1) { 622 | y.values[i] = @call(.always_inline, BinaryFunc, .{ x.values[i], s }); 623 | } 624 | } 625 | 626 | // <>--------------------------------------------------------<> 627 | 628 | inline fn addGeneric(x: anytype, y: anytype) @TypeOf(x) { 629 | return x + y; 630 | } 631 | inline fn mulGeneric(x: anytype, y: anytype) @TypeOf(x) { 632 | return x * y; 633 | } 634 | inline fn subGeneric(x: anytype, y: anytype) @TypeOf(x) { 635 | return x - y; 636 | } 637 | inline fn divGeneric(x: anytype, y: anytype) @TypeOf(x) { 638 | return x / y; 639 | } 640 | inline fn maxGeneric(x: anytype, y: anytype) @TypeOf(x) { 641 | return @max(x, y); 642 | } 643 | inline fn minGeneric(x: anytype, y: anytype) @TypeOf(x) { 644 | return @min(x, y); 645 | } 646 | 647 | // <>--------------------------------------------------------<> 648 | 649 | pub inline fn absGenericUnchecked(x: anytype) @TypeOf(x) { 650 | const T = @TypeOf(x); 651 | switch (comptime @typeInfo(T)) { 652 | .Float => { 653 | return @abs(x); 654 | }, 655 | .Int => |info| { 656 | if (comptime info.signedness == true) { 657 | const mask = x >> (comptime @bitSizeOf(T) - 1); 658 | return (x + mask) ^ mask; 659 | } else { 660 | return x; 661 | } 662 | }, 663 | .Vector => |info| { 664 | switch (comptime @typeInfo(info.child)) { 665 | .Float => { 666 | return @abs(x); 667 | }, 668 | else => { 669 | @compileError("Absolute value for integer vectors unimplemented."); 670 | }, 671 | } 672 | }, 673 | else => @compileError("Invalid type passed to absGeneric function: " ++ @typeName(T)), 674 | } 675 | } 676 | 677 | pub inline fn absGeneric(x: anytype) !@TypeOf(x) { 678 | const T = @TypeOf(x); 679 | return switch (comptime @typeInfo(T)) { 680 | .Int => |info| { 681 | if (info.signedness) { 682 | if (x == math.minInt(T)) return OpsError.IntegerOverflow; 683 | } 684 | return @call(.always_inline, absGenericUnchecked, .{x}); 685 | }, 686 | else => @call(.always_inline, absGenericUnchecked, .{x}), 687 | }; 688 | } 689 | --------------------------------------------------------------------------------