├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── build.zig ├── build.zig.zon └── src ├── strided_array.zig ├── test.zig └── test_utils.zig /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | workflow_dispatch: 10 | 11 | schedule: 12 | - cron: '17 15 * * 4' 13 | 14 | jobs: 15 | build: 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, windows-latest, macos-latest] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Setup Zig 25 | uses: goto-bus-stop/setup-zig@v2 26 | with: 27 | version: master 28 | 29 | - name: Check formatting 30 | if: ${{ ! runner.os == 'Windows' }} 31 | run: zig fmt --check . 32 | 33 | - name: Run tests 34 | run: zig build test 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | zig-out/ 2 | .zig-cache/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dominic Weiller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zig-strided-arrays 2 | 3 | A library implementing strided arrays for [Zig](https://ziglang.org). 4 | 5 | #### Features 6 | 7 | Strided arrays allow flexible manipulation of and iteration over an underlying slice. This library provides the `StridedArrayView(T, n)` generic type, which is an `n`-dimensional view of a `[]T`. The `StridedArrayView(T, n)` type provides get/set helpers with access by coordinate, iterators over the data 'in view' (in row-major order), utilities to produce sub-views (e.g. with `slice()`), and `flip()` and `transpose()` for cheap (i.e. without a copy) logical reordering of data. 8 | 9 | The strides of a view can be manipulated to achieve a range of effects; currently there is only one 'more advanced' helper for manipulating strides, `slidingWindow()`, which produces a (higher dimensional) view where the inner-most dimensions act as a sliding window over the original view data. If there are utilities you would like to see included please raise an issue or submit a pull request. 10 | 11 | #### Limitations 12 | 13 | The current implementation measures strides in terms of the data type of the underlying slice (not `u8`s), so you can't add exactly `n` bytes of padding to each row of a 2-dimensional array for example (unless the size of the array's data type divides `n`). 14 | 15 | #### Contributing 16 | 17 | Feel free to open issues or submit pull requests if you think something can be improved, or there is other functionality you think would be useful (e.g. helpers for manipulating strides). The current API might be a bit crufty as well, so feel free to suggest improvements. 18 | -------------------------------------------------------------------------------- /build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.Build) void { 4 | const mode = b.standardOptimizeOption(.{}); 5 | 6 | _ = b.addModule("strided-arrays", .{ 7 | .root_source_file = b.path("src/strided_array.zig"), 8 | }); 9 | 10 | const main_tests = b.addTest(.{ 11 | .root_source_file = b.path("src/test.zig"), 12 | .optimize = mode, 13 | }); 14 | const run_main_tests = b.addRunArtifact(main_tests); 15 | 16 | const test_step = b.step("test", "Run library tests"); 17 | test_step.dependOn(&run_main_tests.step); 18 | 19 | b.default_step = test_step; 20 | } 21 | -------------------------------------------------------------------------------- /build.zig.zon: -------------------------------------------------------------------------------- 1 | .{ 2 | .name = .strided_array, 3 | .fingerprint = 0xd8319dc2ec61126d, 4 | .version = "0.0.0", 5 | .paths = .{ 6 | "src/", 7 | "build.zig", 8 | "build.zig.zon", 9 | "README.md", 10 | "LICENSE", 11 | }, 12 | } 13 | -------------------------------------------------------------------------------- /src/strided_array.zig: -------------------------------------------------------------------------------- 1 | //! This module provides a generic StridedArrayView type. 2 | //! 3 | //! A strided array view is an `n`-dimensional view into a slice of type `[]T`. A particular view allows access to a subset of the slice via `n`-dimensional coordinates, or by iterating over elements of the subset in row-major order. 4 | //! 5 | //! Strided array views also allow cheap _logical_ reordering of the underlying slice, allowing dimensions to be arbritrarily transposed, flipped and rotated and resized (as long as the new shape makes sense). 6 | 7 | const std = @import("std"); 8 | const Allocator = @import("std").mem.Allocator; 9 | 10 | const GridError = error{ 11 | IndexOutOfBounds, 12 | BufferTooSmall, 13 | ZeroLengthDimension, 14 | }; 15 | 16 | pub const ViewError = error{ 17 | InvalidView, 18 | OverlappingElementsUnsupported, 19 | }; 20 | 21 | pub fn StridedArrayView(comptime T: type, comptime num_dims: usize) type { 22 | const bit_size = @typeInfo(usize).int.bits / 2; 23 | return StridedArrayViewIdx(T, num_dims, @Type(.{ 24 | .int = .{ .bits = bit_size, .signedness = .unsigned }, 25 | })); 26 | } 27 | 28 | pub fn StridedArrayViewIdx(comptime T: type, comptime num_dims: usize, comptime IndexType: type) type { 29 | return struct { 30 | const Self = @This(); 31 | 32 | const info = @typeInfo(IndexType); 33 | comptime { 34 | if (info != .int or info.int.signedness == .signed) { 35 | @compileError("StridedArrayView IndexType must be an unsigned integer type"); 36 | } 37 | if (info.int.bits > 64) { 38 | @compileError(std.fmt.comptimePrint( 39 | "Maximum allowed bit size is for IndexType is 64; got {d}-bit type", 40 | .{info.int.bits}, 41 | )); 42 | } 43 | } 44 | 45 | pub const Indices = [num_dims]IndexType; 46 | pub const StrideType = @Type(.{ .int = .{ .bits = 2 * info.int.bits, .signedness = .signed } }); 47 | pub const Stride = [num_dims]StrideType; 48 | 49 | pub const dim_count = num_dims; 50 | pub const EltType = T; 51 | 52 | items: []T, 53 | stride: Stride, 54 | shape: Indices, 55 | offset: IndexType, 56 | 57 | fn strideOfShapePacked(shape: Indices) Stride { 58 | var stride: Stride = undefined; 59 | for (&stride, shape) |*stride_elt, shape_elt| { 60 | stride_elt.* = shape_elt; 61 | } 62 | stride[stride.len - 1] = 1; 63 | var i = shape.len - 1; 64 | while (i > 0) : (i -= 1) { 65 | stride[i - 1] = stride[i] * shape[i]; 66 | } 67 | return stride; 68 | } 69 | 70 | fn validView(self: Self) bool { 71 | return if (self.size() > 0 and self.maxCoordIndex() < self.items.len) 72 | true 73 | else 74 | false; 75 | } 76 | 77 | fn maxCoordIndex(self: Self) usize { 78 | var max_coord = self.shape; 79 | for (&max_coord) |*v| { 80 | v.* -= 1; 81 | } 82 | return self.sliceIndex(max_coord); 83 | } 84 | 85 | /// Create a packed view (i.e. a view with no gaps between elements in the underlying slice) 86 | pub fn ofSlicePacked(items: []T, shape: Indices) !Self { 87 | return ofSliceStrided(items, strideOfShapePacked(shape), shape); 88 | } 89 | 90 | /// Create a view with the given `stride`s. 91 | pub fn ofSliceStrided(items: []T, stride: Stride, shape: Indices) !Self { 92 | return ofSliceExtra(items, 0, stride, shape); 93 | } 94 | 95 | /// Create a view with the given `stride`s and `offset` into the underylying slice. 96 | pub fn ofSliceExtra(items: []T, offset: IndexType, stride: Stride, shape: Indices) !Self { 97 | const view = Self{ 98 | .items = items, 99 | .stride = stride, 100 | .shape = shape, 101 | .offset = offset, 102 | }; 103 | if (!view.validView()) 104 | return ViewError.InvalidView; 105 | return view; 106 | } 107 | 108 | fn isValid(self: Self, coord: Indices) bool { 109 | inline for (coord, self.shape) |coord_elt, shape_elt| { 110 | if (coord_elt >= shape_elt) { 111 | return false; 112 | } 113 | } 114 | return true; 115 | } 116 | 117 | /// Returns the index in the underlying slice of the element at `coord`, 118 | /// or `null` if `coord` is not valid. 119 | pub fn sliceIndexOrNull(self: Self, coord: Indices) ?usize { 120 | return if (self.isValid(coord)) self.sliceIndex(coord) else null; 121 | } 122 | 123 | /// Returns the index in the underlying slice of the element at `coord`. 124 | /// The caller guarantees that `coord` is valid. 125 | pub fn sliceIndex(self: Self, coord: Indices) usize { 126 | var index: StrideType = @as(StrideType, self.offset); 127 | inline for (coord, self.stride) |coord_elt, stride_elt| { 128 | index += coord_elt * stride_elt; 129 | } 130 | return @intCast(index); 131 | } 132 | 133 | /// Returns the iteration index for row-major ordering of the element at `coord`, 134 | /// or `null` if `coord` if not valid. 135 | pub fn iterIndexOrNull(self: Self, coord: Indices) ?usize { 136 | return if (self.isValid(coord)) self.iterIndex(coord) else null; 137 | } 138 | 139 | /// Returns the iteration index for row-major ordering or the element at `coord`. 140 | pub fn iterIndex(self: Self, coord: Indices) usize { 141 | var index: usize = coord[num_dims - 1]; 142 | 143 | comptime var dim = num_dims - 1; 144 | var s: usize = 1; 145 | inline while (dim > 0) : (dim -= 1) { 146 | s *= self.shape[dim]; 147 | index += coord[dim - 1] * s; 148 | } 149 | return index; 150 | } 151 | 152 | /// Returns coordinates in row-major order 153 | pub fn coordOfIterIndex(self: Self, index: usize) Indices { 154 | var coord: Indices = undefined; 155 | var idx = index; 156 | comptime var i = num_dims; 157 | inline while (i > 0) : (i -= 1) { 158 | coord[i - 1] = @intCast(idx % @as(usize, self.shape[i - 1])); 159 | idx /= self.shape[i - 1]; 160 | } 161 | return coord; 162 | } 163 | 164 | fn strideGreaterThan(stride: Stride, a: usize, b: usize) bool { 165 | const l = @abs(stride[a]); 166 | const r = @abs(stride[b]); 167 | return l > r; 168 | } 169 | 170 | fn strideOrdering(self: Self) [num_dims]usize { 171 | var dims = comptime dims: { 172 | var res: [num_dims]usize = undefined; 173 | for (&res, 0..) |*r, i| { 174 | r.* = i; 175 | } 176 | break :dims res; 177 | }; 178 | // given that `dims` should be small, insertion sort should be fine 179 | std.sort.insertion(usize, dims[0..], self.stride, strideGreaterThan); 180 | return dims; 181 | } 182 | 183 | fn viewOverlapping(self: Self, order: [num_dims]usize) bool { 184 | var overlapping = false; 185 | inline for (0..num_dims - 1) |i| { 186 | overlapping = overlapping or 187 | @abs(self.stride[order[i]]) < @abs(self.stride[order[i + 1]] * self.shape[order[i + 1]]); 188 | } 189 | return overlapping; 190 | } 191 | 192 | /// Returns ViewError.OverlappingElementsUnsupported if the view has overlapping elements 193 | fn coordOfSliceIndex(self: Self, index: IndexType) !Indices { 194 | const dims_in_order = self.strideOrdering(); 195 | if (self.viewOverlapping(dims_in_order)) { 196 | return ViewError.OverlappingElementsUnsupported; 197 | } 198 | var coord: Indices = undefined; 199 | var idx = @as(StrideType, index - self.offset); 200 | 201 | inline for (0..self.stride.len) |i| { 202 | coord[dims_in_order[i]] = @intCast(@divTrunc(idx, self.stride[dims_in_order[i]])); 203 | idx = @rem(idx, self.stride[dims_in_order[i]]); 204 | } 205 | return coord; 206 | } 207 | 208 | /// Returns the element at `coord`, or `null` if `coord` is invalid. 209 | pub fn getOrNull(self: Self, coord: Indices) ?T { 210 | return if (self.isValid(coord)) self.get(coord) else null; 211 | } 212 | 213 | /// Returns the element at `coord`; asserts that `coord` is valid. 214 | pub fn get(self: Self, coord: Indices) T { 215 | return self.items[self.sliceIndex(coord)]; 216 | } 217 | 218 | /// Returns a pointer to the element at `coord`, or `null` if `coord` is invalid. 219 | pub fn getPtrOrNull(self: Self, coord: Indices) ?*T { 220 | return if (self.isValid(coord)) self.getPtr(coord) else null; 221 | } 222 | 223 | /// Returns a pointer to the element at `coord`; asserts that `coord` is valid. 224 | pub fn getPtr(self: Self, coord: Indices) *T { 225 | return &self.items[self.sliceIndex(coord)]; 226 | } 227 | 228 | /// Sets the value at `coord`; asserts that `coord` is valid. 229 | pub fn set(self: Self, coord: Indices, value: T) void { 230 | self.items[self.sliceIndex(coord)] = value; 231 | } 232 | 233 | /// Returns the size of a view with the provided `shape`. 234 | pub fn sizeOf(shape: Indices) usize { 235 | var result: usize = 1; 236 | inline for (shape) |s| { 237 | result *= s; 238 | } 239 | return result; 240 | } 241 | 242 | /// Returns the size of a view. 243 | pub fn size(self: Self) usize { 244 | return sizeOf(self.shape); 245 | } 246 | 247 | pub const Iterator = struct { 248 | index: usize, 249 | last: usize, 250 | array_view: Self, 251 | 252 | pub fn nextSliceIndex(self: *Iterator) ?usize { 253 | if (self.index >= self.last) return null; 254 | const coord = self.array_view.coordOfIterIndex(self.index); 255 | const index = self.array_view.sliceIndex(coord); 256 | self.index += 1; 257 | return index; 258 | } 259 | 260 | pub fn next(self: *Iterator) ?T { 261 | const index = self.nextSliceIndex() orelse return null; 262 | return self.array_view.items[index]; 263 | } 264 | 265 | pub fn nextPtr(self: *Iterator) ?*T { 266 | const index = self.nextSliceIndex() orelse return null; 267 | return &self.array_view.items[index]; 268 | } 269 | 270 | pub const PtrInd = struct { 271 | ptr: *T, 272 | index: usize, 273 | }; 274 | 275 | pub fn nextPtrWithIndex(self: *Iterator) ?PtrInd { 276 | const index = self.index; 277 | const ptr = self.nextPtr() orelse return null; 278 | return PtrInd{ 279 | .ptr = ptr, 280 | .index = index, 281 | }; 282 | } 283 | 284 | pub const PtrCoord = struct { 285 | ptr: *T, 286 | coord: Indices, 287 | }; 288 | 289 | pub fn nextPtrWithCoord(self: *Iterator) ?PtrCoord { 290 | const coord = self.array_view.coordOfIterIndex(self.index); 291 | const ptr = self.nextPtr() orelse return null; 292 | return PtrCoord{ 293 | .ptr = ptr, 294 | .coord = coord, 295 | }; 296 | } 297 | 298 | pub const PtrCoordInd = struct { 299 | ptr: *T, 300 | coord: Indices, 301 | index: usize, 302 | }; 303 | 304 | pub fn nextPtrWithBoth(self: *Iterator) ?PtrCoordInd { 305 | const index = self.index; 306 | const coord = self.array_view.coordOfIterIndex(index); 307 | const ptr = self.nextPtr() orelse return null; 308 | return PtrCoordInd{ 309 | .ptr = ptr, 310 | .coord = coord, 311 | .index = index, 312 | }; 313 | } 314 | 315 | pub const TInd = struct { 316 | val: T, 317 | index: usize, 318 | }; 319 | 320 | pub fn nextWithIndex(self: *Iterator) ?TInd { 321 | const index = self.index; 322 | const item = self.next() orelse return null; 323 | return TInd{ 324 | .val = item, 325 | .index = index, 326 | }; 327 | } 328 | 329 | pub const TCoord = struct { 330 | val: T, 331 | coord: Indices, 332 | }; 333 | 334 | pub fn nextWithCoord(self: *Iterator) ?TCoord { 335 | const coord = self.array_view.coordOfIterIndex(self.index); 336 | const item = self.next() orelse return null; 337 | return TCoord{ 338 | .val = item, 339 | .coord = coord, 340 | }; 341 | } 342 | 343 | pub const TCoordInd = struct { 344 | val: T, 345 | coord: Indices, 346 | index: usize, 347 | }; 348 | 349 | pub fn nextWithBoth(self: *Iterator) ?TCoordInd { 350 | const index = self.index; 351 | const coord = self.array_view.coordOfIterIndex(index); 352 | const item = self.next() orelse return null; 353 | return TCoordInd{ 354 | .val = item, 355 | .coord = coord, 356 | .index = index, 357 | }; 358 | } 359 | }; 360 | 361 | /// Iterate over the whole view. 362 | pub fn iterate(self: Self) Iterator { 363 | return self.iterateFrom(0); 364 | } 365 | 366 | /// Iterate from (iteration) index `first` to the end of the view. 367 | pub fn iterateFrom(self: Self, first: usize) Iterator { 368 | return self.iterateRange(first, self.size()); 369 | } 370 | 371 | /// Iterate over the view up to (iteration) index `last`. 372 | pub fn iterateTo(self: Self, last: usize) Iterator { 373 | return self.iterateRange(0, last); 374 | } 375 | 376 | /// Iterate from (iteration) index `first` to `last`. 377 | pub fn iterateRange(self: Self, first: usize, last: usize) Iterator { 378 | return Iterator{ 379 | .index = first, 380 | .last = last, 381 | .array_view = self, 382 | }; 383 | } 384 | 385 | pub const WrapIterator = struct { 386 | offset: Indices, 387 | coord: Indices, 388 | shape: Indices, 389 | done: bool, 390 | array_view: Self, 391 | 392 | pub fn next(self: *WrapIterator) ?T { 393 | const coord = self.coord; 394 | if (self.done) return null; 395 | 396 | var underlying_coord: Indices = undefined; 397 | for ( 398 | &underlying_coord, 399 | self.offset, 400 | coord, 401 | self.array_view.shape, 402 | ) |*uc, offset_elt, coord_elt, shape_elt| { 403 | uc.* = (offset_elt + coord_elt) % shape_elt; 404 | } 405 | 406 | var iter_coord = self.coord; 407 | var dim = self.shape.len - 1; 408 | iter_coord[dim] = (iter_coord[dim] + 1) % self.shape[dim]; 409 | while (dim > 0 and iter_coord[dim] == 0) : (dim -= 1) { 410 | iter_coord[dim - 1] = (iter_coord[dim - 1] + 1) % self.shape[dim - 1]; 411 | } 412 | 413 | if (dim == 0 and iter_coord[0] == 0) self.done = true; 414 | 415 | self.coord = iter_coord; 416 | 417 | return self.array_view.get(underlying_coord); 418 | } 419 | }; 420 | 421 | /// Iterate over the given region, but wrap coordinates along each dimension. 422 | /// If wrapping behaviour is not needed or desired, you can `slice()` to the 423 | /// shape desired and then `iterate()` instead. 424 | pub fn iterateWrap(self: Self, from: Indices, shape: Indices) WrapIterator { 425 | const start = [1]IndexType{0} ** num_dims; 426 | return WrapIterator{ 427 | .offset = from, 428 | .coord = start, 429 | .shape = shape, 430 | .done = false, 431 | .array_view = self, 432 | }; 433 | } 434 | 435 | /// Copy data to `buf`. See `copyToAlloc()` for a wrapper that takes an allocator. 436 | pub fn copyTo(self: Self, buf: []T) void { 437 | std.debug.assert(buf.len >= self.size()); 438 | var iter = self.iterate(); 439 | var i: usize = 0; 440 | while (iter.next()) |val| : (i += 1) { 441 | buf[i] = val; 442 | } 443 | } 444 | 445 | /// Copy data to a newly allocated slice. 446 | pub fn copyToAlloc(self: Self, allocator: Allocator) ![]T { 447 | const buf = try allocator.alloc(T, self.size()); 448 | self.copyTo(buf); 449 | return buf; 450 | } 451 | 452 | /// Transpose dimensions `dim_1` and `dim_2`. 453 | pub fn transpose(self: *Self, dim_1: usize, dim_2: usize) void { 454 | std.mem.swap(StrideType, &self.stride[dim_1], &self.stride[dim_2]); 455 | std.mem.swap(IndexType, &self.shape[dim_1], &self.shape[dim_2]); 456 | } 457 | 458 | /// Flip dimension `dim` to run in the opposite direction. 459 | pub fn flip(self: *Self, dim: usize) void { 460 | self.offset = @intCast(@as(StrideType, self.offset) + self.stride[dim] * (self.shape[dim] - 1)); 461 | self.stride[dim] = -self.stride[dim]; 462 | } 463 | 464 | /// Returns a new view containing the sub-region starting at `from` 465 | /// with the given shape, or null if `shape` is not valid. 466 | pub fn sliceOrNull(self: Self, from: Indices, shape: Indices) ?Self { 467 | const view = self.slice(from, shape); 468 | return if (view.validView()) view else null; 469 | } 470 | 471 | /// Returns a new view containing the sub-region starting at `from` 472 | /// with the given shape. Do not slice to a shape that has zero size 473 | /// (i.e. any dimension is zero). 474 | pub fn slice(self: Self, from: Indices, shape: Indices) Self { 475 | return Self{ 476 | .items = self.items, 477 | .shape = shape, 478 | .stride = self.stride, 479 | .offset = @intCast(self.sliceIndex(from)), 480 | }; 481 | } 482 | 483 | /// Returns a new view containing the sub-region starting at `from` 484 | /// with the given shape and step size or null if invalid shape/step 485 | /// combination is passed. Note that `shape` corresponds to the region 486 | /// in the original view (i.e. if all steps are 1) and each component of 487 | /// `steps` must be non-zero. Prefer `sliceOrNull` if stepping is not 488 | /// required. 489 | pub fn sliceStepOrNull(self: Self, from: Indices, shape: Indices, steps: Indices) ?Self { 490 | const view = self.sliceStep(from, shape, steps); 491 | return if (view.validView()) view else null; 492 | } 493 | 494 | /// Returns a new view containing the sub-region starting at `from` 495 | /// with the given shape and step size. Note that `shape` corresponds 496 | /// to the region in the original view (i.e. if all steps are 1) 497 | /// and each component of `steps` must be non-zero. Prefer `slice` 498 | /// if stepping is not required. 499 | pub fn sliceStep(self: Self, from: Indices, shape: Indices, steps: Indices) Self { 500 | var stride = self.stride; 501 | for (&stride, steps) |*s, step| { 502 | s.* *= step; 503 | } 504 | var result_shape: Indices = undefined; 505 | for (&result_shape, shape, steps) |*r, shape_elt, step| { 506 | r.* = (shape_elt + step - 1) / step; 507 | } 508 | return Self{ 509 | .items = self.items, 510 | .shape = result_shape, 511 | .stride = stride, 512 | .offset = @intCast(self.sliceIndex(from)), 513 | }; 514 | } 515 | 516 | /// creates a new view whose `dims` inner dimensions creating a sliding window 517 | /// over the `dims` inner-most dimensions of `self`. The strides of the window 518 | /// dimensions are copied from the corresponding dimensions of `self`. 519 | /// Asserts `dims < num_dims` 520 | pub fn slidingWindow( 521 | self: Self, 522 | comptime dims: usize, 523 | window_shape: [dims]IndexType, 524 | ) StridedArrayViewIdx(T, dims + num_dims, IndexType) { 525 | std.debug.assert(dims <= num_dims); 526 | 527 | const total_dims = dims + num_dims; 528 | var shape: [total_dims]IndexType = undefined; 529 | // copy shape from dimensions we're not sliding along 530 | @memcpy(shape[0 .. num_dims - dims], self.shape[0 .. num_dims - dims]); 531 | // reduce shape size in directions we slide along 532 | for ( 533 | shape[num_dims - dims .. num_dims], 534 | self.shape[num_dims - dims ..], 535 | window_shape, 536 | ) |*s, shape_elt, window_shape_elt| { 537 | s.* = shape_elt - window_shape_elt + 1; 538 | } 539 | @memcpy(shape[num_dims..], window_shape[0..]); 540 | 541 | var stride: [total_dims]StrideType = undefined; 542 | // copy stride for the original dimensions 543 | @memcpy(stride[0..num_dims], self.stride[0..]); 544 | // copy strides into corresponding window dimensions 545 | for (stride[num_dims..], self.stride[num_dims - dims ..]) |*s, stride_elt| { 546 | s.* = stride_elt; 547 | } 548 | return StridedArrayViewIdx(T, dims + num_dims, IndexType){ 549 | .items = self.items, 550 | .stride = stride, 551 | .shape = shape, 552 | .offset = self.offset, 553 | }; 554 | } 555 | }; 556 | } 557 | 558 | const TestArrayView = StridedArrayView(u8, 3); 559 | var one_to_23 = [24]TestArrayView.EltType{ 560 | // zig fmt: off 561 | 0, 1, 2, 3, 4, 5, 6, 7, 562 | 8, 9, 10, 11, 12, 13, 14, 15, 563 | 16, 17, 18, 19, 20, 21, 22, 23, 564 | // zig fmt: on 565 | }; 566 | 567 | const testing = std.testing; 568 | 569 | test { 570 | std.testing.refAllDeclsRecursive(@This()); 571 | std.testing.refAllDeclsRecursive(TestArrayView); 572 | } 573 | 574 | test "strided_array refAllDecls" { 575 | std.testing.refAllDecls(@This()); 576 | } 577 | 578 | 579 | test "StridedArrayView.strideOfShapePacked()" { 580 | const shape = TestArrayView.Indices{ 2, 3, 4 }; 581 | const expected = TestArrayView.Stride{ 12, 4, 1 }; 582 | try testing.expectEqual(expected, TestArrayView.strideOfShapePacked(shape)); 583 | } 584 | 585 | test "StridedArrayView.validView()" { 586 | var array_view = TestArrayView{ 587 | .items = one_to_23[0..], 588 | .stride = .{ 12, 4, 1 }, 589 | .shape = .{ 2, 3, 4 }, 590 | .offset = 0, 591 | }; 592 | // differing offsets (along with associated stride/shape) 593 | try testing.expect(array_view.validView()); 594 | array_view.stride = .{ 12, 4, 2 }; 595 | try testing.expect(!array_view.validView()); 596 | array_view.shape = .{ 2, 3, 2 }; 597 | try testing.expect(array_view.validView()); 598 | array_view.offset = 1; 599 | try testing.expect(array_view.validView()); 600 | array_view.offset = 2; 601 | try testing.expect(!array_view.validView()); 602 | array_view.offset = 3; 603 | try testing.expect(!array_view.validView()); 604 | array_view.offset = 4; 605 | try testing.expect(!array_view.validView()); 606 | array_view.shape = .{ 2, 2, 2 }; 607 | array_view.offset = 1; 608 | try testing.expect(array_view.validView()); 609 | array_view.offset = 2; 610 | try testing.expect(array_view.validView()); 611 | array_view.offset = 3; 612 | try testing.expect(array_view.validView()); 613 | array_view.offset = 4; 614 | try testing.expect(array_view.validView()); 615 | array_view.offset = 5; 616 | try testing.expect(array_view.validView()); 617 | array_view.offset = 6; 618 | try testing.expect(!array_view.validView()); 619 | 620 | // overlapping window 621 | array_view.offset = 4; 622 | array_view.stride = .{ 4, 1, 1 }; 623 | array_view.shape = .{ 5, 3, 2 }; 624 | try testing.expect(array_view.validView()); 625 | array_view.offset = 5; 626 | try testing.expect(!array_view.validView()); 627 | 628 | // overlapping but not unit 629 | array_view.offset = 0; 630 | array_view.stride = .{ 4, 2, 1 }; 631 | array_view.shape = .{ 5, 3, 3 }; 632 | try testing.expect(array_view.validView()); 633 | array_view.offset = 1; 634 | try testing.expect(array_view.validView()); 635 | array_view.offset = 2; 636 | try testing.expect(!array_view.validView()); 637 | 638 | // zero size is not valid 639 | array_view.offset = 0; 640 | array_view.shape = .{ 0, 3, 4 }; 641 | array_view.stride = .{ 12, 4, 1 }; 642 | try testing.expect(!array_view.validView()); 643 | } 644 | 645 | test "StridedArrayView.isValid()" { 646 | var array_view = TestArrayView{ 647 | .items = one_to_23[0..], 648 | .stride = .{ 12, 4, 1 }, 649 | .shape = .{ 2, 3, 4 }, 650 | .offset = 0, 651 | }; 652 | try testing.expect(array_view.isValid(.{ 1, 2, 3 })); 653 | try testing.expect(array_view.isValid(.{ 0, 0, 0 })); 654 | try testing.expect(array_view.isValid(.{ 1, 1, 1 })); 655 | try testing.expect(!array_view.isValid(.{ 2, 2, 3 })); 656 | try testing.expect(!array_view.isValid(.{ 1, 3, 3 })); 657 | try testing.expect(!array_view.isValid(.{ 1, 2, 4 })); 658 | 659 | // overlapping window 660 | array_view.stride = .{ 4, 1, 1 }; 661 | array_view.shape = .{ 6, 3, 2 }; 662 | try testing.expect(array_view.isValid(.{ 4, 2, 1 })); 663 | array_view.offset = 4; 664 | array_view.shape = .{ 5, 3, 2 }; 665 | try testing.expect(array_view.isValid(.{ 4, 2, 1 })); 666 | } 667 | 668 | test "StridedArrayView.strideOrdering()" { 669 | var array_view = TestArrayView{ 670 | .items = one_to_23[0..], 671 | .stride = .{ 12, 4, 1 }, 672 | .shape = .{ 2, 3, 4 }, 673 | .offset = 0, 674 | }; 675 | 676 | try testing.expectEqualSlices(usize, &.{ 0, 1, 2 }, &array_view.strideOrdering()); 677 | array_view.transpose(0, 1); 678 | try testing.expectEqualSlices(usize, &.{ 1, 0, 2 }, &array_view.strideOrdering()); 679 | array_view.transpose(0, 1); 680 | array_view.transpose(1, 2); 681 | try testing.expectEqualSlices(usize, &.{ 0, 2, 1 }, &array_view.strideOrdering()); 682 | array_view.transpose(1, 2); 683 | array_view.transpose(0, 2); 684 | try testing.expectEqualSlices(usize, &.{ 2, 1, 0 }, &array_view.strideOrdering()); 685 | 686 | array_view.stride = .{ 4, 1, 1 }; 687 | array_view.shape = .{ 6, 3, 2 }; 688 | 689 | try testing.expectEqualSlices(usize, &.{ 0, 1, 2 }, &array_view.strideOrdering()); 690 | array_view.transpose(0, 1); 691 | try testing.expectEqualSlices(usize, &.{ 1, 0, 2 }, &array_view.strideOrdering()); 692 | array_view.transpose(0, 1); 693 | array_view.transpose(1, 2); 694 | try testing.expectEqualSlices(usize, &.{ 0, 1, 2 }, &array_view.strideOrdering()); 695 | array_view.transpose(1, 2); 696 | array_view.transpose(0, 2); 697 | try testing.expectEqualSlices(usize, &.{ 2, 0, 1 }, &array_view.strideOrdering()); 698 | 699 | // overlapping but not unit 700 | array_view.offset = 0; 701 | array_view.stride = .{ 4, 2, 1 }; 702 | array_view.shape = .{ 5, 3, 3 }; 703 | try testing.expectEqualSlices(usize, &.{ 0, 1, 2 }, &array_view.strideOrdering()); 704 | array_view.transpose(0, 1); 705 | try testing.expectEqualSlices(usize, &.{ 1, 0, 2 }, &array_view.strideOrdering()); 706 | array_view.transpose(0, 1); 707 | array_view.transpose(1, 2); 708 | try testing.expectEqualSlices(usize, &.{ 0, 2, 1 }, &array_view.strideOrdering()); 709 | array_view.transpose(1, 2); 710 | array_view.transpose(0, 2); 711 | try testing.expectEqualSlices(usize, &.{ 2, 1, 0 }, &array_view.strideOrdering()); 712 | } 713 | 714 | const ForAllSymmetries = @import("test_utils.zig").ForAllSymmetries; 715 | 716 | test "StridedArrayView.viewOverlapping()" { 717 | var array_view = TestArrayView{ 718 | .items = one_to_23[0..], 719 | .stride = .{ 12, 4, 1 }, 720 | .shape = .{ 2, 3, 4 }, 721 | .offset = 0, 722 | }; 723 | 724 | const tests = struct { 725 | fn overlap(ctx: void, av: TestArrayView) error{TestUnexpectedResult}!void { 726 | _ = ctx; 727 | try testing.expect(av.viewOverlapping(av.strideOrdering())); 728 | } 729 | fn noOverlap(ctx: void, av: TestArrayView) error{TestUnexpectedResult}!void { 730 | _ = ctx; 731 | try testing.expect(!av.viewOverlapping(av.strideOrdering())); 732 | } 733 | }; 734 | const no_overlap = ForAllSymmetries(void, TestArrayView, tests.noOverlap){ .ctx = {} }; 735 | const overlap = ForAllSymmetries(void, TestArrayView, tests.overlap){ .ctx = {} }; 736 | 737 | try no_overlap.run(&array_view); 738 | 739 | // overlapping window 740 | array_view.offset = 4; 741 | array_view.stride = .{ 4, 1, 1 }; 742 | array_view.shape = .{ 5, 3, 2 }; 743 | 744 | try overlap.run(&array_view); 745 | 746 | // overlapping but not unit 747 | array_view.offset = 0; 748 | array_view.stride = .{ 4, 2, 1 }; 749 | array_view.shape = .{ 5, 3, 3 }; 750 | 751 | try overlap.run(&array_view); 752 | } 753 | 754 | test "StridedArrayView.coordOfSliceIndex()" { 755 | var array_view = TestArrayView{ 756 | .items = one_to_23[0..], 757 | .stride = .{ 12, 4, 1 }, 758 | .shape = .{ 2, 3, 4 }, 759 | .offset = 0, 760 | }; 761 | 762 | try testing.expectEqual(TestArrayView.Indices{ 0, 0, 0 }, try array_view.coordOfSliceIndex(0)); 763 | try testing.expectEqual(TestArrayView.Indices{ 0, 0, 3 }, try array_view.coordOfSliceIndex(3)); 764 | try testing.expectEqual(TestArrayView.Indices{ 1, 2, 0 }, try array_view.coordOfSliceIndex(20)); 765 | try testing.expectEqual(TestArrayView.Indices{ 1, 1, 2 }, try array_view.coordOfSliceIndex(18)); 766 | try testing.expectEqual(TestArrayView.Indices{ 1, 2, 3 }, try array_view.coordOfSliceIndex(23)); 767 | 768 | // transposed 769 | array_view.transpose(1, 2); 770 | try testing.expectEqual(TestArrayView.Indices{ 0, 0, 0 }, try array_view.coordOfSliceIndex(0)); 771 | try testing.expectEqual(TestArrayView.Indices{ 0, 1, 0 }, try array_view.coordOfSliceIndex(1)); 772 | try testing.expectEqual(TestArrayView.Indices{ 0, 2, 0 }, try array_view.coordOfSliceIndex(2)); 773 | try testing.expectEqual(TestArrayView.Indices{ 0, 3, 0 }, try array_view.coordOfSliceIndex(3)); 774 | try testing.expectEqual(TestArrayView.Indices{ 1, 0, 2 }, try array_view.coordOfSliceIndex(20)); 775 | try testing.expectEqual(TestArrayView.Indices{ 1, 2, 1 }, try array_view.coordOfSliceIndex(18)); 776 | try testing.expectEqual(TestArrayView.Indices{ 1, 3, 2 }, try array_view.coordOfSliceIndex(23)); 777 | 778 | // overlapping window 779 | array_view.offset = 4; 780 | array_view.stride = .{ 4, 1, 1 }; 781 | array_view.shape = .{ 5, 3, 2 }; 782 | try testing.expectError(ViewError.OverlappingElementsUnsupported, array_view.coordOfSliceIndex(4)); 783 | 784 | // transposed 785 | array_view.transpose(1, 2); 786 | try testing.expectError(ViewError.OverlappingElementsUnsupported, array_view.coordOfSliceIndex(4)); 787 | 788 | // overlapping but not unit 789 | array_view.offset = 0; 790 | array_view.stride = .{ 4, 2, 1 }; 791 | array_view.shape = .{ 5, 3, 3 }; 792 | try testing.expectError(ViewError.OverlappingElementsUnsupported, array_view.coordOfSliceIndex(4)); 793 | } 794 | -------------------------------------------------------------------------------- /src/test.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const testing = std.testing; 3 | 4 | const strided_array = @import("strided_array.zig"); 5 | const ViewError = strided_array.ViewError; 6 | 7 | const TestArrayView = strided_array.StridedArrayView(u8, 3); 8 | var one_to_23 = [24]TestArrayView.EltType{ 9 | // zig fmt: off 10 | 0, 1, 2, 3, 4, 5, 6, 7, 11 | 8, 9, 10, 11, 12, 13, 14, 15, 12 | 16, 17, 18, 19, 20, 21, 22, 23, 13 | // zig fmt: on 14 | }; 15 | 16 | const ForAllSymmetries = @import("test_utils.zig").ForAllSymmetries; 17 | 18 | test "StridedArrayView.sliceIndex()" { 19 | var array_view = TestArrayView{ 20 | .items = one_to_23[0..], 21 | .stride = .{ 12, 4, 1 }, 22 | .shape = .{ 2, 3, 4 }, 23 | .offset = 0, 24 | }; 25 | 26 | try testing.expectEqual(@as(usize, 0), array_view.sliceIndex(.{ 0, 0, 0 })); 27 | try testing.expectEqual(@as(usize, 15), array_view.sliceIndex(.{ 1, 0, 3 })); 28 | try testing.expectEqual(@as(usize, 23), array_view.sliceIndex(.{ 1, 2, 3 })); 29 | 30 | // overlapping window 31 | array_view.offset = 4; 32 | array_view.stride = .{ 4, 1, 1 }; 33 | array_view.shape = .{ 5, 3, 2 }; 34 | try testing.expectEqual(@as(usize, 4), array_view.sliceIndex(.{ 0, 0, 0 })); 35 | try testing.expectEqual(@as(usize, 5), array_view.sliceIndex(.{ 0, 0, 1 })); 36 | try testing.expectEqual(@as(usize, 5), array_view.sliceIndex(.{ 0, 1, 0 })); 37 | try testing.expectEqual(@as(usize, 7), array_view.sliceIndex(.{ 0, 2, 1 })); 38 | try testing.expectEqual(@as(usize, 8), array_view.sliceIndex(.{ 1, 0, 0 })); 39 | try testing.expectEqual(@as(usize, 15), array_view.sliceIndex(.{ 2, 2, 1 })); 40 | try testing.expectEqual(@as(usize, 16), array_view.sliceIndex(.{ 3, 0, 0 })); 41 | try testing.expectEqual(@as(usize, 18), array_view.sliceIndex(.{ 3, 1, 1 })); 42 | try testing.expectEqual(@as(usize, 18), array_view.sliceIndex(.{ 3, 2, 0 })); 43 | try testing.expectEqual(@as(usize, 23), array_view.sliceIndex(.{ 4, 2, 1 })); 44 | 45 | // overlapping but not unit 46 | array_view.offset = 0; 47 | array_view.stride = .{ 4, 2, 1 }; 48 | array_view.shape = .{ 5, 3, 3 }; 49 | try testing.expectEqual(@as(usize, 0), array_view.sliceIndex(.{ 0, 0, 0 })); 50 | try testing.expectEqual(@as(usize, 1), array_view.sliceIndex(.{ 0, 0, 1 })); 51 | try testing.expectEqual(@as(usize, 2), array_view.sliceIndex(.{ 0, 0, 2 })); 52 | try testing.expectEqual(@as(usize, 2), array_view.sliceIndex(.{ 0, 1, 0 })); 53 | try testing.expectEqual(@as(usize, 4), array_view.sliceIndex(.{ 0, 1, 2 })); 54 | try testing.expectEqual(@as(usize, 4), array_view.sliceIndex(.{ 1, 0, 0 })); 55 | try testing.expectEqual(@as(usize, 15), array_view.sliceIndex(.{ 3, 1, 1 })); 56 | try testing.expectEqual(@as(usize, 22), array_view.sliceIndex(.{ 4, 2, 2 })); 57 | } 58 | 59 | test "StridedArrayView.size()" { 60 | var array_view = TestArrayView{ 61 | .items = one_to_23[0..], 62 | .stride = .{ 12, 4, 1 }, 63 | .shape = .{ 2, 3, 4 }, 64 | .offset = 0, 65 | }; 66 | 67 | try testing.expectEqual(@as(usize, 0), TestArrayView.sizeOf(.{ 0, 2, 4 })); 68 | 69 | const size_func = struct { 70 | fn f(ctx: usize, av: TestArrayView) error{TestExpectedEqual}!void { 71 | try testing.expectEqual(ctx, av.size()); 72 | } 73 | }.f; 74 | 75 | const size_test = ForAllSymmetries(usize, TestArrayView, size_func); 76 | 77 | { 78 | const s = size_test{ .ctx = 24 }; 79 | try s.run(&array_view); 80 | } 81 | 82 | // strided 83 | array_view.stride = .{ 12, 4, 2 }; 84 | array_view.shape = .{ 2, 3, 2 }; 85 | { 86 | const s = size_test{ .ctx = 12 }; 87 | try s.run(&array_view); 88 | } 89 | 90 | // overlapping window 91 | array_view.offset = 4; 92 | array_view.stride = .{ 4, 1, 1 }; 93 | array_view.shape = .{ 5, 3, 2 }; 94 | { 95 | const s = size_test{ .ctx = 30 }; 96 | try s.run(&array_view); 97 | } 98 | 99 | // overlapping but not unit 100 | array_view.offset = 0; 101 | array_view.stride = .{ 4, 2, 1 }; 102 | array_view.shape = .{ 5, 3, 3 }; 103 | { 104 | const s = size_test{ .ctx = 45 }; 105 | try s.run(&array_view); 106 | } 107 | } 108 | 109 | test "StridedArrayView.iterator()" { 110 | var array_view = TestArrayView{ 111 | .items = one_to_23[0..], 112 | .stride = .{ 12, 4, 1 }, 113 | .shape = .{ 2, 3, 4 }, 114 | .offset = 0, 115 | }; 116 | { 117 | var iter = array_view.iterate(); 118 | var i: usize = 0; 119 | while (iter.next()) |val| : (i += 1) { 120 | try testing.expectEqual(i, val); 121 | } 122 | try testing.expectEqual(@as(usize, 24), i); 123 | } 124 | // strided 125 | array_view.stride = .{ 12, 4, 2 }; 126 | array_view.shape = .{ 2, 3, 2 }; 127 | { 128 | const exp = [_]TestArrayView.EltType{ 129 | // zig fmt: off 130 | 0, 2, 131 | 4, 6, 132 | 8, 10, 133 | 134 | 12, 14, 135 | 16, 18, 136 | 20, 22, 137 | // zig fmt: on 138 | }; 139 | var iter = array_view.iterate(); 140 | var i: usize = 0; 141 | while (iter.next()) |val| : (i += 1) { 142 | try testing.expectEqual(exp[i], val); 143 | } 144 | try testing.expectEqual(@as(usize, 12), i); 145 | } 146 | 147 | // strided + offset 148 | array_view.stride = .{ 12, 4, 2 }; 149 | array_view.shape = .{ 2, 3, 2 }; 150 | array_view.offset = 1; 151 | { 152 | const exp = [_]TestArrayView.EltType{ 153 | // zig fmt: off 154 | 1, 3, 155 | 5, 7, 156 | 9, 11, 157 | 158 | 13, 15, 159 | 17, 19, 160 | 21, 23, 161 | // zig fmt: on 162 | }; 163 | var iter = array_view.iterate(); 164 | var i: usize = 0; 165 | while (iter.next()) |val| : (i += 1) { 166 | try testing.expectEqual(exp[i], val); 167 | } 168 | try testing.expectEqual(@as(usize, 12), i); 169 | } 170 | 171 | // leave off last elt of inner dimension 172 | array_view.stride = .{ 12, 4, 1 }; 173 | array_view.shape = .{ 2, 3, 3 }; 174 | array_view.offset = 0; 175 | { 176 | const exp = [_]TestArrayView.EltType{ 177 | // zig fmt: off 178 | 0, 1, 2, 179 | 4, 5, 6, 180 | 8, 9, 10, 181 | 182 | 12, 13, 14, 183 | 16, 17, 18, 184 | 20, 21, 22, 185 | // zig fmt: on 186 | }; 187 | var iter = array_view.iterate(); 188 | var i: usize = 0; 189 | while (iter.next()) |val| : (i += 1) { 190 | try testing.expectEqual(exp[i], val); 191 | } 192 | try testing.expectEqual(@as(usize, 18), i); 193 | } 194 | 195 | // overlapping 196 | array_view.offset = 4; 197 | array_view.stride = .{ 4, 1, 1 }; 198 | array_view.shape = .{ 5, 3, 2 }; 199 | { 200 | const exp = [_]TestArrayView.EltType{ 201 | // zig fmt: off 202 | 4, 5, 5, 6, 6, 7, 203 | 8, 9, 9, 10, 10, 11, 204 | 12, 13, 13, 14, 14, 15, 205 | 16, 17, 17, 18, 18, 19, 206 | 20, 21, 21, 22, 22, 23, 207 | // zig fmt: on 208 | }; 209 | var iter = array_view.iterate(); 210 | var i: usize = 0; 211 | while (iter.next()) |val| : (i += 1) { 212 | try testing.expectEqual(exp[i], val); 213 | } 214 | try testing.expectEqual(@as(usize, 30), i); 215 | } 216 | 217 | // overlapping but not unit 218 | array_view.offset = 0; 219 | array_view.stride = .{ 4, 2, 1 }; 220 | array_view.shape = .{ 5, 3, 3 }; 221 | { 222 | const exp = [_]TestArrayView.EltType{ 223 | // zig fmt: off 224 | 0, 1, 2, 2, 3, 4, 4, 5, 6, 225 | 4, 5, 6, 6, 7, 8, 8, 9, 10, 226 | 8, 9, 10, 10, 11, 12, 12, 13, 14, 227 | 12, 13, 14, 14, 15, 16, 16, 17, 18, 228 | 16, 17, 18, 18, 19, 20, 20, 21, 22, 229 | // zig fmt: on 230 | }; 231 | var iter = array_view.iterate(); 232 | var i: usize = 0; 233 | while (iter.next()) |val| : (i += 1) { 234 | try testing.expectEqual(exp[i], val); 235 | } 236 | try testing.expectEqual(@as(usize, 45), i); 237 | } 238 | } 239 | 240 | test "StridedArrayVIew.iterateWrap()" { 241 | var array_view = TestArrayView{ 242 | .items = one_to_23[0..], 243 | .stride = .{ 12, 4, 1 }, 244 | .shape = .{ 2, 3, 4 }, 245 | .offset = 0, 246 | }; 247 | { 248 | const exp = [_]TestArrayView.EltType{ 249 | // zig fmt: off 250 | 0, 1, 2, 3, 0, 1, 2, 3, 251 | 4, 5, 6, 7, 4, 5, 6, 7, 252 | 8, 9, 10, 11, 8, 9, 10, 11, 253 | 0, 1, 2, 3, 0, 1, 2, 3, 254 | 4, 5, 6, 7, 4, 5, 6, 7, 255 | 256 | 12, 13, 14, 15, 12, 13, 14, 15, 257 | 16, 17, 18, 19, 16, 17, 18, 19, 258 | 20, 21, 22, 23, 20, 21, 22, 23, 259 | 12, 13, 14, 15, 12, 13, 14, 15, 260 | 16, 17, 18, 19, 16, 17, 18, 19, 261 | // zig fmt: on 262 | }; 263 | var iter = array_view.iterateWrap(.{0, 0, 0}, .{2, 5, 8}); 264 | var i: usize = 0; 265 | while (iter.next()) |item| : (i += 1) { 266 | try testing.expectEqual(exp[i], item); 267 | } 268 | } 269 | } 270 | 271 | test "StridedArrayView.transpose()" { 272 | var array_view = TestArrayView{ 273 | .items = one_to_23[0..], 274 | .stride = .{ 12, 4, 1 }, 275 | .shape = .{ 2, 3, 4 }, 276 | .offset = 0, 277 | }; 278 | array_view.transpose(1, 2); 279 | { 280 | const exp = [_]TestArrayView.EltType{ 281 | // zig fmt: off 282 | 0, 4, 8, 1, 283 | 5, 9, 2, 6, 284 | 10, 3, 7, 11, 285 | 286 | 12, 16, 20, 13, 287 | 17, 21, 14, 18, 288 | 22, 15, 19, 23, 289 | // zig fmt: on 290 | }; 291 | var iter = array_view.iterate(); 292 | var i: usize = 0; 293 | while (iter.next()) |val| : (i += 1) { 294 | try testing.expectEqual(exp[i], val); 295 | } 296 | } 297 | array_view.transpose(1, 2); 298 | array_view.transpose(0, 2); 299 | { 300 | const exp = [_]TestArrayView.EltType{ 301 | // zig fmt: off 302 | 0, 12, 4, 16, 303 | 8, 20, 1, 13, 304 | 5, 17, 9, 21, 305 | 306 | 2, 14, 6, 18, 307 | 10, 22, 3, 15, 308 | 7, 19, 11, 23, 309 | // zig fmt: on 310 | }; 311 | var iter = array_view.iterate(); 312 | var i: usize = 0; 313 | while (iter.next()) |val| : (i += 1) { 314 | try testing.expectEqual(exp[i], val); 315 | } 316 | } 317 | 318 | // overlapping but not unit 319 | array_view.offset = 0; 320 | array_view.stride = .{ 4, 2, 1 }; 321 | array_view.shape = .{ 5, 3, 3 }; 322 | array_view.transpose(1, 2); 323 | { 324 | const exp = [_]TestArrayView.EltType{ 325 | // zig fmt: off 326 | 0, 2, 4, 1, 3, 5, 2, 4, 6, 327 | 4, 6, 8, 5, 7, 9, 6, 8, 10, 328 | 8, 10, 12, 9, 11, 13, 10, 12, 14, 329 | 12, 14, 16, 13, 15, 17, 14, 16, 18, 330 | 16, 18, 20, 17, 19, 21, 18, 20, 22, 331 | // zig fmt: on 332 | }; 333 | var iter = array_view.iterate(); 334 | var i: usize = 0; 335 | while (iter.next()) |val| : (i += 1) { 336 | try testing.expectEqual(exp[i], val); 337 | } 338 | } 339 | 340 | } 341 | 342 | test "StridedArrayView.flip()" { 343 | var array_view = TestArrayView{ 344 | .items = one_to_23[0..], 345 | .stride = .{ 12, 4, 1 }, 346 | .shape = .{ 2, 3, 4 }, 347 | .offset = 0, 348 | }; 349 | array_view.flip(2); 350 | { 351 | const exp = [_]TestArrayView.EltType{ 352 | // zig fmt: off 353 | 3, 2, 1, 0, 354 | 7, 6, 5, 4, 355 | 11, 10, 9, 8, 356 | 357 | 15, 14, 13, 12, 358 | 19, 18, 17, 16, 359 | 23, 22, 21, 20, 360 | // zig fmt: on 361 | }; 362 | var iter = array_view.iterate(); 363 | var i: usize = 0; 364 | while (iter.next()) |val| : (i += 1) { 365 | try testing.expectEqual(exp[i], val); 366 | } 367 | } 368 | array_view.flip(2); 369 | array_view.flip(1); 370 | { 371 | const exp = [_]TestArrayView.EltType{ 372 | // zig fmt: off 373 | 8, 9, 10, 11, 374 | 4, 5, 6, 7, 375 | 0, 1, 2, 3, 376 | 377 | 20, 21, 22, 23, 378 | 16, 17, 18, 19, 379 | 12, 13, 14, 15, 380 | // zig fmt: on 381 | }; 382 | var iter = array_view.iterate(); 383 | var i: usize = 0; 384 | while (iter.next()) |val| : (i += 1) { 385 | try testing.expectEqual(exp[i], val); 386 | } 387 | } 388 | } 389 | 390 | test "StridedArrayView.slice()" { 391 | var array_view = TestArrayView{ 392 | .items = one_to_23[0..], 393 | .stride = .{ 12, 4, 1 }, 394 | .shape = .{ 2, 3, 4 }, 395 | .offset = 0, 396 | }; 397 | 398 | // trying to slice outside the bounds 399 | try testing.expect(array_view.sliceOrNull(.{2, 2, 2}, .{1, 1, 1}) == null); 400 | try testing.expect(array_view.sliceOrNull(.{1, 2, 2}, .{1, 2, 1}) == null); 401 | try testing.expect(array_view.sliceOrNull(.{1, 2, 2}, .{1, 1, 3}) == null); 402 | 403 | // a '2D' slice 404 | { 405 | const view_opt = array_view.sliceOrNull(.{0, 1, 1}, .{1, 2, 3}); 406 | try testing.expect(view_opt != null); 407 | const view = view_opt.?; 408 | const exp = [_]TestArrayView.EltType{ 409 | // zig fmt: off 410 | 5, 6, 7, 411 | 9, 10, 11, 412 | // zig fmt: on 413 | }; 414 | var iter = view.iterate(); 415 | var i: usize = 0; 416 | while (iter.next()) |val| : (i += 1) { 417 | try testing.expectEqual(exp[i], val); 418 | } 419 | } 420 | 421 | // a slice with zero size is not valid 422 | { 423 | const view_opt = array_view.sliceOrNull(.{0, 1, 1}, .{0, 2, 3}); 424 | try testing.expect(view_opt == null); 425 | } 426 | } 427 | 428 | test "StridedArrayView.sliceStep()" { 429 | var array_view = TestArrayView{ 430 | .items = one_to_23[0..], 431 | .stride = .{ 12, 4, 1 }, 432 | .shape = .{ 2, 3, 4 }, 433 | .offset = 0, 434 | }; 435 | 436 | // trying to slice outside the bounds 437 | try testing.expect(array_view.sliceStepOrNull(.{2, 2, 2}, .{1, 1, 1}, .{1, 1, 1}) == null); 438 | try testing.expect(array_view.sliceStepOrNull(.{1, 2, 2}, .{1, 2, 1}, .{1, 1, 1}) == null); 439 | try testing.expect(array_view.sliceStepOrNull(.{1, 2, 2}, .{1, 1, 3}, .{1, 1, 1}) == null); 440 | 441 | // a '2D' slice 442 | { 443 | const view_opt = array_view.sliceStepOrNull(.{0, 0, 1}, .{2, 3, 3}, .{1, 2, 2}); 444 | try testing.expect(view_opt != null); 445 | const view = view_opt.?; 446 | const exp = [_]TestArrayView.EltType{ 447 | // zig fmt: off 448 | 1, 3, 449 | 9, 11, 450 | 451 | 13, 15, 452 | 21, 23, 453 | // zig fmt: on 454 | }; 455 | var iter = view.iterate(); 456 | var i: usize = 0; 457 | while (iter.next()) |val| : (i += 1) { 458 | try testing.expectEqual(exp[i], val); 459 | } 460 | try testing.expectEqual(exp.len, i); 461 | } 462 | 463 | // a slice with zero size is not valid 464 | { 465 | const view_opt = array_view.sliceStepOrNull(.{0, 1, 1}, .{0, 2, 3}, .{1, 1, 1}); 466 | try testing.expect(view_opt == null); 467 | } 468 | } 469 | 470 | test "StridedArrayView.slidingWindow()" { 471 | var array_view = TestArrayView{ 472 | .items = one_to_23[0..], 473 | .stride = .{ 12, 4, 1 }, 474 | .shape = .{ 2, 3, 4 }, 475 | .offset = 0, 476 | }; 477 | 478 | const window = array_view.slidingWindow(2, .{3, 3}); 479 | try testing.expectEqualSlices(u32, &.{2, 1, 2, 3, 3}, window.shape[0..]); 480 | try testing.expectEqualSlices(TestArrayView.StrideType, &.{12, 4, 1, 4, 1}, window.stride[0..]); 481 | { 482 | const exp = [_]TestArrayView.EltType{ 483 | // zig fmt: off 484 | 0, 1, 2, 485 | 4, 5, 6, 486 | 8, 9, 10, 487 | 488 | 1, 2, 3, 489 | 5, 6, 7, 490 | 9, 10, 11, 491 | 492 | 12, 13, 14, 493 | 16, 17, 18, 494 | 20, 21, 22, 495 | 496 | 13, 14, 15, 497 | 17, 18, 19, 498 | 21, 22, 23, 499 | // zig fmt: on 500 | }; 501 | var iter = window.iterate(); 502 | var i: usize = 0; 503 | while (iter.next()) |val| : (i += 1) { 504 | try testing.expectEqual(exp[i], val); 505 | } 506 | } 507 | } 508 | -------------------------------------------------------------------------------- /src/test_utils.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn ForAllSymmetries( 4 | comptime T: type, 5 | comptime TestArrayView: type, 6 | comptime func: fn (T, TestArrayView) anyerror!void, 7 | ) type { 8 | return struct { 9 | const Self = @This(); 10 | 11 | ctx: T, 12 | 13 | pub fn run(self: Self, array_view: *TestArrayView) !void { 14 | // transpositions 15 | { 16 | for (0..TestArrayView.dim_count) |i| { 17 | for (0..TestArrayView.dim_count) |j| { 18 | array_view.transpose(i, j); 19 | try func(self.ctx, array_view.*); 20 | array_view.transpose(i, j); 21 | } 22 | } 23 | } 24 | 25 | // transposition + flip (i.e. rotation) note that when i == j, it's just a flip 26 | { 27 | for (0..TestArrayView.dim_count) |i| { 28 | for (0..TestArrayView.dim_count) |j| { 29 | array_view.transpose(i, j); 30 | for (0..TestArrayView.dim_count) |k| { 31 | array_view.flip(k); 32 | try func(self.ctx, array_view.*); 33 | array_view.flip(k); 34 | } 35 | array_view.transpose(i, j); 36 | } 37 | } 38 | } 39 | } 40 | }; 41 | } 42 | --------------------------------------------------------------------------------