├── Makefile ├── README.md ├── requirements-test.txt ├── requirements.txt ├── tensor1d.c ├── tensor1d.h ├── tensor1d.py └── test_tensor1d.py /Makefile: -------------------------------------------------------------------------------- 1 | CC = gcc 2 | CFLAGS = -Wall -O3 3 | LDFLAGS = -lm 4 | 5 | # turn on all the warnings 6 | # https://github.com/mcinglis/c-style 7 | CFLAGS += -Wall -Wextra -Wpedantic \ 8 | -Wformat=2 -Wno-unused-parameter -Wshadow \ 9 | -Wwrite-strings -Wstrict-prototypes -Wold-style-definition \ 10 | -Wredundant-decls -Wnested-externs -Wmissing-include-dirs 11 | 12 | # Main targets 13 | all: tensor1d libtensor1d.so 14 | 15 | # Compile the main executable 16 | tensor1d: tensor1d.c tensor1d.h 17 | $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) 18 | 19 | # Create shared library 20 | libtensor1d.so: tensor1d.c tensor1d.h 21 | $(CC) $(CFLAGS) -shared -fPIC -o $@ $< $(LDFLAGS) 22 | 23 | # Clean up build artifacts 24 | clean: 25 | rm -f tensor1d libtensor1d.so 26 | 27 | # Test using pytest 28 | test: 29 | pytest 30 | 31 | .PHONY: all clean test tensor1d 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensor 2 | 3 | In this module we build a small `Tensor` in C, along the lines of `torch.Tensor` or `numpy.ndarray`. The current code implements a simple 1-dimensional float tensor that we can access and slice. We get to see that the tensor object maintains both a `Storage` that holds the 1-dimensional data as it is in physical memory, and a `View` over that memory that has some start, end, and stride. This allows us to efficiently slice into a Tensor without creating any additional memory, because the `Storage` is re-used, while the `View` is updated to reflect the new start, end, and stride. We then get to see how we can wrap our C tensor into a Python module, just like PyTorch and numpy do. 4 | 5 | The source code of the 1D Tensor is in [tensor1d.h](tensor1d.h) and [tensor1d.c](tensor1d.c). You can compile and run this simply as: 6 | 7 | ```bash 8 | gcc -Wall -O3 tensor1d.c -o tensor1d 9 | ./tensor1d 10 | ``` 11 | 12 | The code contains both the `Tensor` class, and also a short `int main` that just has a toy example. We can now wrap up this C code into a Python module so we can access it there. For that, compile it as a shared library: 13 | 14 | ```bash 15 | gcc -O3 -shared -fPIC -o libtensor1d.so tensor1d.c 16 | ``` 17 | 18 | This writes a `libtensor1d.so` shared library that we can load from Python using the [cffi](https://cffi.readthedocs.io/en/latest/) library, which you can see in the [tensor1d.py](tensor1d.py) file. We can then use this in Python simply like: 19 | 20 | ```python 21 | import tensor1d 22 | 23 | # 1D tensor of [0, 1, 2, ..., 19] 24 | t = tensor1d.arange(20) 25 | 26 | # getitem / setitem functionality 27 | print(t[3]) # prints 3.0 28 | t[-1] = 100.0 # sets the last element to 100.0 29 | 30 | # slicing, prints [5, 7, 9, 11, 13] 31 | print(t[5:15:2]) 32 | 33 | # slice of a slice works ok! prints [9, 11, 13] 34 | # (note how the end range is oob and gets cropped) 35 | print(t[5:15:2][2:7]) 36 | 37 | # add a scalar to the whole tensor 38 | t = t + 10.0 39 | 40 | # add two tensors (of the same size) together 41 | t2 = tensor1d.arange(20) 42 | t3 = t + t2 43 | 44 | # add two tensors together with broadcasting 45 | t4 = t + tensor1d.tensor([10.0]) 46 | ``` 47 | 48 | Finally the tests use [pytest](https://docs.pytest.org/en/stable/) and can be found in [test_tensor1d.py](test_tensor1d.py). You can run this as `pytest test_tensor1d.py`. 49 | 50 | It is well worth understanding this topic because you can get fairly fancy with torch tensors and you have to be careful and aware of the memory underlying your code, when we're creating new storage or just a new view, functions that may or may not only accept "contiguous" tensors. Another pitfall is when you e.g. create a small slice of a big tensor, assuming that somehow the big tensor will be garbage collected, but in reality the big tensor will still be around because the small slice is just a view over the big tensor's storage. The same would be true of our own tensor here. 51 | 52 | Actual production-grade tensors like `torch.Tensor` have a lot more functionality we won't cover. You can have different `dtype` not just float, different `device`, different `layout`, and tensors can be quantized, encrypted, etc etc. 53 | 54 | TODOs: 55 | 56 | - bring our own implementation closer to `torch.Tensor` 57 | - make tests better 58 | - implement 2D tensor, where we have to start worrying about 2D shapes/strides 59 | - implement broadcasting for 2D tensor 60 | 61 | Good related resources: 62 | - [PyTorch internals](http://blog.ezyang.com/2019/05/pytorch-internals/) 63 | - [Numpy paper](https://arxiv.org/abs/1102.1523) 64 | 65 | ### License 66 | 67 | MIT -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | torch 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | -------------------------------------------------------------------------------- /tensor1d.c: -------------------------------------------------------------------------------- 1 | /* 2 | Implements a 1-dimensional Tensor, similar to torch.Tensor. 3 | 4 | Compile and run like: 5 | gcc -Wall -O3 tensor1d.c -o tensor1d && ./tensor1d 6 | 7 | Or create .so for use with cffi: 8 | gcc -O3 -shared -fPIC -o libtensor1d.so tensor1d.c 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include "tensor1d.h" 17 | 18 | // ---------------------------------------------------------------------------- 19 | // memory allocation 20 | 21 | void *malloc_check(size_t size, const char *file, int line) { 22 | void *ptr = malloc(size); 23 | if (ptr == NULL) { 24 | fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); 25 | exit(EXIT_FAILURE); 26 | } 27 | return ptr; 28 | } 29 | #define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) 30 | 31 | // ---------------------------------------------------------------------------- 32 | // utils 33 | 34 | int ceil_div(int a, int b) { 35 | // integer division that rounds up, i.e. ceil(a / b) 36 | return (a + b - 1) / b; 37 | } 38 | 39 | int min(int a, int b) { 40 | return (a < b) ? a : b; 41 | } 42 | 43 | int max(int a, int b) { 44 | return (a > b) ? a : b; 45 | } 46 | 47 | // ---------------------------------------------------------------------------- 48 | // Storage: simple array of floats, defensive on index access, reference-counted 49 | // The reference counting allows multiple Tensors sharing the same Storage. 50 | // similar to torch.Storage 51 | 52 | Storage* storage_new(int size) { 53 | assert(size >= 0); 54 | Storage* storage = mallocCheck(sizeof(Storage)); 55 | storage->data = mallocCheck(size * sizeof(float)); 56 | storage->data_size = size; 57 | storage->ref_count = 1; 58 | return storage; 59 | } 60 | 61 | float storage_getitem(Storage* s, int idx) { 62 | assert(idx >= 0 && idx < s->data_size); 63 | return s->data[idx]; 64 | } 65 | 66 | void storage_setitem(Storage* s, int idx, float val) { 67 | assert(idx >= 0 && idx < s->data_size); 68 | s->data[idx] = val; 69 | } 70 | 71 | void storage_incref(Storage* s) { 72 | s->ref_count++; 73 | } 74 | 75 | void storage_decref(Storage* s) { 76 | s->ref_count--; 77 | if (s->ref_count == 0) { 78 | free(s->data); 79 | free(s); 80 | } 81 | } 82 | 83 | // ---------------------------------------------------------------------------- 84 | // Tensor class functions 85 | 86 | // torch.empty(size) 87 | Tensor* tensor_empty(int size) { 88 | Tensor* t = mallocCheck(sizeof(Tensor)); 89 | t->storage = storage_new(size); 90 | // at init we cover the whole storage, i.e. range(start=0, stop=size, step=1) 91 | t->offset = 0; 92 | t->size = size; 93 | t->stride = 1; 94 | // holds the text representation of the tensor 95 | t->repr = NULL; 96 | return t; 97 | } 98 | 99 | // torch.arange(size) 100 | Tensor* tensor_arange(int size) { 101 | Tensor* t = tensor_empty(size); 102 | for (int i = 0; i < t->size; i++) { 103 | tensor_setitem(t, i, (float) i); 104 | } 105 | return t; 106 | } 107 | 108 | int logical_to_physical(Tensor *t, int ix) { 109 | int idx = t->offset + ix * t->stride; 110 | return idx; 111 | } 112 | 113 | // Index into the tensor. 114 | // Note that both PyTorch and numpy actually return a 1-element Tensor when you index like: 115 | // val = t[ix] 116 | // This particular function returns the actual float, i.e.: 117 | // val = t[ix].item() 118 | float tensor_getitem(Tensor* t, int ix) { 119 | // handle negative indices by wrapping around 120 | if (ix < 0) { ix = t->size + ix; } 121 | // oob indices raise IndexError (and we return NaN) 122 | if (ix >= t->size) { 123 | fprintf(stderr, "IndexError: index %d is out of bounds of %d\n", ix, t->size); 124 | return NAN; 125 | } 126 | // get the physical index into the storage and return the value 127 | int idx = logical_to_physical(t, ix); 128 | float val = storage_getitem(t->storage, idx); 129 | return val; 130 | } 131 | 132 | // The _astensor version of getitem: 133 | // val = t[ix] 134 | // i.e. consistent with PyTorch/numpy create a 1-element Tensor and return it 135 | Tensor* tensor_getitem_astensor(Tensor* t, int ix) { 136 | // wrap around negative indices so we can do +1 below with confidence 137 | if (ix < 0) { ix = t->size + ix; } 138 | // effectively: t[ix:ix+1:1] <=> t[ix:ix+1] <=> t[ix] 139 | Tensor* slice = tensor_slice(t, ix, ix + 1, 1); 140 | return slice; 141 | } 142 | 143 | // t[ix] = val 144 | void tensor_setitem(Tensor* t, int ix, float val) { 145 | // handle negative indices by wrapping around 146 | if (ix < 0) { ix = t->size + ix; } 147 | if (ix >= t->size) { 148 | fprintf(stderr, "IndexError: index %d is out of bounds of %d\n", ix, t->size); 149 | return; 150 | } 151 | int idx = logical_to_physical(t, ix); 152 | storage_setitem(t->storage, idx, val); 153 | } 154 | 155 | // same as .item() on a torch.Tensor: strips 1-element Tensor to simple scalar 156 | float tensor_item(Tensor* t) { 157 | if (t->size != 1) { 158 | fprintf(stderr, "ValueError: can only convert an array of size 1 to a Python scalar\n"); 159 | return NAN; 160 | } 161 | return tensor_getitem(t, 0); 162 | } 163 | 164 | // return a new Tensor with a new view, but same Storage, i.e.: 165 | // t[start:end:step] 166 | Tensor* tensor_slice(Tensor* t, int start, int end, int step) { 167 | // 1) handle negative indices by wrapping around 168 | if (start < 0) { start = t->size + start; } 169 | if (end < 0) { end = t->size + end; } 170 | // 2) handle out-of-bounds indices: clip to [0, t->size] range 171 | start = min(max(start, 0), t->size); 172 | end = min(max(end, 0), t->size); 173 | // 3) handle step 174 | if (step == 0) { 175 | fprintf(stderr, "ValueError: slice step cannot be zero\n"); 176 | return tensor_empty(0); 177 | } 178 | if (step < 0) { 179 | // TODO possibly support negative step 180 | // PyTorch does not support negative step (numpy does) 181 | fprintf(stderr, "ValueError: slice step cannot be negative\n"); 182 | return tensor_empty(0); 183 | } 184 | // create the new Tensor: same Storage but new View 185 | Tensor* s = mallocCheck(sizeof(Tensor)); 186 | s->storage = t->storage; // inherit the underlying storage! 187 | s->size = ceil_div(end - start, step); 188 | s->offset = t->offset + start * t->stride; 189 | s->stride = t->stride * step; 190 | s->repr = NULL; 191 | storage_incref(s->storage); // increment the reference count 192 | return s; 193 | } 194 | 195 | Tensor* tensor_addf(Tensor* t, float val) { 196 | // adds a float to each element of the tensor, returns a new tensor 197 | Tensor* result = tensor_empty(t->size); 198 | for (int i = 0; i < t->size; i++) { 199 | float old_val = tensor_getitem(t, i); 200 | float new_val = old_val + val; 201 | tensor_setitem(result, i, new_val); 202 | } 203 | return result; 204 | } 205 | 206 | bool broadcastable(Tensor* t1, Tensor* t2) { 207 | // two tensors broadcast if, in each dimension (we only have 1 here) 208 | // tensors either have the same size, or one of their sizes is 1 209 | return t1->size == t2->size || t1->size == 1 || t2->size == 1; 210 | } 211 | 212 | Tensor* tensor_add(Tensor* t1, Tensor* t2) { 213 | if (!broadcastable(t1, t2)) { return NULL; } 214 | int result_size = max(t1->size, t2->size); 215 | Tensor* result = tensor_empty(result_size); 216 | int t1_index = 0; 217 | int t2_index = 0; 218 | int t1_stride = t1->size > 1 ? 1 : 0; // either we walk this tensor or not 219 | int t2_stride = t2->size > 1 ? 1 : 0; // either we walk this tensor or not 220 | // walk the output tensor and add the values 221 | for (int result_index = 0; result_index < result_size; result_index++) { 222 | float val1 = tensor_getitem(t1, t1_index); 223 | float val2 = tensor_getitem(t2, t2_index); 224 | float val = val1 + val2; 225 | tensor_setitem(result, result_index, val); 226 | t1_index += t1_stride; 227 | t2_index += t2_stride; 228 | } 229 | return result; 230 | } 231 | 232 | char* tensor_to_string(Tensor* t) { 233 | // if we already have a string representation, return it 234 | if (t->repr != NULL) { return t->repr; } 235 | // otherwise create a new string representation 236 | int max_size = t->size * 20 + 3; // 20 chars/number, brackets and commas 237 | t->repr = mallocCheck(max_size); 238 | char* current = t->repr; 239 | current += sprintf(current, "["); 240 | for (int i = 0; i < t->size; i++) { 241 | float val = tensor_getitem(t, i); 242 | current += sprintf(current, "%.1f", val); 243 | if (i < t->size - 1) { 244 | current += sprintf(current, ", "); 245 | } 246 | } 247 | current += sprintf(current, "]"); 248 | // ensure we didn't write past the end of the buffer 249 | assert(current - t->repr < max_size); 250 | return t->repr; 251 | } 252 | 253 | void tensor_print(Tensor* t) { 254 | char* str = tensor_to_string(t); 255 | printf("%s\n", str); 256 | } 257 | 258 | void tensor_free(Tensor* t) { 259 | storage_decref(t->storage); 260 | free(t->repr); 261 | free(t); 262 | } 263 | 264 | // ---------------------------------------------------------------------------- 265 | 266 | int main(int argc, char *argv[]) { 267 | // create a tensor with 20 elements 268 | Tensor* t = tensor_arange(20); 269 | tensor_print(t); 270 | // slice the tensor as t[5:15:1] 271 | Tensor* s = tensor_slice(t, 5, 15, 1); 272 | tensor_print(s); 273 | // slice that tensor as s[2:7:2] 274 | Tensor* ss = tensor_slice(s, 2, 7, 2); 275 | tensor_print(ss); 276 | // print element -1 277 | float val = tensor_getitem(ss, -1); 278 | printf("ss[-1] = %.1f\n", val); 279 | 280 | tensor_free(ss); 281 | tensor_free(s); 282 | tensor_free(t); 283 | 284 | return 0; 285 | } 286 | -------------------------------------------------------------------------------- /tensor1d.h: -------------------------------------------------------------------------------- 1 | /* 2 | tensor1d.h 3 | */ 4 | 5 | #ifndef TENSOR1D_H 6 | #define TENSOR1D_H 7 | 8 | #include 9 | #include 10 | 11 | typedef struct { 12 | float* data; 13 | int data_size; 14 | int ref_count; 15 | } Storage; 16 | 17 | // The equivalent of tensor in PyTorch 18 | typedef struct { 19 | Storage* storage; 20 | int offset; 21 | int size; 22 | int stride; 23 | char* repr; // holds the text representation of the tensor 24 | } Tensor; 25 | 26 | Tensor* tensor_empty(int size); 27 | int logical_to_physical(Tensor *t, int ix); 28 | float tensor_getitem(Tensor* t, int ix); 29 | Tensor* tensor_getitem_astensor(Tensor* t, int ix); 30 | float tensor_item(Tensor* t); 31 | void tensor_setitem(Tensor* t, int ix, float val); 32 | Tensor* tensor_arange(int size); 33 | char* tensor_to_string(Tensor* t); 34 | void tensor_print(Tensor* t); 35 | Tensor* tensor_slice(Tensor* t, int start, int end, int step); 36 | Tensor* tensor_addf(Tensor* t, float val); 37 | Tensor* tensor_add(Tensor* t1, Tensor* t2); 38 | void tensor_incref(Tensor* t); 39 | void tensor_decref(Tensor* t); 40 | void tensor_free(Tensor* t); 41 | 42 | #endif // TENSOR1D_H 43 | -------------------------------------------------------------------------------- /tensor1d.py: -------------------------------------------------------------------------------- 1 | import cffi 2 | 3 | # ----------------------------------------------------------------------------- 4 | ffi = cffi.FFI() 5 | ffi.cdef(""" 6 | typedef struct { 7 | float* data; 8 | int data_size; 9 | int ref_count; 10 | } Storage; 11 | 12 | // The equivalent of tensor in PyTorch 13 | typedef struct { 14 | Storage* storage; 15 | int offset; 16 | int size; 17 | int stride; 18 | char* repr; // holds the text representation of the tensor 19 | } Tensor; 20 | 21 | Tensor* tensor_empty(int size); 22 | int logical_to_physical(Tensor *t, int ix); 23 | float tensor_getitem(Tensor* t, int ix); 24 | Tensor* tensor_getitem_astensor(Tensor* t, int ix); 25 | float tensor_item(Tensor* t); 26 | void tensor_setitem(Tensor* t, int ix, float val); 27 | Tensor* tensor_arange(int size); 28 | char* tensor_to_string(Tensor* t); 29 | void tensor_print(Tensor* t); 30 | Tensor* tensor_slice(Tensor* t, int start, int end, int step); 31 | Tensor* tensor_addf(Tensor* t, float val); 32 | Tensor* tensor_add(Tensor* t1, Tensor* t2); 33 | void tensor_incref(Tensor* t); 34 | void tensor_decref(Tensor* t); 35 | void tensor_free(Tensor* t); 36 | """) 37 | lib = ffi.dlopen("./libtensor1d.so") # Make sure to compile the C code into a shared library 38 | # ----------------------------------------------------------------------------- 39 | 40 | class Tensor: 41 | def __init__(self, size_or_data=None, c_tensor=None): 42 | # let's ensure only one of size_or_data and c_tensor is passed 43 | assert (size_or_data is not None) ^ (c_tensor is not None), "Either size_or_data or c_tensor must be passed" 44 | # let's initialize the tensor 45 | if c_tensor is not None: 46 | self.tensor = c_tensor 47 | elif isinstance(size_or_data, int): 48 | self.tensor = lib.tensor_empty(size_or_data) 49 | elif isinstance(size_or_data, (list, range)): 50 | self.tensor = lib.tensor_arange(len(size_or_data)) 51 | for i, val in enumerate(size_or_data): 52 | lib.tensor_setitem(self.tensor, i, float(val)) 53 | else: 54 | raise TypeError("Input must be an integer size or a list/range of values") 55 | 56 | def __del__(self): 57 | # TODO: when Python intepreter is shutting down, lib can become None 58 | # I'm not 100% sure how to do cleanup in cffi here properly 59 | if lib is not None: 60 | if hasattr(self, 'tensor'): 61 | lib.tensor_free(self.tensor) 62 | 63 | def __getitem__(self, key): 64 | if isinstance(key, int): 65 | c_tensor = lib.tensor_getitem_astensor(self.tensor, key) 66 | return Tensor(c_tensor=c_tensor) 67 | elif isinstance(key, slice): 68 | # assign default values to start, stop, and step 69 | start = key.start if key.start is not None else 0 70 | stop = self.tensor.size if key.stop is None else key.stop 71 | step = 1 if key.step is None else key.step 72 | # call the C function to slice the tensor 73 | sliced_tensor = lib.tensor_slice(self.tensor, start, stop, step) 74 | return Tensor(c_tensor=sliced_tensor) # Pass the C tensor directly 75 | else: 76 | raise TypeError("Invalid index type") 77 | 78 | def __setitem__(self, key, value): 79 | if isinstance(key, int): 80 | lib.tensor_setitem(self.tensor, key, float(value)) 81 | else: 82 | raise TypeError("Invalid index type") 83 | 84 | def __add__(self, other): 85 | if isinstance(other, (int, float)): 86 | c_tensor = lib.tensor_addf(self.tensor, float(other)) 87 | elif isinstance(other, Tensor): 88 | c_tensor = lib.tensor_add(self.tensor, other.tensor) 89 | else: 90 | raise TypeError("Invalid type for addition") 91 | if c_tensor == ffi.NULL: 92 | raise ValueError("RuntimeError: tensor add returned NULL") 93 | return Tensor(c_tensor=c_tensor) 94 | 95 | def __len__(self): 96 | return self.tensor.size 97 | 98 | def __repr__(self): 99 | return self.__str__() 100 | 101 | def __str__(self): 102 | c_str = lib.tensor_to_string(self.tensor) 103 | py_str = ffi.string(c_str).decode('utf-8') 104 | return py_str 105 | 106 | def tolist(self): 107 | return [lib.tensor_getitem(self.tensor, i) for i in range(len(self))] 108 | 109 | def item(self): 110 | return lib.tensor_item(self.tensor) 111 | 112 | def empty(size): 113 | return Tensor(size) 114 | 115 | def arange(size): 116 | c_tensor = lib.tensor_arange(size) 117 | return Tensor(c_tensor=c_tensor) 118 | 119 | def tensor(data): 120 | return Tensor(data) -------------------------------------------------------------------------------- /test_tensor1d.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import tensor1d 4 | 5 | def assert_tensor_equal(torch_tensor, tensor1d_tensor): 6 | assert torch_tensor.tolist() == tensor1d_tensor.tolist() 7 | 8 | @pytest.mark.parametrize("size", [0, 1, 10, 100]) 9 | def test_arange(size): 10 | torch_tensor = torch.arange(size, dtype=torch.float32) 11 | tensor1d_tensor = tensor1d.arange(size) 12 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 13 | 14 | @pytest.mark.parametrize("case", [[], [1], [1, 2, 3], list(range(100))]) 15 | def test_tensor_creation(case): 16 | torch_tensor = torch.tensor(case) 17 | tensor1d_tensor = tensor1d.tensor(case) 18 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 19 | 20 | @pytest.mark.parametrize("size", [0, 1, 10, 100]) 21 | def test_empty(size): 22 | torch_tensor = torch.empty(size) 23 | tensor1d_tensor = tensor1d.empty(size) 24 | assert len(torch_tensor) == len(tensor1d_tensor) 25 | 26 | @pytest.mark.parametrize("index", range(1, 10)) 27 | def test_indexing(index): 28 | torch_tensor = torch.arange(10, dtype=torch.float32) 29 | tensor1d_tensor = tensor1d.arange(10) 30 | assert torch_tensor[index].item() == tensor1d_tensor[index].item() 31 | 32 | @pytest.mark.parametrize("slice_params", [ 33 | (None, None, None), # [:] 34 | (5, None, None), # [5:] 35 | (None, 15, None), # [:15] 36 | (5, 15, None), # [5:15] 37 | (None, None, 2), # [::2] 38 | (5, 15, 2), # [5:15:2] 39 | (5, 15, 15), # [5:15:15] 40 | ]) 41 | def test_slicing(slice_params): 42 | torch_tensor = torch.arange(20, dtype=torch.float32) 43 | tensor1d_tensor = tensor1d.arange(20) 44 | s = slice(*slice_params) 45 | assert_tensor_equal(torch_tensor[s], tensor1d_tensor[s]) 46 | 47 | def test_invalid_input(): 48 | with pytest.raises(TypeError): 49 | tensor1d.tensor("not a valid input") 50 | 51 | def test_invalid_index(): 52 | t = tensor1d.arange(5) 53 | with pytest.raises(TypeError): 54 | t["invalid index"] 55 | 56 | @pytest.mark.parametrize("initial_slice, second_slice", [ 57 | ((5, 15, 1), (2, 7, 1)), # Basic case 58 | ((5, 15, 1), (None, None, 1)), # Full slice 59 | ((5, 15, 1), (None, None, 2)), # Every other element 60 | ((5, 15, 2), (None, None, 2)), # Every other of every other 61 | ((0, 20, 1), (-5, None, 1)), # Negative start index 62 | ((0, 20, 1), (None, -5, 1)), # Negative end index 63 | ((0, 20, 1), (-15, -5, 1)), # Negative start and end indices 64 | ((5, 15, 1), (100, None, 1)), # Start index out of range 65 | ((5, 15, 1), (None, 100, 1)), # End index out of range 66 | ((5, 15, 1), (-100, None, 1)), # Negative start index out of range 67 | ((5, 15, 1), (None, -100, 1)), # Negative end index out of range 68 | ((0, 20, 1), (0, 0, 1)), # Empty slice 69 | ((0, 0, 1), (None, None, 1)), # Slice of empty slice 70 | ]) 71 | def test_slice_of_slice(initial_slice, second_slice): 72 | torch_tensor = torch.arange(20, dtype=torch.float32) 73 | tensor1d_tensor = tensor1d.arange(20) 74 | 75 | torch_slice = torch_tensor[slice(*initial_slice)] 76 | tensor1d_slice = tensor1d_tensor[slice(*initial_slice)] 77 | 78 | torch_result = torch_slice[slice(*second_slice)] 79 | tensor1d_result = tensor1d_slice[slice(*second_slice)] 80 | 81 | assert_tensor_equal(torch_result, tensor1d_result) 82 | 83 | def test_multiple_slices(): 84 | torch_tensor = torch.arange(100, dtype=torch.float32) 85 | tensor1d_tensor = tensor1d.arange(100) 86 | 87 | torch_result = torch_tensor[10:90:2][5:35:3][::2] 88 | tensor1d_result = tensor1d_tensor[10:90:2][5:35:3][::2] 89 | 90 | assert_tensor_equal(torch_result, tensor1d_result) 91 | 92 | # Test for behavior with step sizes > 1 93 | @pytest.mark.parametrize("step", [2, 3, 5]) 94 | def test_slices_with_steps(step): 95 | torch_tensor = torch.arange(50, dtype=torch.float32) 96 | tensor1d_tensor = tensor1d.arange(50) 97 | 98 | torch_result = torch_tensor[::step][5:20] 99 | tensor1d_result = tensor1d_tensor[::step][5:20] 100 | 101 | assert_tensor_equal(torch_result, tensor1d_result) 102 | 103 | # Test for behavior with different slice sizes 104 | @pytest.mark.parametrize("size", [10, 20, 50, 100]) 105 | def test_slices_with_different_sizes(size): 106 | torch_tensor = torch.arange(size, dtype=torch.float32) 107 | tensor1d_tensor = tensor1d.arange(size) 108 | 109 | torch_result = torch_tensor[size//4:3*size//4][::2] 110 | tensor1d_result = tensor1d_tensor[size//4:3*size//4][::2] 111 | 112 | assert_tensor_equal(torch_result, tensor1d_result) 113 | 114 | # Test for behavior with overlapping slices 115 | def test_overlapping_slices(): 116 | torch_tensor = torch.arange(30, dtype=torch.float32) 117 | tensor1d_tensor = tensor1d.arange(30) 118 | 119 | torch_result = torch_tensor[5:25][3:15] 120 | tensor1d_result = tensor1d_tensor[5:25][3:15] 121 | 122 | assert_tensor_equal(torch_result, tensor1d_result) 123 | 124 | # Test for behavior with adjacent slices 125 | def test_adjacent_slices(): 126 | torch_tensor = torch.arange(20, dtype=torch.float32) 127 | tensor1d_tensor = tensor1d.arange(20) 128 | 129 | torch_result = torch_tensor[5:15][0:10] 130 | tensor1d_result = tensor1d_tensor[5:15][0:10] 131 | 132 | assert_tensor_equal(torch_result, tensor1d_result) 133 | 134 | # Test accessing elements, including negative indices 135 | def test_getitem(): 136 | torch_tensor = torch.arange(20, dtype=torch.float32) 137 | tensor1d_tensor = tensor1d.arange(20) 138 | assert torch_tensor[0].item() == tensor1d_tensor[0].item() 139 | assert torch_tensor[5].item() == tensor1d_tensor[5].item() 140 | assert torch_tensor[-1].item() == tensor1d_tensor[-1].item() 141 | assert torch_tensor[-5].item() == tensor1d_tensor[-5].item() 142 | 143 | # Test setting elements, including negative indices 144 | def test_setitem(): 145 | torch_tensor = torch.arange(20, dtype=torch.float32) 146 | tensor1d_tensor = tensor1d.arange(20) 147 | 148 | torch_tensor[0] = 100 149 | tensor1d_tensor[0] = 100 150 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 151 | 152 | torch_tensor[5] = 200 153 | tensor1d_tensor[5] = 200 154 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 155 | 156 | torch_tensor[-1] = 300 157 | tensor1d_tensor[-1] = 300 158 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 159 | 160 | torch_tensor[-5] = 400 161 | tensor1d_tensor[-5] = 400 162 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 163 | 164 | # Test setting elements indirectly (via a slice) 165 | def test_setitem_indirect(): 166 | torch_tensor = torch.arange(20, dtype=torch.float32) 167 | tensor1d_tensor = tensor1d.arange(20) 168 | torch_view = torch_tensor[5:15] 169 | tensor1d_view = tensor1d_tensor[5:15] 170 | 171 | torch_view[0] = 100 172 | tensor1d_view[0] = 100 173 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 174 | 175 | torch_view[-1] = 200 176 | tensor1d_view[-1] = 200 177 | assert_tensor_equal(torch_tensor, tensor1d_tensor) 178 | 179 | # test addition 180 | def test_addition(): 181 | 182 | # simple element-wise addition 183 | torch_tensor = torch.arange(20, dtype=torch.float32) 184 | tensor1d_tensor = tensor1d.arange(20) 185 | torch_result = torch_tensor + 5.0 186 | tensor1d_result = tensor1d_tensor + 5.0 187 | assert_tensor_equal(torch_result, tensor1d_result) 188 | 189 | # now test adding a float 190 | torch_result = torch_tensor + 6.0 191 | tensor1d_result = tensor1d_tensor + 6.0 192 | assert_tensor_equal(torch_result, tensor1d_result) 193 | 194 | # test broadcasting add with a 1-element tensor on the right 195 | torch_result = torch_tensor + torch.tensor([123.0]) 196 | tensor1d_result = tensor1d_tensor + tensor1d.tensor([123.0]) 197 | assert_tensor_equal(torch_result, tensor1d_result) 198 | 199 | # and on the left 200 | torch_result = torch.tensor([42.0]) + torch_tensor 201 | tensor1d_result = tensor1d.tensor([42.0]) + tensor1d_tensor 202 | assert_tensor_equal(torch_result, tensor1d_result) 203 | 204 | # and now test invalid cases 205 | with pytest.raises(TypeError): 206 | tensor1d_tensor + "not a valid input" 207 | 208 | with pytest.raises(TypeError): 209 | tensor1d_tensor + [1, 2, 3] 210 | 211 | with pytest.raises(ValueError): 212 | tensor1d_tensor + tensor1d.arange(5) 213 | --------------------------------------------------------------------------------