├── src ├── utils │ ├── types.h │ ├── augment.h │ ├── cifar10.h │ ├── aligned_alloc.h │ ├── metrics.h │ ├── cifar10.c │ ├── augment.c │ └── tqdm.h ├── ops │ ├── reshapes.h │ ├── losses.h │ ├── reductions.h │ ├── activations.h │ ├── convolutions.h │ ├── losses_backward.h │ ├── reshapes_backward.h │ ├── activations_backward.h │ ├── arithmetic.h │ ├── reductions_backward.h │ ├── arithmetic_backward.h │ ├── convolutions_backward.h │ ├── reshapes_backward.c │ ├── activations.c │ ├── reshapes.c │ ├── losses.c │ ├── reductions_backward.c │ ├── losses_backward.c │ ├── arithmetic_backward.c │ ├── arithmetic.c │ ├── activations_backward.c │ └── reductions.c ├── layers.h ├── optimizers.h ├── tensor.h ├── autograd.h ├── autograd.c ├── optimizers.c ├── layers.c └── tensor.c ├── cppcheck-suppressions.txt ├── docs ├── about.md ├── pytorch_ast_optimization.py ├── design.md └── autodiff.py ├── entitlements.plist ├── suppr.txt ├── .gitignore ├── Makefile ├── README ├── CONTRIBUTING.md ├── CMakeLists.txt ├── .clang-format └── test ├── test_losses_backward.c ├── test_activations_backward.c ├── test_convolutions_backward.c └── test_arithmetic.c /src/utils/types.h: -------------------------------------------------------------------------------- 1 | typedef float float32_t; 2 | typedef double float64_t; 3 | -------------------------------------------------------------------------------- /cppcheck-suppressions.txt: -------------------------------------------------------------------------------- 1 | unusedFunction 2 | staticFunction 3 | checkersReport 4 | missingIncludeSystem 5 | duplicateBranch 6 | -------------------------------------------------------------------------------- /src/ops/reshapes.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | Tensor *tensor_reshape(const Tensor *t, const int64_t *new_shape, uint64_t new_ndim); 7 | Tensor *tensor_transpose(const Tensor *t, uint64_t dim0, uint64_t dim1); 8 | -------------------------------------------------------------------------------- /src/ops/losses.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | 5 | Tensor *mse_loss(const Tensor *predictions, const Tensor *targets); 6 | Tensor *cross_entropy_loss(const Tensor *logits, const Tensor *targets); 7 | Tensor *binary_cross_entropy_loss(const Tensor *predictions, const Tensor *targets); 8 | -------------------------------------------------------------------------------- /docs/about.md: -------------------------------------------------------------------------------- 1 | a minimal reverse mode autograd engine in c with reference counted tensors, arena allocated function nodes, explicit dependency counting, centralized gradient accumulation, scalar loss backpropagation and a small set of core tensor ops implemented with tightly coupled forward and backward code. 2 | -------------------------------------------------------------------------------- /src/ops/reductions.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | #include 6 | 7 | Tensor *tensor_sum(const Tensor *t, int64_t dim_idx, bool keepdims); 8 | Tensor *tensor_mean(const Tensor *t, int64_t dim_idx, bool keepdims); 9 | Tensor *tensor_max(const Tensor *t, int64_t dim_idx, bool keepdims); 10 | -------------------------------------------------------------------------------- /src/ops/activations.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | 5 | Tensor *tensor_sigmoid(const Tensor *t); 6 | Tensor *tensor_relu(const Tensor *t); 7 | Tensor *tensor_tanh(const Tensor *t); // needs `tensor_` prefix to avoid conflict with math.h 8 | Tensor *tensor_gelu(const Tensor *t); 9 | Tensor *tensor_softmax(const Tensor *t, int64_t dim); 10 | -------------------------------------------------------------------------------- /entitlements.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | com.apple.security.get-task-allow 7 | 8 | 9 | -------------------------------------------------------------------------------- /src/utils/augment.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | // randomly flip images horizontally with given probability p 7 | void random_horizontal_flip_mut(Tensor *t, float32_t p); 8 | 9 | // randomly crop image to (target_h, target_w) after padding 10 | void random_crop_mut(Tensor *t, uint64_t target_h, uint64_t target_w, uint64_t padding); 11 | -------------------------------------------------------------------------------- /src/ops/convolutions.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | 5 | Tensor *tensor_conv2d(const Tensor *input, const Tensor *weight, const Tensor *bias, uint64_t stride, uint64_t padding, uint64_t dilation); 6 | Tensor *tensor_maxpool2d(const Tensor *input, uint64_t kernel_size, uint64_t stride, uint64_t padding); 7 | Tensor *tensor_avgpool2d(const Tensor *input, uint64_t kernel_size, uint64_t stride, uint64_t padding); 8 | Tensor *tensor_batchnorm2d(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *running_mean, Tensor *running_var, bool training, float32_t momentum, float32_t eps); 9 | -------------------------------------------------------------------------------- /src/ops/losses_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "tensor.h" 5 | 6 | Tensor *mse_loss_backward(const Tensor *predictions, const Tensor *targets); 7 | void mse_loss_backward_fn(Function *fn, const Tensor *grad_output); 8 | 9 | Tensor *cross_entropy_loss_backward(const Tensor *logits, const Tensor *targets); 10 | void cross_entropy_loss_backward_fn(Function *fn, const Tensor *grad_output); 11 | 12 | Tensor *binary_cross_entropy_loss_backward(const Tensor *predictions, const Tensor *targets); 13 | void binary_cross_entropy_loss_backward_fn(Function *fn, const Tensor *grad_output); 14 | -------------------------------------------------------------------------------- /src/ops/reshapes_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "tensor.h" 5 | 6 | Tensor *tensor_reshape_backward(const Tensor *grad_output, const Tensor *input); 7 | typedef struct { 8 | uint64_t shape[MAX_NDIM]; 9 | uint64_t ndim; 10 | } ReshapeContext; 11 | void reshape_backward(Function *fn, const Tensor *grad_output); 12 | 13 | Tensor *tensor_transpose_backward(const Tensor *grad_output, uint64_t dim0, uint64_t dim1); 14 | typedef struct { 15 | uint64_t dim0; 16 | uint64_t dim1; 17 | } TransposeContext; 18 | void transpose_backward(Function *fn, const Tensor *grad_output); 19 | -------------------------------------------------------------------------------- /suppr.txt: -------------------------------------------------------------------------------- 1 | # LeakSanitizer suppressions for false positives from macOS system libraries. 2 | # 3 | # Some changes were merged upstream and have yet to be released in a stable clang/llvm version. 4 | 5 | leak:_fetchInitializingClassList 6 | # thread-local storage used by Objective-C runtime 7 | # - https://github.com/apple-oss-distributions/objc4/blob/f126469408dc82bd3f327217ae678fd0e6e3b37c/runtime/objc-initialize.mm#L287 8 | # - https://github.com/llvm/llvm-project/issues/115992 9 | # - https://github.com/google/sanitizers/wiki/AddressSanitizerLeakSanitizer#suppressions 10 | # - https://github.com/llvm/llvm-project/pull/117478 11 | 12 | leak:__Balloc_D2A 13 | # double to ASCII conversion library 14 | -------------------------------------------------------------------------------- /src/ops/activations_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "tensor.h" 5 | 6 | Tensor *tensor_sigmoid_backward(const Tensor *t); 7 | void sigmoid_backward(Function *fn, const Tensor *grad_output); 8 | 9 | Tensor *tensor_relu_backward(const Tensor *t); 10 | void relu_backward(Function *fn, const Tensor *grad_output); 11 | 12 | void tanh_backward(Function *fn, const Tensor *grad_output); 13 | Tensor *tensor_tanh_backward(const Tensor *t); 14 | 15 | Tensor *tensor_gelu_backward(const Tensor *t); 16 | void gelu_backward(Function *fn, const Tensor *grad_output); 17 | 18 | Tensor *tensor_softmax_backward(const Tensor *t, int64_t dim); 19 | void softmax_backward(Function *fn, const Tensor *grad_output); 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | data/ 3 | build/ 4 | 5 | # Prerequisites 6 | *.d 7 | 8 | # Object files 9 | *.o 10 | *.ko 11 | *.obj 12 | *.elf 13 | 14 | # Linker output 15 | *.ilk 16 | *.map 17 | *.exp 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Libraries 24 | *.lib 25 | *.a 26 | *.la 27 | *.lo 28 | 29 | # Shared objects (inc. Windows DLLs) 30 | *.dll 31 | *.so 32 | *.so.* 33 | *.dylib 34 | 35 | # Executables 36 | *.exe 37 | *.out 38 | *.app 39 | *.i*86 40 | *.x86_64 41 | *.hex 42 | 43 | # Debug files 44 | *.dSYM/ 45 | *.su 46 | *.idb 47 | *.pdb 48 | 49 | # Kernel Module Compile Results 50 | *.mod* 51 | *.cmd 52 | .tmp_versions/ 53 | modules.order 54 | Module.symvers 55 | Mkfile.old 56 | dkms.conf 57 | 58 | # debug information files 59 | *.dwo -------------------------------------------------------------------------------- /src/utils/cifar10.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | #define CHANNELS 3 7 | #define HEIGHT 32 8 | #define WIDTH 32 9 | #define INPUT_SIZE (CHANNELS * HEIGHT * WIDTH) 10 | #define NUM_TRAIN_SAMPLES 50000 11 | #define NUM_TEST_SAMPLES 10000 12 | 13 | #define NUM_CLASSES 10 14 | typedef enum { AIRPLANE = 0, AUTOMOBILE = 1, BIRD = 2, CAT = 3, DEER = 4, DOG = 5, FROG = 6, HORSE = 7, SHIP = 8, TRUCK = 9 } label_t; 15 | 16 | Tensor *cifar10_get_train_images(void); 17 | Tensor *cifar10_get_train_labels(void); 18 | Tensor *cifar10_get_test_images(void); 19 | Tensor *cifar10_get_test_labels(void); 20 | 21 | const char *label_to_str(label_t label); 22 | Tensor *get_batch(const Tensor *data, uint64_t batch_idx, uint64_t batch_size); 23 | -------------------------------------------------------------------------------- /src/ops/arithmetic.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | Tensor *tensor_add(const Tensor *a, const Tensor *b); 7 | Tensor *tensor_sub(const Tensor *a, const Tensor *b); 8 | 9 | // macro overloading to make the `disable_grad` arg optional (default to `false`) 10 | #define tensor_mul(...) TENSOR_MUL_SELECT(__VA_ARGS__, tensor_mul_3, tensor_mul_2)(__VA_ARGS__) 11 | #define TENSOR_MUL_SELECT(_1, _2, _3, NAME, ...) NAME 12 | #define tensor_mul_2(a, b) tensor_mul_impl(a, b, false) 13 | #define tensor_mul_3(a, b, disable_grad) tensor_mul_impl(a, b, disable_grad) 14 | Tensor *tensor_mul_impl(const Tensor *a, const Tensor *b, bool disable_grad); 15 | 16 | Tensor *tensor_div(const Tensor *a, const Tensor *b); 17 | Tensor *tensor_matmul(const Tensor *a, const Tensor *b); 18 | -------------------------------------------------------------------------------- /src/utils/aligned_alloc.h: -------------------------------------------------------------------------------- 1 | #include "utils/types.h" 2 | #include 3 | #include 4 | 5 | #define CACHELINE_SIZE 64 6 | _Static_assert(CACHELINE_SIZE >= sizeof(float32_t), "cacheline_size must be at least 4 bytes"); 7 | _Static_assert((CACHELINE_SIZE & (CACHELINE_SIZE - 1)) == 0, "cacheline_size must be power of 2"); 8 | 9 | static inline void *safe_aligned_alloc(uint64_t size_bytes) { 10 | size_t s_bytes = (size_t)size_bytes; 11 | if (s_bytes % CACHELINE_SIZE != 0) { 12 | s_bytes = (s_bytes / CACHELINE_SIZE + 1) * CACHELINE_SIZE; 13 | } 14 | void *ptr = aligned_alloc(CACHELINE_SIZE, s_bytes); 15 | assert(ptr != NULL && "aligned_alloc failed: out of memory"); 16 | assert((uintptr_t)ptr % CACHELINE_SIZE == 0 && "allocated pointer is not properly aligned"); 17 | return ptr; 18 | } 19 | -------------------------------------------------------------------------------- /src/ops/reductions_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "tensor.h" 5 | #include 6 | 7 | Tensor *tensor_sum_backward(const Tensor *grad_output, const Tensor *t, int64_t dim_idx, bool keepdims); 8 | typedef struct { 9 | int64_t dim_idx; 10 | bool keepdims; 11 | } SumContext; 12 | void sum_backward(Function *fn, const Tensor *grad_output); 13 | 14 | Tensor *tensor_mean_backward(const Tensor *grad_output, const Tensor *t, int64_t dim_idx, bool keepdims); 15 | typedef struct { 16 | int64_t dim_idx; 17 | bool keepdims; 18 | } MeanContext; 19 | void mean_backward(Function *fn, const Tensor *grad_output); 20 | 21 | Tensor *tensor_max_backward(const Tensor *grad_output, const Tensor *t, const Tensor *out, int64_t dim_idx, bool keepdims); 22 | typedef struct { 23 | int64_t dim_idx; 24 | bool keepdims; 25 | Tensor *output; 26 | } MaxContext; 27 | void max_backward(Function *fn, const Tensor *grad_output); 28 | -------------------------------------------------------------------------------- /src/utils/metrics.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | #include 6 | #include 7 | 8 | static inline float32_t accuracy(const Tensor *logits, const Tensor *labels) { 9 | assert(logits != NULL); 10 | assert(labels != NULL); 11 | assert(logits->ndim == 2); 12 | assert(labels->ndim == 1); 13 | 14 | uint64_t batch_size = logits->shape[0]; 15 | uint64_t num_classes = logits->shape[1]; 16 | 17 | uint64_t correct = 0; 18 | for (uint64_t i = 0; i < batch_size; i++) { 19 | // argmax 20 | uint64_t predicted = 0; 21 | float32_t max_val = logits->data[i * num_classes]; 22 | for (uint64_t j = 1; j < num_classes; j++) { 23 | float32_t val = logits->data[i * num_classes + j]; 24 | if (val > max_val) { 25 | max_val = val; 26 | predicted = j; 27 | } 28 | } 29 | 30 | uint64_t true_label = (uint64_t)labels->data[i]; 31 | if (predicted == true_label) { 32 | correct++; 33 | } 34 | } 35 | 36 | return (float32_t)correct / (float32_t)batch_size; 37 | } 38 | -------------------------------------------------------------------------------- /src/ops/arithmetic_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "tensor.h" 5 | 6 | Tensor *tensor_add_backward_a(const Tensor *grad_output, const Tensor *a); 7 | Tensor *tensor_add_backward_b(const Tensor *grad_output, const Tensor *b); 8 | void add_backward(Function *fn, const Tensor *grad_output); 9 | 10 | Tensor *tensor_sub_backward_a(const Tensor *grad_output, const Tensor *a); 11 | Tensor *tensor_sub_backward_b(const Tensor *grad_output, const Tensor *b); 12 | void sub_backward(Function *fn, const Tensor *grad_output); 13 | 14 | Tensor *tensor_mul_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b); 15 | Tensor *tensor_mul_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b); 16 | void mul_backward(Function *fn, const Tensor *grad_output); 17 | 18 | Tensor *tensor_div_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b); 19 | Tensor *tensor_div_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b); 20 | void div_backward(Function *fn, const Tensor *grad_output); 21 | 22 | Tensor *tensor_matmul_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b); 23 | Tensor *tensor_matmul_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b); 24 | void matmul_backward(Function *fn, const Tensor *grad_output); 25 | -------------------------------------------------------------------------------- /src/layers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | // layer base type using vtable polymorphism 7 | typedef struct Layer Layer; 8 | typedef Tensor *(*LayerForwardFunc)(Layer *layer, const Tensor *input, bool training); 9 | typedef void (*LayerFreeFunc)(Layer *layer); 10 | typedef void (*LayerParametersFunc)(Layer *layer, Tensor ***out_params, size_t *out_count); 11 | 12 | struct Layer { 13 | LayerForwardFunc forward; // forward pass 14 | LayerFreeFunc free; // cleanup 15 | LayerParametersFunc parameters; // get trainable params 16 | char *name; // for debugging 17 | }; 18 | 19 | Tensor *layer_forward(Layer *layer, const Tensor *input, bool training); // forward pass through layer 20 | void layer_free(Layer *layer); // frees layer resources 21 | void layer_parameters(Layer *layer, Tensor ***out_params, size_t *out_count); // gets trainable parameters 22 | 23 | Layer *layer_linear_create(uint64_t in_features, uint64_t features_out, bool bias); // linear: y = xW + b 24 | Layer *layer_dropout_create(float32_t p); // randomly zeros elements during training 25 | Layer *layer_sequential_create(Layer **layers, size_t count); // chains layers together 26 | -------------------------------------------------------------------------------- /src/optimizers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tensor.h" 4 | #include 5 | 6 | // optimizer base type using vtable polymorphism 7 | typedef struct Optimizer Optimizer; 8 | typedef void (*OptimizerStepFunc)(Optimizer *opt); 9 | typedef void (*OptimizerFreeFunc)(Optimizer *opt); 10 | 11 | struct Optimizer { 12 | Tensor **params; // parameters being optimized 13 | size_t param_count; // number of parameters 14 | size_t step_count; // optimization steps taken 15 | 16 | OptimizerStepFunc step; // update rule implementation 17 | OptimizerFreeFunc free; // cleanup implementation 18 | }; 19 | 20 | void optimizer_zero_grad(Optimizer *opt); // zeros all parameter gradients 21 | void optimizer_step(Optimizer *opt); // updates parameters 22 | void optimizer_free(Optimizer *opt); // frees optimizer 23 | 24 | // 25 | // initializers 26 | // 27 | 28 | // sgd: v = momentum * v - lr * grad; param += v 29 | Optimizer *optimizer_sgd_create(Tensor **params, size_t count, float32_t lr, float32_t momentum, float32_t weight_decay); 30 | 31 | // adam: adaptive moment estimation 32 | Optimizer *optimizer_adam_create(Tensor **params, size_t count, float32_t lr, float32_t beta1, float32_t beta2, float32_t eps, float32_t weight_decay); 33 | 34 | // adamw: adam with decoupled weight decay 35 | Optimizer *optimizer_adamw_create(Tensor **params, size_t count, float32_t lr, float32_t beta1, float32_t beta2, float32_t eps, float32_t weight_decay); 36 | -------------------------------------------------------------------------------- /src/tensor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "utils/types.h" 5 | #include 6 | #include 7 | #include 8 | 9 | #define MAX_NDIM 32 10 | #define MAX_TENSOR_SIZE (UINT64_MAX / sizeof(float32_t)) 11 | 12 | typedef struct Tensor { 13 | float32_t *data; // flat contiguous array, row-major 14 | uint64_t *shape; // array of dimension sizes 15 | uint64_t *strides; // array of elements to skip to get to next element in each dimension 16 | uint64_t ndim; // rank (ie. 1 for vector, 2 for matrix, etc.) 17 | uint64_t size; // total number of elements 18 | 19 | // autograd fields 20 | bool requires_grad; // whether to track operations for autograd 21 | struct Tensor *grad; // accumulated gradient (del loss / del tensor) during backprop 22 | Function *grad_fn; // function that created this tensor (NULL for leaves) 23 | uint32_t ref_count; // reference count for memory management 24 | } Tensor; 25 | 26 | // internals 27 | void linear_to_multidim_mut(uint64_t lin, const uint64_t *shape, uint64_t ndim, uint64_t *out_multidim); 28 | uint64_t multidim_to_linear(const uint64_t *target, uint64_t target_ndim, const uint64_t *shape, uint64_t ndim, const uint64_t *strides); 29 | 30 | // memory management 31 | Tensor *tensor_create(const float32_t *data, const uint64_t *shape, uint64_t ndim, bool requires_grad); 32 | Tensor *tensor_zeros(const uint64_t *shape, uint64_t ndim, bool requires_grad); 33 | void tensor_free(Tensor *t); 34 | Tensor *tensor_retain(Tensor *t); 35 | void tensor_release(Tensor *t); 36 | void tensor_zero_grad(Tensor *t); 37 | 38 | // utils 39 | void tensor_print(const Tensor *t); 40 | Tensor *tensor_get(const Tensor *t, const uint64_t *multidim); 41 | -------------------------------------------------------------------------------- /src/ops/convolutions_backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "autograd.h" 4 | #include "layers.h" 5 | #include "tensor.h" 6 | 7 | void conv2d_backward(const Tensor *input, const Tensor *weight, const Tensor *bias, uint64_t stride, uint64_t padding, uint64_t kernel_size, const Tensor *grad_output, Tensor **out_grad_in, Tensor **out_grad_w, Tensor **out_grad_b); 8 | typedef struct { 9 | uint64_t stride; 10 | uint64_t padding; 11 | uint64_t dilation; 12 | uint64_t kernel_h; 13 | uint64_t kernel_w; 14 | } Conv2dContext; 15 | void conv2d_backward_fn(Function *fn, const Tensor *grad_output); 16 | 17 | Tensor *maxpool2d_backward(const Tensor *input, const uint64_t *output_shape, uint64_t kernel_size, uint64_t stride, uint64_t padding, const Tensor *grad_output); 18 | typedef struct { 19 | uint64_t kernel_size; 20 | uint64_t stride; 21 | uint64_t padding; 22 | uint64_t output_shape[4]; 23 | } MaxPool2dContext; 24 | void maxpool2d_backward_fn(Function *fn, const Tensor *grad_output); 25 | 26 | Tensor *avgpool2d_backward(const Tensor *input, const uint64_t *output_shape, uint64_t kernel_size, uint64_t stride, uint64_t padding, const Tensor *grad_output); 27 | typedef struct { 28 | uint64_t kernel_size; 29 | uint64_t stride; 30 | uint64_t padding; 31 | uint64_t output_shape[4]; 32 | } AvgPool2dContext; 33 | void avgpool2d_backward_fn(Function *fn, const Tensor *grad_output); 34 | 35 | void batchnorm2d_backward(const Tensor *input, const Tensor *gamma, const Tensor *batch_mean, const Tensor *batch_var, float32_t eps, const Tensor *grad_output, Tensor **out_grad_in, Tensor **out_grad_gamma, Tensor **out_grad_beta); 36 | typedef struct { 37 | float32_t eps; 38 | bool training; 39 | Tensor *batch_mean; 40 | Tensor *batch_var; 41 | } BatchNorm2dContext; 42 | void batchnorm2d_backward_fn(Function *fn, const Tensor *grad_output); 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: run 2 | run: lint 3 | mkdir -p /tmp/build && cd /tmp/build && cmake -DCMAKE_C_COMPILER=/opt/homebrew/opt/llvm/bin/clang $(PWD) && cmake --build . -j$$(sysctl -n hw.ncpu) && cd $(PWD) && ASAN_OPTIONS=detect_leaks=1 LSAN_OPTIONS=suppressions=$(PWD)/suppr.txt /tmp/build/binary 4 | 5 | .PHONY: leaks 6 | leaks: lint 7 | rm -rf /tmp/leaks-build && mkdir -p /tmp/leaks-build && cd /tmp/leaks-build && cmake -DDISABLE_ASAN=ON $(PWD) && cmake --build . -j$$(sysctl -n hw.ncpu) 8 | codesign -s - -f --entitlements entitlements.plist /tmp/leaks-build/binary 9 | cd $(PWD) && leaks --atExit --list --groupByType -- /tmp/leaks-build/binary 10 | 11 | .PHONY: test 12 | test: lint 13 | rm -rf /tmp/test-build && mkdir -p /tmp/test-build && cd /tmp/test-build && cmake -DBUILD_TESTS=ON $(PWD) && cmake --build . -j$$(sysctl -n hw.ncpu) && ctest --output-on-failure 14 | 15 | .PHONY: run-release 16 | run-release: 17 | rm -rf /tmp/release-build && mkdir -p /tmp/release-build && cd /tmp/release-build && cmake -DCMAKE_C_COMPILER=/opt/homebrew/opt/llvm/bin/clang -DCMAKE_BUILD_TYPE=Release -DDISABLE_ASAN=ON -DDISABLE_UBSAN=ON -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON $(PWD) && cmake --build . -j$$(sysctl -n hw.ncpu) && cd $(PWD) && /tmp/release-build/binary 18 | 19 | # 20 | # utils 21 | # 22 | 23 | .PHONY: download 24 | download: 25 | test ! -f data/data_batch_1.bin # check if already downloaded 26 | mkdir -p data 27 | test -f data/cifar-10-binary.tar.gz || wget -O data/cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz 28 | tar -xzf data/cifar-10-binary.tar.gz -C data --strip-components=1 29 | 30 | .PHONY: lint 31 | lint: 32 | cppcheck --enable=all --std=c23 --language=c --suppressions-list=cppcheck-suppressions.txt --check-level=exhaustive --inconclusive --inline-suppr -I src/ src/ 33 | 34 | .PHONY: fmt 35 | fmt: 36 | uvx --from cmakelang cmake-format --dangle-parens --line-width 500 -i CMakeLists.txt 37 | find . -name "*.c" -o -name "*.h" | xargs clang-format -i 38 | -------------------------------------------------------------------------------- /src/ops/reshapes_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/reshapes_backward.h" 2 | #include "ops/reshapes.h" 3 | #include "tensor.h" 4 | #include 5 | #include 6 | #include 7 | 8 | // 9 | // reshape 10 | // 11 | 12 | Tensor *tensor_reshape_backward(const Tensor *grad_output, const Tensor *input) { 13 | assert(grad_output != NULL); 14 | assert(input != NULL); 15 | return tensor_reshape(grad_output, (const int64_t *)input->shape, input->ndim); 16 | } 17 | 18 | void reshape_backward(Function *fn, const Tensor *grad_output) { 19 | assert(fn != NULL); 20 | assert(grad_output != NULL); 21 | assert(fn->num_inputs == 1); 22 | assert(fn->ctx != NULL && "reshape_backward requires context"); 23 | 24 | Tensor *t = fn->inputs[0]; 25 | const ReshapeContext *ctx = (ReshapeContext *)fn->ctx; 26 | 27 | if (t != NULL && t->requires_grad) { 28 | // Create temporary tensor with original shape for backward function 29 | Tensor temp_input; 30 | temp_input.shape = (uint64_t *)ctx->shape; 31 | temp_input.ndim = ctx->ndim; 32 | 33 | Tensor *grad_t = tensor_reshape_backward(grad_output, &temp_input); 34 | accumulate_grad(t, grad_t); 35 | } 36 | 37 | free(fn->ctx); 38 | fn->ctx = NULL; 39 | } 40 | 41 | // 42 | // transpose 43 | // 44 | 45 | Tensor *tensor_transpose_backward(const Tensor *grad_output, uint64_t dim0, uint64_t dim1) { 46 | assert(grad_output != NULL); 47 | return tensor_transpose(grad_output, dim0, dim1); 48 | } 49 | 50 | void transpose_backward(Function *fn, const Tensor *grad_output) { 51 | assert(fn != NULL); 52 | assert(grad_output != NULL); 53 | assert(fn->num_inputs == 1); 54 | assert(fn->ctx != NULL && "transpose_backward requires context"); 55 | 56 | Tensor *t = fn->inputs[0]; 57 | const TransposeContext *ctx = (TransposeContext *)fn->ctx; 58 | 59 | if (t != NULL && t->requires_grad) { 60 | Tensor *grad_t = tensor_transpose_backward(grad_output, ctx->dim0, ctx->dim1); 61 | accumulate_grad(t, grad_t); 62 | } 63 | 64 | free(fn->ctx); 65 | fn->ctx = NULL; 66 | } 67 | -------------------------------------------------------------------------------- /src/autograd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define MAX_INPUTS 4 7 | 8 | // avoid circular dependency 9 | typedef struct Tensor Tensor; 10 | 11 | /* 12 | * represents an op in the computation graph 13 | * 14 | * example: 15 | * forward graph: (a, b) -> mul -> result -> sum -> loss 16 | * backward flow: loss -> sum -> mul -> (a, b) 17 | * 18 | * forward: 19 | * Tensor *result = mul(a, b); 20 | * if (result->requires_grad) { 21 | * Function *fn = malloc_arena(); 22 | * fn->apply = mul_backward; // callback to compute gradients on backprop 23 | * fn->num_inputs = 2; 24 | * fn->inputs[0] = a; 25 | * fn->inputs[1] = b; 26 | * fn->output = result; 27 | * fn->pending_count = 0; // how many ops must finish before this one can run 28 | * if (a->grad_fn) a->grad_fn->pending_count++; 29 | * if (b->grad_fn) b->grad_fn->pending_count++; 30 | * fn->ctx = NULL; // no extra storage needed 31 | * result->grad_fn = fn; 32 | * } 33 | * 34 | * backward: 35 | * after forward pass in neural net, traverse graph backward from loss 36 | * calling each Function's apply() to compute gradients 37 | * to find the contribution of each operation to the final loss. 38 | * 39 | * results can then be retrieved from each Tensor's grad field. 40 | */ 41 | typedef struct Function { 42 | void (*apply)(struct Function *self, const Tensor *grad_output); // callback to compute gradients 43 | uint32_t num_inputs; // number of args for op 44 | Tensor *inputs[MAX_INPUTS]; // args 45 | Tensor *output; // result Tensor 46 | uint32_t pending_count; // how many ops must finish before this one can run 47 | void *ctx; // extra metadata storage, used by complex operations 48 | } Function; 49 | 50 | // memory arena to store Function structs for all ops in a thread 51 | typedef struct Arena { 52 | void *memory; 53 | size_t capacity; 54 | size_t offset; 55 | } Arena; 56 | Function *arena_alloc_function(void); 57 | void arena_free(void); 58 | 59 | // traverse computation graph backward from loss tensor 60 | void backward(Tensor *loss); 61 | 62 | // add a new gradient to tensor->grad 63 | void accumulate_grad(Tensor *tensor, Tensor *new_grad); 64 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | $ toilet -f mono9 -w 100 "autograd.c" 2 | 3 | ▄ █ 4 | ▄▄▄ ▄ ▄ ▄▄█▄▄ ▄▄▄ ▄▄▄▄ ▄ ▄▄ ▄▄▄ ▄▄▄█ ▄▄▄ 5 | ▀ █ █ █ █ █▀ ▀█ █▀ ▀█ █▀ ▀ ▀ █ █▀ ▀█ █▀ ▀ 6 | ▄▀▀▀█ █ █ █ █ █ █ █ █ ▄▀▀▀█ █ █ █ 7 | ▀▄▄▀█ ▀▄▄▀█ ▀▄▄ ▀█▄█▀ ▀█▄▀█ █ ▀▄▄▀█ ▀█▄██ █ ▀█▄▄▀ 8 | ▄ █ 9 | ▀▀ 10 | 11 | $ cat docs/about.md | fold -w 75 12 | 13 | a minimal reverse mode autograd engine in c with reference counted tensors, 14 | arena allocated function nodes, explicit dependency counting, centralized 15 | gradient accumulation, scalar loss backpropagation and a small set of core 16 | tensor ops implemented with tightly coupled forward and backward code. 17 | 18 | $ make download 19 | 20 | data/cifar-10-binary 40%[====> ] 33.10M 2.55MB/s eta 47s 21 | 22 | $ make run-release 23 | 24 | loaded data 25 | train samples: 50000 26 | test samples: 10000 27 | created model 28 | starting training 29 | 30 | avg loss: 2.2632, avg acc: 20.48% 31 | evaluating: 100%|█████████████████████████████████████| 79/79 [0.7it/s] 32 | test acc: 23.56% 33 | 34 | ┌─────────────────── loss (epoch 2/15) ───────────────────┐ 35 | │ • │ 2.3112 (max) 36 | │ │ 37 | │ •• │ 38 | │ • • • │ 39 | │• • │ 40 | │ •••• • • • │ 41 | │ •• • • • • │ 42 | │ •• • • │ 43 | │ • • • • • • │ 44 | │ • • • │ 45 | │ • • • • • • • • • • │ 46 | │ • • • │ 47 | │ • • • • ┤––– 2.2941 48 | │ • │ 49 | │ • • • │ 2.2899 (min) 50 | └──────────────────────────────────────────────────────────┘ 51 | 96%|██████████████████████████████████████████████▊ | 750/782 [1.5it/s] 52 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # (1) Safety 2 | 3 | - Zero Technical Debt: Fix issues immediately. Never rely on future refactoring. 4 | - Testing: 5 | - Mandatory Unit Tests: Every new feature must be accompanied by unit tests. 6 | - Coverage: The more tests, the better. Cover edge cases, boundary conditions and failure modes. 7 | - Assert Aggressively: Validate inputs, outputs, tensor shapes and invariants in every function. 8 | - Pair Assertions: Check critical data at multiple points to catch internal inconsistencies. 9 | - Bounded Execution: Set fixed upper limits on all loops, queues, and recursion depth (especially for graph traversal). 10 | - Fail Fast: Detect unexpected conditions immediately. Crash rather than corrupt state. 11 | - No Undefined Behavior: Rely on explicit code, not compiler optimizations. Treat all compiler warnings as errors. 12 | - Controlled Memory: 13 | - Strongly prefer static allocation over dynamic allocation. 14 | - When dynamic allocation is necessary (e.g., graph construction), use arenas or memory pools* 15 | - Enforce strict upper bounds on memory usage. 16 | - Assert Allocation: Wrap every arena allocation, `malloc`, or resource claim in an assert (e.g., `assert(ptr != NULL)`). 17 | 18 | C Specific: 19 | 20 | - Use explicit types only (e.g., `int64_t`, `float64_t`), no implicit types (`int`, `float`) 21 | - Avoid architecture-dependent types (e.g., `size_t`, `long`) to ensure portability; use fixed-size explicit types instead 22 | - Strict Const Correctness: Pointers are `const` by default. Only drop `const` when mutation is strictly required. 23 | - Explicit Mutation: Avoid manipulating function arguments or causing side effects 24 | - Exception: If copying has an extremely high memory cost (e.g., large tensors), mutation is allowed but must be obvious in the naming (e.g., `out_result`, `tensor_mut`). 25 | 26 | Pre-commit 27 | 28 | - Add unit tests for each new function and feature 29 | - Run `make fmt`, `make lint`, `make test` 30 | - Do NOT run `make run` as it requires downloading large data files 31 | 32 | # (2) Quality 33 | 34 | - Obvious Code > Clever Code 35 | - Maximize Locality: Keep related code together. Define things near usage. Minimize variable scope. 36 | - Centralize Control Flow: Branching logic belongs in parents. leaf functions should be pure logic. 37 | - Guard Clauses: Handle checks first, return early, minimize nesting. 38 | - Functions: Do one coherent thing (ideally <70 lines). Prefer lambdas/inline logic over tiny single-use functions. 39 | - Decompose Conditionals: Use named variables to simplify complex `if` conditions. 40 | - Naming & Comments: 41 | - Variable names must include units/qualifiers (e.g., `timeout_ms`, `size_bytes`). 42 | - Distinguish types: Index (0-based), Count (1-based), and Size (memory bytes). 43 | - Comments explain *why*, not *what*; use lowercase single lines. ASCII illustrations are welcome. 44 | - Paradigm Balance: 45 | - Functional: Prefer pure functions (data in, data out) and immutability for logic. 46 | - Procedural: Use direct loops and local mutation when simpler or significantly more performant. 47 | 48 | # (3) Performance 49 | 50 | - Follow Data-Oriented Design (DoD) principles 51 | - Design for Hardware: Organize data to match how the hardware reads it (cache lines). 52 | - Struct of Arrays (SoA): Prefer SoA over Array of Structs (AoS) for heavy computation to maximize SIMD usage. 53 | - Data Alignment: Ensure critical data (tensors) is **aligned** (e.g., 64 bytes) for SIMD efficiency. Assert alignment on access. 54 | - Batch Processing: Write functions that transform arrays of data rather than single elements (data transformation > object interaction). 55 | - Existence-based Processing: Filter data *before* processing so loops run on contiguous, valid data (avoid `if (obj->active)` inside hot loops). 56 | 57 | --- 58 | 59 | > Based on https://tigerstyle.dev/ 60 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27 FATAL_ERROR) 2 | project(defer LANGUAGES C) 3 | 4 | set(CMAKE_C_STANDARD 23) 5 | set(CMAKE_C_STANDARD_REQUIRED ON) 6 | 7 | if(CMAKE_C_COMPILER_ID MATCHES "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 13.0.0) 8 | message(FATAL_ERROR "GCC is outdated: ${CMAKE_C_COMPILER_VERSION}") 9 | endif() 10 | if(CMAKE_C_COMPILER_ID MATCHES "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 16.0.0) 11 | message(FATAL_ERROR "Clang is outdated: ${CMAKE_C_COMPILER_VERSION}") 12 | endif() 13 | 14 | # 15 | # compiler hardening 16 | # 17 | # see: https://best.openssf.org/Compiler-Hardening-Guides/Compiler-Options-Hardening-Guide-for-C-and-C++ 18 | # 19 | 20 | add_compile_options( 21 | -O2 22 | -Wall 23 | -Wextra 24 | -Wformat 25 | -Wformat=2 26 | -Wconversion 27 | -Wsign-conversion 28 | -Wimplicit-fallthrough 29 | -Werror=format-security 30 | # more portable and explicit than -fhardened 31 | -U_FORTIFY_SOURCE 32 | -D_FORTIFY_SOURCE=3 33 | -D_GLIBCXX_ASSERTIONS 34 | -fstrict-flex-arrays=3 35 | -fstack-protector-strong 36 | # deprecated c calls 37 | -Werror=implicit 38 | -Werror=incompatible-pointer-types 39 | -Werror=int-conversion 40 | # multithreading with pthreads 41 | -fexceptions 42 | # for shared libraries use `-fPIC` 43 | -fPIE 44 | ) 45 | 46 | if(CMAKE_BUILD_TYPE STREQUAL "Release") 47 | add_compile_options(-fno-delete-null-pointer-checks -fno-strict-overflow -fno-strict-aliasing -ftrivial-auto-var-init=zero) 48 | endif() 49 | 50 | if(CMAKE_SYSTEM_NAME STREQUAL "Linux") 51 | add_compile_options(-fstack-clash-protection) 52 | add_link_options( 53 | -pie 54 | -Wl,-z,nodlopen 55 | -Wl,-z,noexecstack 56 | -Wl,-z,relro 57 | -Wl,-z,now 58 | -Wl,--as-needed 59 | -Wl,--no-copy-dt-needed-entries 60 | ) 61 | endif() 62 | 63 | if(CMAKE_C_COMPILER_ID MATCHES "GNU") 64 | # add `-fzero-init-padding-bits=all` from gcc 15 65 | add_compile_options(-Wtrampolines -Wbidi-chars=any) 66 | endif() 67 | if(CMAKE_C_COMPILER_ID MATCHES "Clang") 68 | # nop 69 | endif() 70 | 71 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|i686|i386") 72 | add_compile_options(-fcf-protection=full) 73 | endif() 74 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") 75 | add_compile_options(-mbranch-protection=standard) 76 | endif() 77 | 78 | # 79 | # sanitizers 80 | # 81 | 82 | if(NOT DISABLE_ASAN) 83 | # apple silicon requires clang from homebrew with paths set accordingly: `brew info llvm` 84 | add_compile_options(-fsanitize=address -fsanitize-address-use-after-scope -fno-omit-frame-pointer -g) 85 | add_link_options(-fsanitize=address) 86 | endif() 87 | 88 | if(NOT DISABLE_UBSAN) 89 | add_compile_options(-fsanitize=undefined -fno-omit-frame-pointer -g) 90 | add_link_options(-fsanitize=undefined) 91 | endif() 92 | 93 | # 94 | # dependencies 95 | # 96 | 97 | include(FetchContent) 98 | 99 | if(BUILD_TESTS) 100 | FetchContent_Declare( 101 | unity 102 | GIT_REPOSITORY https://github.com/ThrowTheSwitch/Unity.git 103 | GIT_TAG v2.6.1 104 | GIT_SHALLOW TRUE 105 | ) 106 | 107 | set(BUILD_TESTING 108 | OFF 109 | CACHE BOOL "" FORCE 110 | ) 111 | set(BUILD_SHARED_LIBS 112 | OFF 113 | CACHE BOOL "" FORCE 114 | ) 115 | set(UNITY_EXTENSION_FIXTURE 116 | OFF 117 | CACHE BOOL "" FORCE 118 | ) 119 | set(UNITY_EXTENSION_MEMORY 120 | OFF 121 | CACHE BOOL "" FORCE 122 | ) 123 | set(UNITY_BUILD_TESTS 124 | OFF 125 | CACHE BOOL "" FORCE 126 | ) 127 | 128 | FetchContent_MakeAvailable(unity) 129 | endif() 130 | 131 | # 132 | # sources 133 | # 134 | 135 | file(GLOB_RECURSE SRCS src/*.c) 136 | list(FILTER SRCS EXCLUDE REGEX "src/main\\.c$") 137 | add_library(lib ${SRCS}) 138 | target_include_directories(lib PUBLIC src) 139 | add_executable(binary src/main.c) 140 | target_link_libraries(binary PRIVATE lib) 141 | 142 | # 143 | # tests 144 | # 145 | 146 | file(GLOB_RECURSE TEST_FILES test/*.c) 147 | 148 | if(BUILD_TESTS) 149 | enable_testing() 150 | 151 | foreach(TEST_FILE ${TEST_FILES}) 152 | get_filename_component(TEST_NAME ${TEST_FILE} NAME_WE) 153 | set(TEST_SOURCE "test/${TEST_NAME}.c") 154 | set(TEST_EXECUTABLE "${TEST_NAME}_binary") 155 | 156 | add_executable(${TEST_EXECUTABLE} ${TEST_SOURCE}) 157 | target_link_libraries(${TEST_EXECUTABLE} PRIVATE unity lib) 158 | target_include_directories(${TEST_EXECUTABLE} PRIVATE ${unity_SOURCE_DIR}/src) 159 | 160 | add_test(NAME ${TEST_NAME} COMMAND ${TEST_EXECUTABLE}) 161 | endforeach() 162 | endif() 163 | -------------------------------------------------------------------------------- /docs/pytorch_ast_optimization.py: -------------------------------------------------------------------------------- 1 | # author: sueszli 2 | # reviewer: ivan yashchuk 3 | # /// script 4 | # requires-python = ">=3.11" 5 | # dependencies = [ 6 | # "astunparse==1.6.3", 7 | # "jax==0.4.20", 8 | # "click==8.1.7", 9 | # "torch==2.1.1", 10 | # ] 11 | # /// 12 | 13 | import ast 14 | import astunparse 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | import os 21 | import time 22 | from dataclasses import dataclass 23 | 24 | 25 | # source: https://github.com/karpathy/nanoGPT/blob/master/model.py#L29 26 | csa_str = """ 27 | import torch 28 | import math 29 | from torch import nn 30 | class CausalSelfAttention(nn.Module): 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | def forward(self, x): 45 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 46 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 47 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 48 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 49 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 50 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 51 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 52 | ###!!! Beginning of the scaled dot product attention 53 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 54 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 55 | att = F.softmax(att, dim=-1) 56 | att = self.attn_dropout(att) 57 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 58 | ###!!! End of the scaled dot product attention 59 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 60 | # output projection 61 | y = self.resid_dropout(self.c_proj(y)) 62 | return y 63 | """ 64 | 65 | 66 | def get_exectime(python_str: str) -> float: 67 | exec(python_str, globals()) 68 | assert "CausalSelfAttention" in globals(), "class not found in namespace" 69 | 70 | @dataclass 71 | class Config: 72 | n_embd: int = 3072 # embedding dimensionality 73 | n_head: int = 16 # number of attention heads 74 | dropout: float = 0.1 # dropout probability 75 | bias: bool = True # whether to use bias in projection layers 76 | 77 | def gen_bias(size): 78 | mask = torch.triu(torch.ones(size, size), diagonal=1) 79 | mask = mask.masked_fill(mask == 1, -1e10).unsqueeze(0).unsqueeze(1) 80 | return mask 81 | 82 | config = Config() 83 | model = CausalSelfAttention(config) # type: ignore 84 | setattr(model, "bias", gen_bias(config.n_embd)) # missing attribute in the original code 85 | x = torch.rand(16, 1024, 3072, dtype=torch.float32) 86 | 87 | t_start = time.time() 88 | output = model.forward(x) 89 | t_end = time.time() 90 | return t_end - t_start 91 | 92 | 93 | def transform(python_str: str) -> str: 94 | tree = ast.parse(python_str) 95 | forward_method = tree.body[3].body[1] # type: ignore 96 | forward_method.body[5:10] = ast.parse("y = F.scaled_dot_product_attention(q, k, v)").body 97 | return astunparse.unparse(tree) 98 | 99 | 100 | def main(): 101 | os.system("clear" if os.name == "posix" else "cls") 102 | 103 | new_csa_str = transform(csa_str) 104 | print(f"Original code:\n\n{csa_str}\n\nOptimized code:\n\n{new_csa_str}\n\n") 105 | 106 | iters = 5 107 | print(f"Running {iters} benchmark iterations...") 108 | 109 | avg_t1 = 0.0 110 | for i in range(iters): 111 | score = get_exectime(csa_str) 112 | print(f"\t{i + 1}/{iters} - Original time: {score}") 113 | avg_t1 += score 114 | avg_t1 /= iters 115 | 116 | avg_t2 = 0.0 117 | for i in range(iters): 118 | score = get_exectime(new_csa_str) 119 | print(f"\t{i + 1}/{iters} - Optimized execution time: {score}") 120 | avg_t2 += score 121 | avg_t2 /= iters 122 | 123 | print(f"Avg original time: {avg_t1}s") # Avg original time: 2.4412187576293944s 124 | print(f"Avg optimized execution time: {avg_t2}s") # Avg optimized execution time: 0.9933661937713623s 125 | print(f"🔥 Avg speedup: {avg_t1 / avg_t2 * 100:.2f}%") # 🔥 Avg speedup: 245.75% 126 | print(f"🔥 Avg seconds saved: {avg_t1 - avg_t2:.2f}s") # 🔥 Avg seconds saved: 1.45s 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /src/ops/activations.c: -------------------------------------------------------------------------------- 1 | #include "activations.h" 2 | #include "autograd.h" 3 | #include "ops/activations_backward.h" 4 | #include "ops/arithmetic.h" 5 | #include "ops/reductions.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | Tensor *tensor_sigmoid(const Tensor *t) { 12 | assert(t != NULL); 13 | assert(t->data != NULL || t->size == 0); 14 | 15 | Tensor *out = tensor_create(NULL, t->shape, t->ndim, t->requires_grad); 16 | 17 | for (uint64_t i = 0; i < t->size; i++) { 18 | float32_t x = t->data[i]; 19 | if (x > 500.0f) { 20 | x = 500.0f; 21 | } 22 | if (x < -500.0f) { 23 | x = -500.0f; 24 | } 25 | 26 | if (x >= 0.0f) { 27 | out->data[i] = 1.0f / (1.0f + expf(-x)); 28 | } else { 29 | float32_t ex = expf(x); 30 | out->data[i] = ex / (1.0f + ex); 31 | } 32 | } 33 | 34 | if (out->requires_grad) { 35 | Function *fn = arena_alloc_function(); 36 | fn->apply = sigmoid_backward; 37 | fn->output = out; 38 | fn->num_inputs = 1; 39 | fn->inputs[0] = (Tensor *)t; 40 | fn->pending_count = 0; 41 | fn->ctx = NULL; 42 | if (t->grad_fn != NULL) { 43 | t->grad_fn->pending_count++; 44 | } 45 | out->grad_fn = fn; 46 | } 47 | 48 | return out; 49 | } 50 | 51 | Tensor *tensor_relu(const Tensor *t) { 52 | assert(t != NULL); 53 | assert(t->data != NULL || t->size == 0); 54 | 55 | Tensor *out = tensor_create(NULL, t->shape, t->ndim, t->requires_grad); 56 | 57 | for (uint64_t i = 0; i < t->size; i++) { 58 | float32_t x = t->data[i]; 59 | out->data[i] = (x > 0.0f) ? x : 0.0f; 60 | } 61 | 62 | if (out->requires_grad) { 63 | Function *fn = arena_alloc_function(); 64 | fn->apply = relu_backward; 65 | fn->output = out; 66 | fn->num_inputs = 1; 67 | fn->inputs[0] = (Tensor *)t; 68 | fn->pending_count = 0; 69 | fn->ctx = NULL; 70 | if (t->grad_fn != NULL) { 71 | t->grad_fn->pending_count++; 72 | } 73 | out->grad_fn = fn; 74 | } 75 | 76 | return out; 77 | } 78 | 79 | Tensor *tensor_tanh(const Tensor *t) { 80 | assert(t != NULL); 81 | assert(t->data != NULL || t->size == 0); 82 | 83 | Tensor *out = tensor_create(NULL, t->shape, t->ndim, t->requires_grad); 84 | 85 | for (uint64_t i = 0; i < t->size; i++) { 86 | out->data[i] = tanhf(t->data[i]); 87 | } 88 | 89 | if (out->requires_grad) { 90 | Function *fn = arena_alloc_function(); 91 | fn->apply = tanh_backward; 92 | fn->output = out; 93 | fn->num_inputs = 1; 94 | fn->inputs[0] = (Tensor *)t; 95 | fn->pending_count = 0; 96 | fn->ctx = NULL; 97 | if (t->grad_fn != NULL) { 98 | t->grad_fn->pending_count++; 99 | } 100 | out->grad_fn = fn; 101 | } 102 | 103 | return out; 104 | } 105 | 106 | Tensor *tensor_gelu(const Tensor *t) { 107 | assert(t != NULL); 108 | assert(t->data != NULL || t->size == 0); 109 | 110 | Tensor *out = tensor_create(NULL, t->shape, t->ndim, t->requires_grad); 111 | 112 | for (uint64_t i = 0; i < t->size; i++) { 113 | float32_t x = t->data[i]; 114 | out->data[i] = 0.5f * x * (1.0f + erff(x * 1 / (float32_t)sqrt(2))); 115 | } 116 | 117 | if (out->requires_grad) { 118 | Function *fn = arena_alloc_function(); 119 | fn->apply = gelu_backward; 120 | fn->output = out; 121 | fn->num_inputs = 1; 122 | fn->inputs[0] = (Tensor *)t; 123 | fn->pending_count = 0; 124 | fn->ctx = NULL; 125 | if (t->grad_fn != NULL) { 126 | t->grad_fn->pending_count++; 127 | } 128 | out->grad_fn = fn; 129 | } 130 | 131 | return out; 132 | } 133 | 134 | Tensor *tensor_softmax(const Tensor *t, int64_t dim) { 135 | assert(t != NULL); 136 | assert(t->data != NULL || t->size == 0); 137 | 138 | Tensor *max_val = tensor_max(t, dim, true); 139 | Tensor *shifted = tensor_sub(t, max_val); 140 | tensor_free(max_val); 141 | 142 | for (uint64_t i = 0; i < shifted->size; i++) { 143 | shifted->data[i] = expf(shifted->data[i]); 144 | } 145 | 146 | Tensor *sum_exp = tensor_sum(shifted, dim, true); 147 | Tensor *out = tensor_div(shifted, sum_exp); 148 | 149 | tensor_free(shifted); 150 | tensor_free(sum_exp); 151 | 152 | if (out->requires_grad) { 153 | Function *fn = arena_alloc_function(); 154 | fn->apply = softmax_backward; 155 | fn->output = out; 156 | fn->num_inputs = 1; 157 | fn->inputs[0] = (Tensor *)t; 158 | fn->pending_count = 0; 159 | 160 | // store dimension in context 161 | int64_t *ctx = (int64_t *)malloc(sizeof(int64_t)); 162 | assert(ctx != NULL && "malloc failed"); 163 | *ctx = dim; 164 | fn->ctx = ctx; 165 | 166 | if (t->grad_fn != NULL) { 167 | t->grad_fn->pending_count++; 168 | } 169 | out->grad_fn = fn; 170 | } 171 | 172 | return out; 173 | } 174 | -------------------------------------------------------------------------------- /src/ops/reshapes.c: -------------------------------------------------------------------------------- 1 | #include "ops/reshapes.h" 2 | #include "autograd.h" 3 | #include "ops/reshapes_backward.h" 4 | #include "utils/aligned_alloc.h" 5 | #include 6 | #include 7 | #include 8 | 9 | Tensor *tensor_reshape(const Tensor *t, const int64_t *new_shape, uint64_t new_ndim) { 10 | assert(t != NULL); 11 | assert(new_shape != NULL); 12 | assert(t->data != NULL || t->size == 0); 13 | assert(new_ndim <= MAX_NDIM); 14 | 15 | uint64_t new_size = 1; 16 | int64_t unknown_idx = -1; // one dimension can be -1 (inferred) 17 | 18 | // validate 19 | for (uint64_t i = 0; i < new_ndim; i++) { 20 | if (new_shape[i] == -1) { 21 | assert(unknown_idx == -1 && "only one dimension can be -1"); 22 | unknown_idx = (int64_t)i; 23 | } else { 24 | assert(new_shape[i] >= 0 && "dimension cannot be negative (except -1)"); 25 | new_size *= (uint64_t)new_shape[i]; 26 | } 27 | } 28 | if (unknown_idx != -1) { 29 | assert(t->size % new_size == 0 && "invalid shape (cannot infer dimension)"); 30 | } else { 31 | assert(new_size == t->size && "total elements must match"); 32 | } 33 | 34 | uint64_t *resolved_shape = (uint64_t *)malloc((size_t)new_ndim * sizeof(uint64_t)); 35 | assert(resolved_shape != NULL && "malloc failed"); 36 | 37 | // fill in 38 | for (uint64_t i = 0; i < new_ndim; i++) { 39 | if ((int64_t)i == unknown_idx) { 40 | resolved_shape[i] = t->size / new_size; 41 | } else { 42 | resolved_shape[i] = (uint64_t)new_shape[i]; 43 | } 44 | } 45 | 46 | Tensor *result = tensor_create(t->data, resolved_shape, new_ndim, t->requires_grad); 47 | free(resolved_shape); 48 | assert(result != NULL); 49 | assert(result->size == t->size); 50 | 51 | if (result->requires_grad) { 52 | Function *fn = arena_alloc_function(); 53 | fn->apply = reshape_backward; 54 | fn->output = result; 55 | fn->num_inputs = 1; 56 | fn->inputs[0] = (Tensor *)t; 57 | fn->pending_count = 0; 58 | 59 | ReshapeContext *ctx = (ReshapeContext *)malloc(sizeof(ReshapeContext)); 60 | assert(ctx != NULL && "malloc failed"); 61 | assert(t->ndim <= MAX_NDIM); 62 | memcpy(ctx->shape, t->shape, t->ndim * sizeof(uint64_t)); 63 | ctx->ndim = t->ndim; 64 | fn->ctx = ctx; 65 | 66 | if (t->grad_fn != NULL) { 67 | t->grad_fn->pending_count++; 68 | } 69 | 70 | result->grad_fn = fn; 71 | } 72 | 73 | return result; 74 | } 75 | 76 | /* 77 | * transpose swaps two dimensions 78 | * 79 | * example: (2, 3) -> (3, 2) 80 | * 81 | * T: 82 | * [[1, 2, 3], 83 | * [4, 5, 6]] 84 | * 85 | * T.T: 86 | * [[1, 4], 87 | * [2, 5], 88 | * [3, 6]] 89 | */ 90 | Tensor *tensor_transpose(const Tensor *t, uint64_t dim0, uint64_t dim1) { 91 | assert(t != NULL); 92 | assert(t->data != NULL || t->size == 0); 93 | assert(t->ndim <= MAX_NDIM); 94 | if (t->size > 0) { 95 | assert((uintptr_t)t->data % CACHELINE_SIZE == 0 && "data is not properly aligned"); 96 | } 97 | 98 | if (t->ndim < 2) { 99 | return tensor_create(t->data, t->shape, t->ndim, t->requires_grad); 100 | } 101 | 102 | assert(dim0 < t->ndim && "dimension 0 out of bounds"); 103 | assert(dim1 < t->ndim && "dimension 1 out of bounds"); 104 | 105 | uint64_t *new_shape = (uint64_t *)malloc((size_t)t->ndim * sizeof(uint64_t)); 106 | assert(new_shape != NULL && "malloc failed"); 107 | memcpy(new_shape, t->shape, (size_t)t->ndim * sizeof(uint64_t)); 108 | 109 | // swap dims 110 | uint64_t temp = new_shape[dim0]; 111 | new_shape[dim0] = new_shape[dim1]; 112 | new_shape[dim1] = temp; 113 | 114 | Tensor *result = tensor_zeros(new_shape, t->ndim, t->requires_grad); 115 | free(new_shape); 116 | 117 | uint64_t *curr = (uint64_t *)calloc((size_t)t->ndim, sizeof(uint64_t)); 118 | assert(curr != NULL && "calloc failed"); 119 | 120 | for (uint64_t i = 0; i < result->size; i++) { 121 | linear_to_multidim_mut(i, result->shape, t->ndim, curr); 122 | 123 | uint64_t offset = 0; 124 | for (uint64_t d = 0; d < t->ndim; d++) { 125 | // multidim indices are swapped compared to output at dim0/dim1 126 | uint64_t idx_val = curr[d]; 127 | if (d == dim0) { 128 | idx_val = curr[dim1]; 129 | } else if (d == dim1) { 130 | idx_val = curr[dim0]; 131 | } 132 | offset += idx_val * t->strides[d]; 133 | } 134 | assert(offset < t->size && "offset out of bounds"); 135 | 136 | result->data[i] = t->data[offset]; 137 | } 138 | free(curr); 139 | 140 | if (result->requires_grad) { 141 | Function *fn = arena_alloc_function(); 142 | fn->apply = transpose_backward; 143 | fn->output = result; 144 | fn->num_inputs = 1; 145 | fn->inputs[0] = (Tensor *)t; 146 | fn->pending_count = 0; 147 | 148 | TransposeContext *ctx = (TransposeContext *)malloc(sizeof(TransposeContext)); 149 | assert(ctx != NULL && "malloc failed"); 150 | ctx->dim0 = dim0; 151 | ctx->dim1 = dim1; 152 | fn->ctx = ctx; 153 | 154 | if (t->grad_fn != NULL) { 155 | t->grad_fn->pending_count++; 156 | } 157 | 158 | result->grad_fn = fn; 159 | } 160 | 161 | return result; 162 | } 163 | -------------------------------------------------------------------------------- /src/utils/cifar10.c: -------------------------------------------------------------------------------- 1 | #include "cifar10.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // 8 | // data loader from disk 9 | // 10 | 11 | // path to subdirectory containing the downloaded data files 12 | #define DATA_DIRECTORY "data" 13 | 14 | static uint8_t train_images[NUM_TRAIN_SAMPLES * INPUT_SIZE]; 15 | static label_t train_labels[NUM_TRAIN_SAMPLES]; 16 | static uint8_t test_images[NUM_TEST_SAMPLES * INPUT_SIZE]; 17 | static label_t test_labels[NUM_TEST_SAMPLES]; 18 | 19 | static void load_batch(const char *filepath, uint8_t *images_out, label_t *labels_out, int32_t count) { 20 | assert(filepath != NULL); 21 | assert(images_out != NULL); 22 | assert(labels_out != NULL); 23 | assert(count > 0); 24 | 25 | FILE *f = fopen(filepath, "rb"); 26 | assert(f != NULL && "failed to open batch file"); 27 | 28 | for (int32_t i = 0; i < count; i++) { 29 | uint8_t label; 30 | int64_t label_read = (int64_t)fread(&label, 1, 1, f); 31 | (void)label_read; 32 | assert(label_read == 1 && "failed to read label"); 33 | assert(label < NUM_CLASSES && "invalid label"); 34 | 35 | labels_out[i] = (label_t)label; 36 | 37 | int64_t read = (int64_t)fread(images_out + (i * INPUT_SIZE), 1, INPUT_SIZE, f); 38 | (void)read; 39 | assert(read == INPUT_SIZE && "failed to read image data"); 40 | } 41 | 42 | int32_t close_result = fclose(f); 43 | (void)close_result; 44 | assert(close_result == 0 && "failed to close batch file"); 45 | } 46 | 47 | // constructor to load data on program start 48 | __attribute__((constructor)) static void load_data(void) { 49 | static const char *const batches[] = {"data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch_5.bin"}; 50 | assert(NUM_TRAIN_SAMPLES == 50000); 51 | assert(NUM_TEST_SAMPLES == 10000); 52 | 53 | int32_t samples_per_batch = NUM_TRAIN_SAMPLES / 5; 54 | assert(samples_per_batch == 10000); 55 | 56 | // load train data 57 | for (int32_t i = 0; i < 5; i++) { 58 | char path[512]; 59 | int32_t written = snprintf(path, sizeof(path), "%s/%s", DATA_DIRECTORY, batches[i]); 60 | (void)written; 61 | assert(written > 0 && written < (int32_t)sizeof(path) && "path buffer overflow"); 62 | 63 | int32_t offset = i * samples_per_batch; 64 | load_batch(path, train_images + (offset * INPUT_SIZE), train_labels + offset, samples_per_batch); 65 | } 66 | 67 | // load test data 68 | char path[512]; 69 | int32_t written = snprintf(path, sizeof(path), "%s/test_batch.bin", DATA_DIRECTORY); 70 | (void)written; 71 | assert(written > 0 && written < (int32_t)sizeof(path) && "path buffer overflow"); 72 | load_batch(path, test_images, test_labels, NUM_TEST_SAMPLES); 73 | } 74 | 75 | static Tensor *images_to_tensor(const uint8_t *data, uint64_t count) { 76 | uint64_t total_elements = count * INPUT_SIZE; 77 | const uint64_t shape[] = {count, CHANNELS, HEIGHT, WIDTH}; 78 | Tensor *t = tensor_zeros(shape, 4, false); 79 | assert(t != NULL); 80 | 81 | for (uint64_t i = 0; i < total_elements; i++) { 82 | t->data[i] = (float32_t)data[i] / 255.0f; // normalize in place 83 | } 84 | return t; 85 | } 86 | 87 | static Tensor *labels_to_tensor(const label_t *labels, uint64_t count) { 88 | const uint64_t shape[] = {count}; 89 | Tensor *t = tensor_zeros(shape, 1, false); 90 | assert(t != NULL); 91 | 92 | for (uint64_t i = 0; i < count; i++) { 93 | t->data[i] = (float32_t)labels[i]; 94 | } 95 | return t; 96 | } 97 | 98 | Tensor *cifar10_get_train_images(void) { return images_to_tensor(train_images, NUM_TRAIN_SAMPLES); } 99 | 100 | Tensor *cifar10_get_train_labels(void) { return labels_to_tensor(train_labels, NUM_TRAIN_SAMPLES); } 101 | 102 | Tensor *cifar10_get_test_images(void) { return images_to_tensor(test_images, NUM_TEST_SAMPLES); } 103 | 104 | Tensor *cifar10_get_test_labels(void) { return labels_to_tensor(test_labels, NUM_TEST_SAMPLES); } 105 | 106 | // 107 | // utils 108 | // 109 | 110 | const char *label_to_str(label_t label) { 111 | static const char *const labels[] = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"}; 112 | assert(label >= 0 && label < NUM_CLASSES); 113 | return labels[label]; 114 | } 115 | 116 | Tensor *get_batch(const Tensor *data, uint64_t batch_idx, uint64_t batch_size) { 117 | assert(data != NULL); 118 | assert(data->ndim >= 1); 119 | 120 | uint64_t total_samples = data->shape[0]; 121 | uint64_t start = batch_idx * batch_size; 122 | 123 | if (start >= total_samples) { 124 | return NULL; 125 | } 126 | 127 | uint64_t actual_batch = (start + batch_size > total_samples) ? (total_samples - start) : batch_size; // edge case for last batch 128 | uint64_t elements_per_sample = data->size / total_samples; 129 | uint64_t batch_elements = actual_batch * elements_per_sample; 130 | 131 | float32_t *batch_data = (float32_t *)malloc(batch_elements * sizeof(float32_t)); 132 | assert(batch_data != NULL && "malloc failed"); 133 | memcpy(batch_data, &data->data[start * elements_per_sample], batch_elements * sizeof(float32_t)); 134 | 135 | uint64_t *batch_shape = (uint64_t *)malloc(data->ndim * sizeof(uint64_t)); 136 | assert(batch_shape != NULL && "malloc failed"); 137 | memcpy(batch_shape, data->shape, data->ndim * sizeof(uint64_t)); 138 | batch_shape[0] = actual_batch; 139 | 140 | Tensor *batch = tensor_create(batch_data, batch_shape, data->ndim, false); 141 | free(batch_data); 142 | free(batch_shape); 143 | 144 | return batch; 145 | } 146 | -------------------------------------------------------------------------------- /src/autograd.c: -------------------------------------------------------------------------------- 1 | #include "autograd.h" 2 | #include "tensor.h" 3 | #include "utils/aligned_alloc.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // 12 | // arena allocator 13 | // 14 | 15 | #define ARENA_CAPACITY (4 * 1024 * 1024) // 4MB 16 | 17 | static __thread Arena *thread_arena = NULL; 18 | 19 | static Arena *arena_create(void) { 20 | Arena *arena = (Arena *)malloc(sizeof(Arena)); 21 | assert(arena != NULL && "malloc failed"); 22 | arena->memory = safe_aligned_alloc(ARENA_CAPACITY); 23 | arena->capacity = ARENA_CAPACITY; 24 | arena->offset = 0; 25 | return arena; 26 | } 27 | 28 | static Arena *get_or_create_arena(void) { 29 | if (thread_arena == NULL) { 30 | thread_arena = arena_create(); 31 | } 32 | return thread_arena; 33 | } 34 | 35 | Function *arena_alloc_function(void) { 36 | Arena *arena = get_or_create_arena(); 37 | 38 | // align to 8 bytes 39 | size_t aligned_offset = (arena->offset + 7) & ~(size_t)7; 40 | 41 | // oom, but in prod code you would allocate a new arena 42 | if (aligned_offset + sizeof(Function) > arena->capacity) { 43 | assert(false && "arena out of memory"); 44 | } 45 | 46 | Function *fn = (Function *)((char *)arena->memory + aligned_offset); 47 | arena->offset = aligned_offset + sizeof(Function); 48 | 49 | memset(fn, 0, sizeof(Function)); 50 | 51 | return fn; 52 | } 53 | 54 | void arena_free(void) { 55 | if (thread_arena) { 56 | free(thread_arena->memory); 57 | free(thread_arena); 58 | thread_arena = NULL; 59 | } 60 | } 61 | 62 | // 63 | // backward pass 64 | // 65 | 66 | #define MAX_QUEUE_SIZE 10000 67 | 68 | typedef struct { 69 | Function *items[MAX_QUEUE_SIZE]; 70 | int front; 71 | int rear; 72 | int count; 73 | } Queue; 74 | 75 | static void queue_init(Queue *q) { 76 | q->front = 0; 77 | q->rear = 0; 78 | q->count = 0; 79 | } 80 | 81 | static void queue_push(Queue *q, Function *fn) { 82 | assert(q->count < MAX_QUEUE_SIZE && "queue overflow"); 83 | q->items[q->rear] = fn; 84 | q->rear = (q->rear + 1) % MAX_QUEUE_SIZE; 85 | q->count++; 86 | } 87 | 88 | static Function *queue_pop(Queue *q) { 89 | assert(q->count > 0 && "queue underflow"); 90 | Function *fn = q->items[q->front]; 91 | q->front = (q->front + 1) % MAX_QUEUE_SIZE; 92 | q->count--; 93 | return fn; 94 | } 95 | 96 | static bool queue_empty(const Queue *q) { return q->count == 0; } 97 | 98 | void backward(Tensor *loss) { 99 | assert(loss != NULL); 100 | assert(loss->ndim == 0 && "loss must be scalar"); 101 | assert(loss->size == 1 && "loss must be scalar"); 102 | 103 | // initialize loss->grad to 1.0 (d loss / d loss = 1) 104 | const uint64_t shape_scalar[] = {}; 105 | float32_t one = 1.0f; 106 | loss->grad = tensor_create(&one, shape_scalar, 0, false); 107 | 108 | // if loss has no grad_fn, it's a leaf and there's nothing to backprop 109 | if (loss->grad_fn == NULL) { 110 | arena_free(); 111 | return; 112 | } 113 | 114 | // work queue 115 | Queue queue; 116 | queue_init(&queue); 117 | queue_push(&queue, loss->grad_fn); 118 | 119 | // process queue 120 | while (!queue_empty(&queue)) { 121 | Function *fn = queue_pop(&queue); 122 | assert(fn->output != NULL); 123 | assert(fn->output->grad != NULL && "fn->output->grad is NULL"); 124 | 125 | // call backward kernel 126 | if (fn->apply != NULL) { 127 | fn->apply(fn, fn->output->grad); 128 | } 129 | 130 | // for each parent with non-NULL grad_fn, decrement pending_count 131 | for (uint32_t i = 0; i < fn->num_inputs; i++) { 132 | Tensor *parent = fn->inputs[i]; 133 | if (parent == NULL) { 134 | continue; 135 | } 136 | 137 | // only tensors with grad_fn have pending_count 138 | if (parent->grad_fn != NULL) { 139 | assert(parent->grad_fn->pending_count > 0 && "pending_count already zero"); 140 | parent->grad_fn->pending_count--; 141 | 142 | // if pending_count reaches zero, all consumers processed 143 | if (parent->grad_fn->pending_count == 0) { 144 | queue_push(&queue, parent->grad_fn); 145 | } 146 | } 147 | } 148 | } 149 | 150 | arena_free(); 151 | } 152 | 153 | // 154 | // gradient accumulation 155 | // 156 | 157 | // helper function to check if two shapes are broadcastable 158 | static bool shapes_equal(const Tensor *a, const Tensor *b) { 159 | if (a->ndim != b->ndim) { 160 | return false; 161 | } 162 | for (uint64_t i = 0; i < a->ndim; i++) { 163 | if (a->shape[i] != b->shape[i]) { 164 | return false; 165 | } 166 | } 167 | return true; 168 | } 169 | 170 | void accumulate_grad(Tensor *tensor, Tensor *new_grad) { 171 | assert(tensor != NULL); 172 | assert(new_grad != NULL); 173 | if (!shapes_equal(tensor, new_grad)) { 174 | assert(false && "shape mismatch in accumulate_grad. broadcast reduction not yet implemented"); 175 | } 176 | 177 | if (tensor->grad == NULL) { 178 | tensor->grad = new_grad; 179 | return; 180 | } 181 | 182 | Tensor *summed = tensor_zeros(tensor->shape, tensor->ndim, false); 183 | for (uint64_t i = 0; i < tensor->size; i++) { 184 | summed->data[i] = tensor->grad->data[i] + new_grad->data[i]; 185 | } 186 | 187 | tensor_free(tensor->grad); 188 | tensor_free(new_grad); 189 | 190 | tensor->grad = summed; 191 | } 192 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignArrayOfStructures: None 7 | AlignConsecutiveMacros: None 8 | AlignConsecutiveAssignments: None 9 | AlignConsecutiveBitFields: None 10 | AlignConsecutiveDeclarations: None 11 | AlignEscapedNewlines: Right 12 | AlignOperands: Align 13 | AlignTrailingComments: true 14 | AllowAllArgumentsOnNextLine: true 15 | AllowAllParametersOfDeclarationOnNextLine: true 16 | AllowShortEnumsOnASingleLine: true 17 | AllowShortBlocksOnASingleLine: Never 18 | AllowShortCaseLabelsOnASingleLine: false 19 | AllowShortFunctionsOnASingleLine: All 20 | AllowShortLambdasOnASingleLine: All 21 | AllowShortIfStatementsOnASingleLine: Never 22 | AllowShortLoopsOnASingleLine: false 23 | AlwaysBreakAfterDefinitionReturnType: None 24 | AlwaysBreakAfterReturnType: None 25 | AlwaysBreakBeforeMultilineStrings: false 26 | AlwaysBreakTemplateDeclarations: MultiLine 27 | AttributeMacros: 28 | - __capability 29 | BinPackArguments: true 30 | BinPackParameters: true 31 | BraceWrapping: 32 | AfterCaseLabel: false 33 | AfterClass: false 34 | AfterControlStatement: Never 35 | AfterEnum: false 36 | AfterFunction: false 37 | AfterNamespace: false 38 | AfterObjCDeclaration: false 39 | AfterStruct: false 40 | AfterUnion: false 41 | AfterExternBlock: false 42 | BeforeCatch: false 43 | BeforeElse: false 44 | BeforeLambdaBody: false 45 | BeforeWhile: false 46 | IndentBraces: false 47 | SplitEmptyFunction: true 48 | SplitEmptyRecord: true 49 | SplitEmptyNamespace: true 50 | BreakBeforeBinaryOperators: None 51 | BreakBeforeConceptDeclarations: true 52 | BreakBeforeBraces: Attach 53 | BreakBeforeInheritanceComma: false 54 | BreakInheritanceList: BeforeColon 55 | BreakBeforeTernaryOperators: true 56 | BreakConstructorInitializersBeforeComma: false 57 | BreakConstructorInitializers: BeforeColon 58 | BreakAfterJavaFieldAnnotations: false 59 | BreakStringLiterals: true 60 | ColumnLimit: 5000 61 | CommentPragmas: '^ IWYU pragma:' 62 | QualifierAlignment: Leave 63 | CompactNamespaces: false 64 | ConstructorInitializerIndentWidth: 4 65 | ContinuationIndentWidth: 4 66 | Cpp11BracedListStyle: true 67 | DeriveLineEnding: true 68 | DerivePointerAlignment: false 69 | DisableFormat: false 70 | EmptyLineAfterAccessModifier: Never 71 | EmptyLineBeforeAccessModifier: LogicalBlock 72 | ExperimentalAutoDetectBinPacking: false 73 | PackConstructorInitializers: BinPack 74 | BasedOnStyle: '' 75 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 76 | AllowAllConstructorInitializersOnNextLine: true 77 | FixNamespaceComments: true 78 | ForEachMacros: 79 | - foreach 80 | - Q_FOREACH 81 | - BOOST_FOREACH 82 | IfMacros: 83 | - KJ_IF_MAYBE 84 | IncludeBlocks: Preserve 85 | IncludeCategories: 86 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 87 | Priority: 2 88 | SortPriority: 0 89 | CaseSensitive: false 90 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 91 | Priority: 3 92 | SortPriority: 0 93 | CaseSensitive: false 94 | - Regex: '.*' 95 | Priority: 1 96 | SortPriority: 0 97 | CaseSensitive: false 98 | IncludeIsMainRegex: '(Test)?$' 99 | IncludeIsMainSourceRegex: '' 100 | IndentAccessModifiers: false 101 | IndentCaseLabels: false 102 | IndentCaseBlocks: false 103 | IndentGotoLabels: true 104 | IndentPPDirectives: None 105 | IndentExternBlock: AfterExternBlock 106 | IndentRequires: false 107 | IndentWidth: 4 108 | IndentWrappedFunctionNames: false 109 | InsertTrailingCommas: None 110 | JavaScriptQuotes: Leave 111 | JavaScriptWrapImports: true 112 | KeepEmptyLinesAtTheStartOfBlocks: true 113 | LambdaBodyIndentation: Signature 114 | MacroBlockBegin: '' 115 | MacroBlockEnd: '' 116 | MaxEmptyLinesToKeep: 1 117 | NamespaceIndentation: None 118 | ObjCBinPackProtocolList: Auto 119 | ObjCBlockIndentWidth: 4 120 | ObjCBreakBeforeNestedBlockParam: true 121 | ObjCSpaceAfterProperty: false 122 | ObjCSpaceBeforeProtocolList: true 123 | PenaltyBreakAssignment: 2 124 | PenaltyBreakBeforeFirstCallParameter: 19 125 | PenaltyBreakComment: 300 126 | PenaltyBreakFirstLessLess: 120 127 | PenaltyBreakOpenParenthesis: 0 128 | PenaltyBreakString: 1000 129 | PenaltyBreakTemplateDeclaration: 10 130 | PenaltyExcessCharacter: 1000000 131 | PenaltyReturnTypeOnItsOwnLine: 60 132 | PenaltyIndentedWhitespace: 0 133 | PointerAlignment: Right 134 | PPIndentWidth: -1 135 | ReferenceAlignment: Pointer 136 | ReflowComments: true 137 | RemoveBracesLLVM: false 138 | SeparateDefinitionBlocks: Leave 139 | ShortNamespaceLines: 1 140 | SortIncludes: CaseSensitive 141 | SortJavaStaticImport: Before 142 | SortUsingDeclarations: true 143 | SpaceAfterCStyleCast: false 144 | SpaceAfterLogicalNot: false 145 | SpaceAfterTemplateKeyword: true 146 | SpaceBeforeAssignmentOperators: true 147 | SpaceBeforeCaseColon: false 148 | SpaceBeforeCpp11BracedList: false 149 | SpaceBeforeCtorInitializerColon: true 150 | SpaceBeforeInheritanceColon: true 151 | SpaceBeforeParens: ControlStatements 152 | SpaceBeforeParensOptions: 153 | AfterControlStatements: true 154 | AfterForeachMacros: true 155 | AfterFunctionDefinitionName: false 156 | AfterFunctionDeclarationName: false 157 | AfterIfMacros: true 158 | AfterOverloadedOperator: false 159 | BeforeNonEmptyParentheses: false 160 | SpaceAroundPointerQualifiers: Default 161 | SpaceBeforeRangeBasedForLoopColon: true 162 | SpaceInEmptyBlock: false 163 | SpaceInEmptyParentheses: false 164 | SpacesBeforeTrailingComments: 1 165 | SpacesInAngles: Never 166 | SpacesInConditionalStatement: false 167 | SpacesInContainerLiterals: true 168 | SpacesInCStyleCastParentheses: false 169 | SpacesInLineCommentPrefix: 170 | Minimum: 1 171 | Maximum: -1 172 | SpacesInParentheses: false 173 | SpacesInSquareBrackets: false 174 | SpaceBeforeSquareBrackets: false 175 | BitFieldColonSpacing: Both 176 | Standard: Latest 177 | StatementAttributeLikeMacros: 178 | - Q_EMIT 179 | StatementMacros: 180 | - Q_UNUSED 181 | - QT_REQUIRE_VERSION 182 | TabWidth: 8 183 | UseCRLF: false 184 | UseTab: Never 185 | WhitespaceSensitiveMacros: 186 | - STRINGIZE 187 | - PP_STRINGIZE 188 | - BOOST_PP_STRINGIZE 189 | - NS_SWIFT_NAME 190 | - CF_SWIFT_NAME -------------------------------------------------------------------------------- /test/test_losses_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/losses_backward.h" 2 | #include "tensor.h" 3 | #include "unity.h" 4 | #include 5 | #include 6 | #include 7 | 8 | void setUp(void) {} 9 | void tearDown(void) {} 10 | 11 | static Tensor *create_tensor_1d(float32_t *data, uint64_t size) { 12 | uint64_t shape[] = {size}; 13 | return tensor_create(data, shape, 1, false); 14 | } 15 | 16 | static Tensor *create_tensor_2d(float32_t *data, uint64_t rows, uint64_t cols) { 17 | uint64_t shape[] = {rows, cols}; 18 | return tensor_create(data, shape, 2, false); 19 | } 20 | 21 | void test_mse_loss_backward_standard(void) { 22 | float32_t p_data[] = {1.0f, 2.0f, 3.0f}; 23 | float32_t t_data[] = {1.5f, 2.5f, 2.8f}; 24 | Tensor *pred = create_tensor_1d(p_data, 3); 25 | Tensor *target = create_tensor_1d(t_data, 3); 26 | 27 | Tensor *grad = mse_loss_backward(pred, target); 28 | 29 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -0.3333333f, grad->data[0]); 30 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -0.3333333f, grad->data[1]); 31 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.1333333f, grad->data[2]); 32 | 33 | tensor_free(grad); 34 | tensor_free(pred); 35 | tensor_free(target); 36 | } 37 | 38 | void test_mse_loss_backward_size_scaling(void) { 39 | float32_t p_data[] = {10.0f, 10.0f}; 40 | float32_t t_data[] = {0.0f, 0.0f}; 41 | Tensor *pred = create_tensor_1d(p_data, 2); 42 | Tensor *target = create_tensor_1d(t_data, 2); 43 | 44 | Tensor *grad = mse_loss_backward(pred, target); 45 | 46 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 10.0f, grad->data[0]); 47 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 10.0f, grad->data[1]); 48 | 49 | tensor_free(grad); 50 | tensor_free(pred); 51 | tensor_free(target); 52 | } 53 | 54 | void test_mse_loss_backward_zeros(void) { 55 | float32_t data[] = {1.0f, 2.0f, -1.0f}; 56 | Tensor *pred = create_tensor_1d(data, 3); 57 | Tensor *target = create_tensor_1d(data, 3); 58 | 59 | Tensor *grad = mse_loss_backward(pred, target); 60 | 61 | for (int i = 0; i < 3; i++) { 62 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[i]); 63 | } 64 | 65 | tensor_free(grad); 66 | tensor_free(pred); 67 | tensor_free(target); 68 | } 69 | 70 | void test_cross_entropy_backward_standard(void) { 71 | float32_t l_data_in[] = {0.0f, 0.0f, 0.0f, 0.0f}; 72 | float32_t t_data_in[] = {0.0f, 1.0f}; 73 | 74 | Tensor *logits = create_tensor_2d(l_data_in, 2, 2); 75 | Tensor *targets = create_tensor_1d(t_data_in, 2); 76 | 77 | Tensor *grad = cross_entropy_loss_backward(logits, targets); 78 | 79 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -0.25f, grad->data[0]); 80 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.25f, grad->data[1]); 81 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.25f, grad->data[2]); 82 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -0.25f, grad->data[3]); 83 | 84 | tensor_free(grad); 85 | tensor_free(logits); 86 | tensor_free(targets); 87 | } 88 | 89 | void test_cross_entropy_backward_batch(void) { 90 | float32_t l_data[] = {100.0f, 0.0f, 0.0f}; 91 | float32_t t_data[] = {0.0f}; 92 | 93 | Tensor *logits = create_tensor_2d(l_data, 1, 3); 94 | Tensor *targets = create_tensor_1d(t_data, 1); 95 | 96 | Tensor *grad = cross_entropy_loss_backward(logits, targets); 97 | 98 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.0f, grad->data[0]); 99 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.0f, grad->data[1]); 100 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.0f, grad->data[2]); 101 | 102 | tensor_free(grad); 103 | tensor_free(logits); 104 | tensor_free(targets); 105 | } 106 | 107 | void test_cross_entropy_backward_sum_zero(void) { 108 | float32_t l_data[] = {1.0f, 2.0f, 3.0f}; 109 | float32_t t_data[] = {1.0f}; 110 | 111 | Tensor *logits = create_tensor_2d(l_data, 1, 3); 112 | Tensor *targets = create_tensor_1d(t_data, 1); 113 | 114 | Tensor *grad = cross_entropy_loss_backward(logits, targets); 115 | 116 | float32_t sum = 0.0f; 117 | for (int i = 0; i < 3; i++) { 118 | sum += grad->data[i]; 119 | } 120 | 121 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 0.0f, sum); 122 | 123 | tensor_free(grad); 124 | tensor_free(logits); 125 | tensor_free(targets); 126 | } 127 | 128 | void test_bce_loss_backward_standard(void) { 129 | float32_t p_data[] = {0.5f}; 130 | float32_t t_data[] = {1.0f}; 131 | 132 | Tensor *pred = create_tensor_1d(p_data, 1); 133 | Tensor *target = create_tensor_1d(t_data, 1); 134 | 135 | Tensor *grad = binary_cross_entropy_loss_backward(pred, target); 136 | 137 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -2.0f, grad->data[0]); 138 | 139 | tensor_free(grad); 140 | tensor_free(pred); 141 | tensor_free(target); 142 | } 143 | 144 | void test_bce_loss_backward_batch(void) { 145 | float32_t p_data[] = {0.5f, 0.5f}; 146 | float32_t t_data[] = {1.0f, 0.0f}; 147 | 148 | Tensor *pred = create_tensor_1d(p_data, 2); 149 | Tensor *target = create_tensor_1d(t_data, 2); 150 | 151 | Tensor *grad = binary_cross_entropy_loss_backward(pred, target); 152 | 153 | TEST_ASSERT_FLOAT_WITHIN(1e-5, -1.0f, grad->data[0]); 154 | TEST_ASSERT_FLOAT_WITHIN(1e-5, 1.0f, grad->data[1]); 155 | 156 | tensor_free(grad); 157 | tensor_free(pred); 158 | tensor_free(target); 159 | } 160 | 161 | void test_bce_loss_backward_edge_handling(void) { 162 | float32_t p_data[] = {0.0001f}; 163 | float32_t t_data[] = {0.0f}; 164 | 165 | Tensor *pred = create_tensor_1d(p_data, 1); 166 | Tensor *target = create_tensor_1d(t_data, 1); 167 | 168 | Tensor *grad = binary_cross_entropy_loss_backward(pred, target); 169 | 170 | TEST_ASSERT_FLOAT_WITHIN(0.1f, 1.0f, grad->data[0]); 171 | 172 | tensor_free(grad); 173 | tensor_free(pred); 174 | tensor_free(target); 175 | } 176 | 177 | int main(void) { 178 | UNITY_BEGIN(); 179 | RUN_TEST(test_mse_loss_backward_standard); 180 | RUN_TEST(test_mse_loss_backward_size_scaling); 181 | RUN_TEST(test_mse_loss_backward_zeros); 182 | RUN_TEST(test_cross_entropy_backward_standard); 183 | RUN_TEST(test_cross_entropy_backward_batch); 184 | RUN_TEST(test_cross_entropy_backward_sum_zero); 185 | RUN_TEST(test_bce_loss_backward_standard); 186 | RUN_TEST(test_bce_loss_backward_batch); 187 | RUN_TEST(test_bce_loss_backward_edge_handling); 188 | return UNITY_END(); 189 | } 190 | -------------------------------------------------------------------------------- /src/utils/augment.c: -------------------------------------------------------------------------------- 1 | #include "augment.h" 2 | #include 3 | #include 4 | #include 5 | 6 | // applies an in-place random horizontal flip to the tensor with probability p. 7 | // only makes sense for images where semantics are invariant to horizontal flipping. 8 | void random_horizontal_flip_mut(Tensor *t, float32_t p) { 9 | assert(t != NULL); 10 | assert(t->data != NULL); 11 | assert(p >= 0.0f && p <= 1.0f); 12 | 13 | float32_t rand_float = (float32_t)rand() / (float32_t)RAND_MAX; 14 | if (rand_float >= p) { 15 | return; 16 | } 17 | 18 | uint64_t width = t->shape[t->ndim - 1]; 19 | assert(width > 0); 20 | assert(t->strides[t->ndim - 1] == 1 && "tensor must be contiguous in last dimension"); 21 | 22 | Tensor *out = tensor_create(NULL, t->shape, t->ndim, t->requires_grad); 23 | assert(out != NULL); 24 | 25 | uint64_t num_rows = t->size / width; 26 | for (uint64_t r = 0; r < num_rows; r++) { 27 | const float32_t *src_row = t->data + r * width; 28 | float32_t *dst_row = out->data + r * width; 29 | for (uint64_t c = 0; c < width; c++) { 30 | dst_row[c] = src_row[width - 1 - c]; 31 | } 32 | } 33 | 34 | free(t->data); 35 | t->data = out->data; 36 | out->data = NULL; 37 | tensor_free(out); 38 | } 39 | 40 | /* 41 | * applies a random crop to the tensor. 42 | * simulates translation invariance by shifting content. 43 | * virtually pads the image with zeros, then selects a random window. 44 | * 45 | * example (padding=1, target=2x2): 46 | * 47 | * input (2x2): virtual padded (4x4): random crop (2x2): 48 | * [1, 1] 0 0 0 0 [0, 0] 49 | * [1, 1] -> 0 [1, 1] 0 -> [0, 1] 50 | * 0 [1, 1] 0 (if top=0, left=0) 51 | * 0 0 0 0 52 | */ 53 | void random_crop_mut(Tensor *t, uint64_t target_h, uint64_t target_w, uint64_t padding) { 54 | assert(t != NULL); 55 | assert(t->data != NULL); 56 | assert(target_h > 0); 57 | assert(target_w > 0); 58 | assert(t->ndim == 2 || t->ndim == 3); 59 | 60 | // 2D case: [H, W] 61 | if (t->ndim == 2) { 62 | uint64_t h = t->shape[0]; 63 | uint64_t w = t->shape[1]; 64 | uint64_t padded_h = h + 2 * padding; 65 | uint64_t padded_w = w + 2 * padding; 66 | uint64_t max_top = padded_h >= target_h ? padded_h - target_h : 0; 67 | uint64_t max_left = padded_w >= target_w ? padded_w - target_w : 0; 68 | uint64_t top = max_top > 0 ? ((uint64_t)rand() % (max_top + 1)) : 0; 69 | uint64_t left = max_left > 0 ? ((uint64_t)rand() % (max_left + 1)) : 0; 70 | 71 | const uint64_t out_shape[2] = {target_h, target_w}; 72 | Tensor *out = tensor_zeros(out_shape, 2, t->requires_grad); 73 | assert(out != NULL); 74 | 75 | for (uint64_t oy = 0; oy < target_h; oy++) { 76 | for (uint64_t ox = 0; ox < target_w; ox++) { 77 | int64_t iy = (int64_t)(top + oy) - (int64_t)padding; 78 | int64_t ix = (int64_t)(left + ox) - (int64_t)padding; 79 | if (iy < 0 || iy >= (int64_t)h || ix < 0 || ix >= (int64_t)w) { 80 | continue; 81 | } 82 | uint64_t out_idx = oy * out->strides[0] + ox * out->strides[1]; 83 | uint64_t in_idx = (uint64_t)iy * t->strides[0] + (uint64_t)ix * t->strides[1]; 84 | out->data[out_idx] = t->data[in_idx]; 85 | } 86 | } 87 | 88 | free(t->data); 89 | free(t->shape); 90 | free(t->strides); 91 | t->data = out->data; 92 | t->shape = out->shape; 93 | t->strides = out->strides; 94 | t->size = out->size; 95 | out->data = NULL; 96 | out->shape = NULL; 97 | out->strides = NULL; 98 | tensor_free(out); 99 | return; 100 | } 101 | 102 | // 3D case: detect [C, H, W] vs [H, W, C] 103 | bool is_chw = t->shape[0] <= 4; 104 | uint64_t h = is_chw ? t->shape[1] : t->shape[0]; 105 | uint64_t w = is_chw ? t->shape[2] : t->shape[1]; 106 | uint64_t c = is_chw ? t->shape[0] : t->shape[2]; 107 | 108 | uint64_t padded_h = h + 2 * padding; 109 | uint64_t padded_w = w + 2 * padding; 110 | uint64_t max_top = padded_h >= target_h ? padded_h - target_h : 0; 111 | uint64_t max_left = padded_w >= target_w ? padded_w - target_w : 0; 112 | uint64_t top = max_top > 0 ? ((uint64_t)rand() % (max_top + 1)) : 0; 113 | uint64_t left = max_left > 0 ? ((uint64_t)rand() % (max_left + 1)) : 0; 114 | 115 | uint64_t out_shape[3]; 116 | if (is_chw) { 117 | out_shape[0] = c; 118 | out_shape[1] = target_h; 119 | out_shape[2] = target_w; 120 | } else { 121 | out_shape[0] = target_h; 122 | out_shape[1] = target_w; 123 | out_shape[2] = c; 124 | } 125 | 126 | Tensor *out = tensor_zeros(out_shape, 3, t->requires_grad); 127 | assert(out != NULL); 128 | 129 | for (uint64_t oy = 0; oy < target_h; oy++) { 130 | for (uint64_t ox = 0; ox < target_w; ox++) { 131 | int64_t iy = (int64_t)(top + oy) - (int64_t)padding; 132 | int64_t ix = (int64_t)(left + ox) - (int64_t)padding; 133 | if (iy < 0 || iy >= (int64_t)h || ix < 0 || ix >= (int64_t)w) { 134 | continue; 135 | } 136 | 137 | for (uint64_t ci = 0; ci < c; ci++) { 138 | uint64_t out_idx, in_idx; 139 | if (is_chw) { 140 | out_idx = ci * out->strides[0] + oy * out->strides[1] + ox * out->strides[2]; 141 | in_idx = ci * t->strides[0] + (uint64_t)iy * t->strides[1] + (uint64_t)ix * t->strides[2]; 142 | } else { 143 | out_idx = oy * out->strides[0] + ox * out->strides[1] + ci * out->strides[2]; 144 | in_idx = (uint64_t)iy * t->strides[0] + (uint64_t)ix * t->strides[1] + ci * t->strides[2]; 145 | } 146 | out->data[out_idx] = t->data[in_idx]; 147 | } 148 | } 149 | } 150 | 151 | free(t->data); 152 | free(t->shape); 153 | free(t->strides); 154 | t->data = out->data; 155 | t->shape = out->shape; 156 | t->strides = out->strides; 157 | t->size = out->size; 158 | t->ndim = out->ndim; 159 | out->data = NULL; 160 | out->shape = NULL; 161 | out->strides = NULL; 162 | tensor_free(out); 163 | } 164 | -------------------------------------------------------------------------------- /docs/design.md: -------------------------------------------------------------------------------- 1 | # Autograd Engine Design 2 | 3 | Context 4 | 5 | - This is a hobby autograd engine in C for learning and experimentation. This is a learning tool, not a hardened library. 6 | - The design prioritizes simplicity and code locality over flexibility. We make explicit tradeoffs: no in-place ops, no `retain_graph`, no non-leaf gradient retention, no views or aliasing, co-located forward/backward code, arena allocation for graph nodes and dependency counting instead of topological sort. These constraints eliminate entire classes of bugs and keep the codebase small. 7 | 8 | Core Structures 9 | 10 | - Tensor carries standard fields plus autograd metadata: `requires_grad` bool, `grad` pointer (`NULL` until backward writes to it), `grad_fn` pointer to the Function that created it (`NULL` for leaves) and `ref_count`. 11 | - Function contains: apply function pointer, output pointer to the tensor this function produced, inputs array of parent tensor pointers (fixed-size, maximum 4), num_inputs count, pending_count integer tracking downstream consumers, and `ctx void` pointer strictly for non-input saved data. 12 | 13 | Memory Model 14 | 15 | - Tensors use reference counting. Operations return new tensors with `ref_count=1;` caller owns them. When `ref_count` hits zero, tensor releases its grad if non-`NULL` and frees itself. 16 | - Every tensor owns a unique data buffer for its entire lifetime. No aliasing, no shared storage, no views. 17 | - Functions and ctx are allocated from a bump arena owned exclusively by `backward()`. One active arena per thread. Nested forwards and concurrent graphs within a thread are unsupported. 18 | - The arena is created on first `requires_grad` operation and stored in thread-local state. `backward()` frees it. If backward is never called, the arena leaks. Forward without backward is allowed only for debugging or REPL-style exploration. In real use it is misuse. Name it, do not fix it in code. 19 | - Ownership between Tensor and Function is one-way. `Tensor->grad_fn` points to Function. `Function->output` points to Tensor. Neither retains the other. Safe because both live in arena and are freed together. Do not add refcounting here. 20 | 21 | Invariants 22 | 23 | - Non-negotiable. Correctness depends on all of them. 24 | - (1) `pending_count` equals downstream consumer edges. Incremented once per consumer during graph construction. 25 | - (2) `pending_count` only incremented for parents with non-`NULL` `grad_fn`. Leaves have pending_count zero and never appear in the work queue. 26 | - (3) Backward kernels produce gradients in output shape space. Never pre-reduce. `accumulate_grad` handles all broadcasting reduction. 27 | - (4) `accumulate_grad` is the only place gradients are summed. Backward kernels must not call `tensor_add` on gradients. 28 | - (5) Backward kernels must not read `tensor->grad`. They receive `grad_output` as input and write via `accumulate_grad`. Gradient flow is strictly directional. 29 | - (6) `backward()` assumes all grad fields are `NULL` at entry. Only `loss->grad` is initialized by backward. Others are written when reached through the graph. 30 | - (7) Intermediate `tensor->grad` is allocated during backward and must not be inspected or retained. Only leaf gradients survive. 31 | - (8) `backward()` is only defined for scalar losses. Calling backward on non-scalar tensors is undefined behavior. 32 | - (9) `pending_count` must never go negative. Assert in debug builds. 33 | - (10) `fn->output->grad` must be non-NULL when apply is called. Assert in debug builds. If this fires, dependency counting or accumulation is broken. 34 | - (11) One active arena per thread. No nested forwards. 35 | 36 | Restrictions 37 | 38 | - No in-place operations on tensors with `requires_grad`. Fatal error on mutation. 39 | - No `retain_graph`. Run forward again to differentiate again. 40 | - Only leaf gradients survive backward. 41 | - No views, slicing, or aliasing. Every tensor owns unique data. 42 | - Fixed-size inputs array (maximum 4). Variable-arity ops like concat are out of scope unless the limit is raised. Do not silently truncate. 43 | 44 | Graph Construction 45 | 46 | - When an operation produces output with `requires_grad=true`: 47 | - (1) allocate Function from arena 48 | - (2) set `fn->output` 49 | - (3) populate `fn->inputs` with `requires_grad` parents 50 | - (4) for each parent with non-`NULL` `grad_fn`, increment `pending_count` 51 | - (5) set `output->grad_fn = fn` 52 | - `ctx` is for non-input data only. For `add`/`mul`, `ctx` is `NULL`. For `sum`, `ctx` holds original shape. For `matmul`, `ctx` holds transpose flags. 53 | 54 | Backward Algorithm 55 | 56 | - (1) Assert loss is scalar. Set `loss->grad = 1.0`. 57 | - (2) Create work queue. Push `loss->grad_fn`. 58 | - (3) Process queue: for each `fn`, assert `fn->output->grad` is non-`NULL`, call apply, for each parent with non-`NULL` `grad_fn` decrement pending_count and assert non-negative and if zero enqueue. 59 | - (4) Free arena. 60 | - Queue order does not affect correctness. Do not rely on FIFO or LIFO semantics. 61 | 62 | Gradient Accumulation 63 | 64 | - `accumulate_grad(tensor, new_grad)` takes ownership of `new_grad`. Reduces to tensor's shape. If `tensor->grad` is `NULL`, assigns directly. Otherwise sums, releases old, assigns result. 65 | - Single location for gradient addition and broadcast handling. 66 | 67 | Operation Implementation 68 | 69 | - Forward and backward adjacent in same file. 70 | - `mul`: Forward computes result. If `requires_grad`, allocate Function, set output, populate inputs, increment `pending_counts`. `ctx` is `NULL`. Backward reads inputs, computes `grad_a = grad_output * b` and `grad_b = grad_output * a`, calls `accumulate_grad`. 71 | - `sum`: Forward computes scalar. `ctx` holds original shape. Backward broadcasts `grad_output` to shape, calls `accumulate_grad`. 72 | - Apply signature: void `apply(Function *self, void *ctx, Tensor *grad_output)`. 73 | 74 | API Surface 75 | 76 | - `tensor_create`, `tensor_retain`, `tensor_release`, `tensor_zero_grad`. 77 | - `tensor_add`, `tensor_mul`, `tensor_matmul`, `tensor_sum`. 78 | - `backward(loss)`. 79 | 80 | File Organization 81 | 82 | - `src/tensor.c`: `Tensor`, `refcounting`, `creation`, `zero_grad`. 83 | - `src/autograd.c`: `Function`, `arena`, `backward` loop, `accumulate_grad`. Nothing else. 84 | - `src/ops/add.c`, `mul.c`, `matmul.c`, `reduce.c`: each op with forward and backward adjacent. 85 | 86 | Development Phases 87 | 88 | - Phase 1: Tensor with refcounting, arena, Function, backward loop without apply. Assert `pending_count >= 0`. 89 | - Phase 2: `accumulate_grad`, add op, finite difference verification. 90 | - Phase 3: `mul`, `sum`, `matmul`, each verified. 91 | - Phase 4: Diamond DAG `(z = x + x)`, valgrind, edge cases. 92 | -------------------------------------------------------------------------------- /src/ops/losses.c: -------------------------------------------------------------------------------- 1 | #include "losses.h" 2 | #include "autograd.h" 3 | #include "ops/losses_backward.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | Tensor *mse_loss(const Tensor *predictions, const Tensor *targets) { 11 | assert(predictions != NULL); 12 | assert(targets != NULL); 13 | assert(predictions->data != NULL || predictions->size == 0); 14 | assert(targets->data != NULL || targets->size == 0); 15 | assert(predictions->size == targets->size); 16 | assert(predictions->ndim == targets->ndim); 17 | for (uint64_t i = 0; i < predictions->ndim; i++) { 18 | assert(predictions->shape[i] == targets->shape[i]); 19 | } 20 | 21 | float32_t sum_squared_error = 0.0f; 22 | for (uint64_t i = 0; i < predictions->size; i++) { 23 | float32_t diff = predictions->data[i] - targets->data[i]; 24 | sum_squared_error += diff * diff; 25 | } 26 | float32_t loss_val = sum_squared_error / (float32_t)predictions->size; 27 | 28 | const uint64_t shape[] = {1}; 29 | Tensor *out = tensor_create(&loss_val, shape, 0, predictions->requires_grad); 30 | 31 | if (out->requires_grad) { 32 | Function *fn = arena_alloc_function(); 33 | fn->apply = mse_loss_backward_fn; 34 | fn->output = out; 35 | fn->num_inputs = 2; 36 | fn->inputs[0] = (Tensor *)predictions; 37 | fn->inputs[1] = (Tensor *)targets; 38 | fn->pending_count = 0; 39 | fn->ctx = NULL; 40 | if (predictions->grad_fn != NULL) { 41 | predictions->grad_fn->pending_count++; 42 | } 43 | if (targets->grad_fn != NULL) { 44 | targets->grad_fn->pending_count++; 45 | } 46 | out->grad_fn = fn; 47 | } 48 | 49 | return out; 50 | } 51 | 52 | Tensor *cross_entropy_loss(const Tensor *logits, const Tensor *targets) { 53 | assert(logits != NULL); 54 | assert(targets != NULL); 55 | assert(logits->data != NULL || logits->size == 0); 56 | assert(targets->data != NULL || targets->size == 0); 57 | assert(logits->ndim == 2); 58 | assert(targets->ndim == 1 || (targets->ndim == 2 && targets->shape[1] == 1)); 59 | assert(logits->shape[0] == targets->shape[0]); 60 | 61 | uint64_t batch_size = logits->shape[0]; 62 | uint64_t num_classes = logits->shape[1]; 63 | if (num_classes == 0) { 64 | return tensor_zeros(NULL, 0, false); 65 | } 66 | assert(num_classes > 0); 67 | 68 | float32_t sum_loss = 0.0f; 69 | 70 | for (uint64_t i = 0; i < batch_size; i++) { 71 | float32_t target_float = targets->data[i]; 72 | assert(target_float >= 0.0f && target_float < (float32_t)num_classes); 73 | uint64_t target_idx = (uint64_t)target_float; 74 | 75 | // log-softmax for numerical stability 76 | float32_t max_logit = -FLT_MAX; 77 | for (uint64_t j = 0; j < num_classes; j++) { 78 | float32_t logit = logits->data[i * num_classes + j]; 79 | if (logit > max_logit) { 80 | max_logit = logit; 81 | } 82 | } 83 | 84 | // compute sum(exp(logit - max_logit)) 85 | float32_t sum_exp = 0.0f; 86 | for (uint64_t j = 0; j < num_classes; j++) { 87 | float32_t logit = logits->data[i * num_classes + j]; 88 | sum_exp += expf(logit - max_logit); 89 | } 90 | 91 | // clip for numerical stability 92 | if (sum_exp < 1.0f) { 93 | sum_exp = 1.0f; 94 | } 95 | float32_t log_sum_exp = logf(sum_exp) + max_logit; 96 | float32_t correct_logit = logits->data[i * num_classes + target_idx]; 97 | 98 | // log_softmax = correct_logit - log_sum_exp 99 | // loss = -log_softmax 100 | float32_t loss = -(correct_logit - log_sum_exp); 101 | sum_loss += loss; 102 | } 103 | float32_t loss_val = sum_loss / (float32_t)batch_size; 104 | 105 | const uint64_t shape[] = {1}; 106 | Tensor *out = tensor_create(&loss_val, shape, 0, logits->requires_grad); 107 | 108 | if (out->requires_grad) { 109 | Function *fn = arena_alloc_function(); 110 | fn->apply = cross_entropy_loss_backward_fn; 111 | fn->output = out; 112 | fn->num_inputs = 2; 113 | fn->inputs[0] = (Tensor *)logits; 114 | fn->inputs[1] = (Tensor *)targets; 115 | fn->pending_count = 0; 116 | fn->ctx = NULL; 117 | if (logits->grad_fn != NULL) { 118 | logits->grad_fn->pending_count++; 119 | } 120 | if (targets->grad_fn != NULL) { 121 | targets->grad_fn->pending_count++; 122 | } 123 | out->grad_fn = fn; 124 | } 125 | 126 | return out; 127 | } 128 | 129 | #define EPSILON 1e-7f 130 | 131 | Tensor *binary_cross_entropy_loss(const Tensor *predictions, const Tensor *targets) { 132 | assert(predictions != NULL); 133 | assert(targets != NULL); 134 | assert(predictions->data != NULL || predictions->size == 0); 135 | assert(targets->data != NULL || targets->size == 0); 136 | assert(predictions->size == targets->size); 137 | assert(predictions->ndim == targets->ndim); 138 | for (uint64_t i = 0; i < predictions->ndim; i++) { 139 | assert(predictions->shape[i] == targets->shape[i]); 140 | } 141 | 142 | float32_t sum_loss = 0.0f; 143 | 144 | for (uint64_t i = 0; i < predictions->size; i++) { 145 | float32_t pred = predictions->data[i]; 146 | float32_t target = targets->data[i]; 147 | 148 | // clamp to avoid log(0) 149 | if (pred < EPSILON) { 150 | pred = EPSILON; 151 | } 152 | if (pred > 1.0f - EPSILON) { 153 | pred = 1.0f - EPSILON; 154 | } 155 | 156 | float32_t term1 = target * logf(pred); 157 | float32_t term2 = (1.0f - target) * logf(1.0f - pred); 158 | sum_loss += -(term1 + term2); 159 | } 160 | float32_t loss_val = sum_loss / (float32_t)predictions->size; 161 | 162 | const uint64_t shape[] = {1}; 163 | Tensor *out = tensor_create(&loss_val, shape, 0, predictions->requires_grad); 164 | 165 | if (out->requires_grad) { 166 | Function *fn = arena_alloc_function(); 167 | fn->apply = binary_cross_entropy_loss_backward_fn; 168 | fn->output = out; 169 | fn->num_inputs = 2; 170 | fn->inputs[0] = (Tensor *)predictions; 171 | fn->inputs[1] = (Tensor *)targets; 172 | fn->pending_count = 0; 173 | fn->ctx = NULL; 174 | if (predictions->grad_fn != NULL) { 175 | predictions->grad_fn->pending_count++; 176 | } 177 | if (targets->grad_fn != NULL) { 178 | targets->grad_fn->pending_count++; 179 | } 180 | out->grad_fn = fn; 181 | } 182 | 183 | return out; 184 | } 185 | -------------------------------------------------------------------------------- /src/ops/reductions_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/reductions_backward.h" 2 | #include "ops/arithmetic.h" 3 | #include "ops/reductions.h" 4 | #include "ops/reshapes.h" 5 | #include "tensor.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // 13 | // sum 14 | // 15 | 16 | Tensor *tensor_sum_backward(const Tensor *grad_output, const Tensor *t, int64_t dim_idx, bool keepdims) { 17 | assert(grad_output != NULL); 18 | assert(t != NULL); 19 | 20 | int64_t ndim_signed = (int64_t)t->ndim; 21 | int64_t target_dim_signed = (dim_idx < 0) ? (dim_idx + ndim_signed) : dim_idx; 22 | assert(target_dim_signed >= 0 && target_dim_signed < ndim_signed && "target_dim out of bounds"); 23 | 24 | uint64_t target_dim = (uint64_t)target_dim_signed; 25 | 26 | const Tensor *grad_expanded = grad_output; 27 | bool needs_free = false; 28 | 29 | if (!keepdims) { 30 | assert(t->ndim <= MAX_NDIM); 31 | int64_t new_shape[MAX_NDIM] = {0}; 32 | uint64_t grad_dim_idx = 0; 33 | 34 | for (uint64_t d = 0; d < t->ndim; d++) { 35 | if (d == target_dim) { 36 | new_shape[d] = 1; 37 | } else { 38 | new_shape[d] = (int64_t)grad_output->shape[grad_dim_idx++]; 39 | } 40 | } 41 | 42 | grad_expanded = tensor_reshape(grad_output, new_shape, t->ndim); 43 | needs_free = true; 44 | } 45 | 46 | Tensor *zeros = tensor_zeros(t->shape, t->ndim, false); 47 | assert(zeros != NULL); 48 | 49 | Tensor *grad_input = tensor_add(zeros, grad_expanded); 50 | assert(grad_input != NULL); 51 | 52 | tensor_free(zeros); 53 | if (needs_free) { 54 | tensor_free((Tensor *)grad_expanded); 55 | } 56 | 57 | return grad_input; 58 | } 59 | 60 | void sum_backward(Function *fn, const Tensor *grad_output) { 61 | assert(fn != NULL); 62 | assert(grad_output != NULL); 63 | assert(fn->num_inputs == 1); 64 | assert(fn->ctx != NULL && "sum_backward requires context"); 65 | 66 | Tensor *t = fn->inputs[0]; 67 | const SumContext *ctx = (SumContext *)fn->ctx; 68 | 69 | if (t != NULL && t->requires_grad) { 70 | Tensor *grad_t = tensor_sum_backward(grad_output, t, ctx->dim_idx, ctx->keepdims); 71 | accumulate_grad(t, grad_t); 72 | } 73 | 74 | free(fn->ctx); 75 | fn->ctx = NULL; 76 | } 77 | 78 | // 79 | // mean 80 | // 81 | 82 | Tensor *tensor_mean_backward(const Tensor *grad_output, const Tensor *t, int64_t dim_idx, bool keepdims) { 83 | assert(grad_output != NULL); 84 | assert(t != NULL); 85 | 86 | Tensor *sum_grad = tensor_sum_backward(grad_output, t, dim_idx, keepdims); 87 | assert(sum_grad != NULL); 88 | 89 | int64_t ndim_signed = (int64_t)t->ndim; 90 | int64_t target_dim_signed = (dim_idx < 0) ? (dim_idx + ndim_signed) : dim_idx; 91 | assert(target_dim_signed >= 0 && target_dim_signed < ndim_signed && "target_dim out of bounds"); 92 | uint64_t target_dim = (uint64_t)target_dim_signed; 93 | 94 | uint64_t count = t->shape[target_dim]; 95 | 96 | if (count == 0) { 97 | return sum_grad; 98 | } 99 | 100 | const uint64_t scalar_shape[] = {1}; 101 | Tensor *scale_t = tensor_create(NULL, scalar_shape, 1, false); 102 | assert(scale_t != NULL); 103 | 104 | scale_t->data[0] = 1.0f / (float32_t)count; 105 | 106 | Tensor *grad_input = tensor_mul(sum_grad, scale_t, true); // disable_grad=true 107 | assert(grad_input != NULL); 108 | 109 | tensor_free(sum_grad); 110 | tensor_free(scale_t); 111 | 112 | return grad_input; 113 | } 114 | 115 | void mean_backward(Function *fn, const Tensor *grad_output) { 116 | assert(fn != NULL); 117 | assert(grad_output != NULL); 118 | assert(fn->num_inputs == 1); 119 | assert(fn->ctx != NULL && "mean_backward requires context"); 120 | 121 | Tensor *t = fn->inputs[0]; 122 | const MeanContext *ctx = (MeanContext *)fn->ctx; 123 | 124 | if (t != NULL && t->requires_grad) { 125 | Tensor *grad_t = tensor_mean_backward(grad_output, t, ctx->dim_idx, ctx->keepdims); 126 | accumulate_grad(t, grad_t); 127 | } 128 | 129 | free(fn->ctx); 130 | fn->ctx = NULL; 131 | } 132 | 133 | // 134 | // max 135 | // 136 | 137 | Tensor *tensor_max_backward(const Tensor *grad_output, const Tensor *t, const Tensor *out, int64_t dim_idx, bool keepdims) { 138 | assert(grad_output != NULL); 139 | assert(t != NULL); 140 | assert(out != NULL); 141 | 142 | int64_t ndim_signed = (int64_t)t->ndim; 143 | int64_t target_dim_signed = (dim_idx < 0) ? (dim_idx + ndim_signed) : dim_idx; 144 | assert(target_dim_signed >= 0 && target_dim_signed < ndim_signed && "target_dim out of bounds"); 145 | uint64_t target_dim = (uint64_t)target_dim_signed; 146 | 147 | Tensor *grad_input = tensor_zeros(t->shape, t->ndim, false); 148 | assert(grad_input != NULL); 149 | 150 | uint64_t curr_coords[MAX_NDIM] = {0}; 151 | uint64_t out_coords[MAX_NDIM] = {0}; 152 | 153 | for (uint64_t i = 0; i < t->size; i++) { 154 | linear_to_multidim_mut(i, t->shape, t->ndim, curr_coords); 155 | 156 | uint64_t out_offset = 0; 157 | if (keepdims) { 158 | uint64_t saved_dim_val = curr_coords[target_dim]; 159 | // for keepdims=true, target dimension is collapsed to size 1 (index 0) 160 | curr_coords[target_dim] = 0; 161 | out_offset = multidim_to_linear(curr_coords, t->ndim, out->shape, out->ndim, out->strides); 162 | curr_coords[target_dim] = saved_dim_val; 163 | } else { 164 | uint64_t k = 0; 165 | for (uint64_t d = 0; d < t->ndim; d++) { 166 | if (d != target_dim) { 167 | out_coords[k++] = curr_coords[d]; 168 | } 169 | } 170 | out_offset = multidim_to_linear(out_coords, out->ndim, out->shape, out->ndim, out->strides); 171 | } 172 | 173 | float32_t val = t->data[i]; 174 | float32_t max_val = out->data[out_offset]; 175 | 176 | if (val == max_val) { 177 | grad_input->data[i] += grad_output->data[out_offset]; 178 | } 179 | } 180 | 181 | return grad_input; 182 | } 183 | 184 | void max_backward(Function *fn, const Tensor *grad_output) { 185 | assert(fn != NULL); 186 | assert(grad_output != NULL); 187 | assert(fn->num_inputs == 1); 188 | assert(fn->ctx != NULL && "max_backward requires context"); 189 | 190 | Tensor *t = fn->inputs[0]; 191 | const MaxContext *ctx = (MaxContext *)fn->ctx; 192 | 193 | if (t != NULL && t->requires_grad) { 194 | Tensor *grad_t = tensor_max_backward(grad_output, t, ctx->output, ctx->dim_idx, ctx->keepdims); 195 | accumulate_grad(t, grad_t); 196 | } 197 | 198 | free(fn->ctx); 199 | fn->ctx = NULL; 200 | } 201 | -------------------------------------------------------------------------------- /test/test_activations_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/activations.h" 2 | #include "ops/activations_backward.h" 3 | #include "tensor.h" 4 | #include "unity.h" 5 | #include 6 | #include 7 | #include 8 | 9 | void setUp(void) {} 10 | void tearDown(void) {} 11 | 12 | static Tensor *create_tensor_from_data(float32_t *data, uint64_t size) { 13 | uint64_t shape[] = {size}; 14 | return tensor_create(data, shape, 1, false); 15 | } 16 | 17 | void test_sigmoid_backward_standard_values(void) { 18 | float32_t data[] = {0.0f, 1.0f, -1.0f}; 19 | Tensor *t = create_tensor_from_data(data, 3); 20 | Tensor *grad = tensor_sigmoid_backward(t); 21 | 22 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[0]); 23 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.1966119f, grad->data[1]); 24 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.1966119f, grad->data[2]); 25 | 26 | tensor_free(t); 27 | tensor_free(grad); 28 | } 29 | 30 | void test_relu_backward_standard_values(void) { 31 | float32_t data[] = {-5.0f, -0.1f, 0.0f, 0.1f, 5.0f}; 32 | Tensor *t = create_tensor_from_data(data, 5); 33 | Tensor *grad = tensor_relu_backward(t); 34 | 35 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[0]); 36 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[1]); 37 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[2]); 38 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, grad->data[3]); 39 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, grad->data[4]); 40 | 41 | tensor_free(t); 42 | tensor_free(grad); 43 | } 44 | 45 | void test_tanh_backward_standard_values(void) { 46 | float32_t data[] = {0.0f, 1.0f, -1.0f}; 47 | Tensor *t = create_tensor_from_data(data, 3); 48 | Tensor *grad = tensor_tanh_backward(t); 49 | 50 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, grad->data[0]); 51 | 52 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.4199743f, grad->data[1]); 53 | 54 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.4199743f, grad->data[2]); 55 | 56 | tensor_free(t); 57 | tensor_free(grad); 58 | } 59 | 60 | void test_gelu_backward_standard_values(void) { 61 | float32_t data[] = {0.0f, 1.0f, -1.0f}; 62 | Tensor *t = create_tensor_from_data(data, 3); 63 | Tensor *grad = tensor_gelu_backward(t); 64 | 65 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.5f, grad->data[0]); 66 | 67 | TEST_ASSERT_FLOAT_WITHIN(1e-4, 1.08296f, grad->data[1]); 68 | TEST_ASSERT_FLOAT_WITHIN(1e-4, -0.08296f, grad->data[2]); 69 | 70 | tensor_free(t); 71 | tensor_free(grad); 72 | } 73 | 74 | void test_softmax_backward_diagonal(void) { 75 | float32_t data[] = {0.0f, 0.0f}; 76 | uint64_t shape[] = {2}; 77 | Tensor *t = tensor_create(data, shape, 1, false); 78 | Tensor *grad = tensor_softmax_backward(t, 0); 79 | 80 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[0]); 81 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[1]); 82 | 83 | tensor_free(t); 84 | tensor_free(grad); 85 | } 86 | 87 | void test_sigmoid_backward_stability(void) { 88 | float32_t data[] = {100.0f, -100.0f}; 89 | Tensor *t = create_tensor_from_data(data, 2); 90 | Tensor *grad = tensor_sigmoid_backward(t); 91 | 92 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[0]); 93 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[1]); 94 | 95 | tensor_free(t); 96 | tensor_free(grad); 97 | } 98 | 99 | void test_tanh_backward_stability(void) { 100 | float32_t data[] = {50.0f, -50.0f}; 101 | Tensor *t = create_tensor_from_data(data, 2); 102 | Tensor *grad = tensor_tanh_backward(t); 103 | 104 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[0]); 105 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[1]); 106 | 107 | tensor_free(t); 108 | tensor_free(grad); 109 | } 110 | 111 | void test_sigmoid_backward_edge_cases(void) { 112 | float32_t data[] = {0.0f, 100.0f, -100.0f}; 113 | Tensor *t = create_tensor_from_data(data, 3); 114 | Tensor *grad = tensor_sigmoid_backward(t); 115 | 116 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[0]); 117 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[1]); 118 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[2]); 119 | 120 | tensor_free(t); 121 | tensor_free(grad); 122 | } 123 | 124 | void test_relu_backward_edge_cases(void) { 125 | float32_t data[] = {0.0f, 10.0f, -10.0f}; 126 | Tensor *t = create_tensor_from_data(data, 3); 127 | Tensor *grad = tensor_relu_backward(t); 128 | 129 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[0]); 130 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, grad->data[1]); 131 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[2]); 132 | 133 | tensor_free(t); 134 | tensor_free(grad); 135 | } 136 | 137 | void test_tanh_backward_edge_cases(void) { 138 | float32_t data[] = {0.0f, 50.0f, -50.0f}; 139 | Tensor *t = create_tensor_from_data(data, 3); 140 | Tensor *grad = tensor_tanh_backward(t); 141 | 142 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, grad->data[0]); 143 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[1]); 144 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, grad->data[2]); 145 | 146 | tensor_free(t); 147 | tensor_free(grad); 148 | } 149 | 150 | void test_gelu_backward_edge_cases(void) { 151 | float32_t data[] = {0.0f, 1.0f, -1.0f}; 152 | Tensor *t = create_tensor_from_data(data, 3); 153 | Tensor *grad = tensor_gelu_backward(t); 154 | 155 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.5f, grad->data[0]); 156 | 157 | TEST_ASSERT_FLOAT_WITHIN(1e-4, 1.08296f, grad->data[1]); 158 | TEST_ASSERT_FLOAT_WITHIN(1e-4, -0.08296f, grad->data[2]); 159 | 160 | tensor_free(t); 161 | tensor_free(grad); 162 | } 163 | 164 | void test_softmax_backward_shapes(void) { 165 | float32_t data[] = {0.0f, 0.0f, 0.0f, 0.0f}; 166 | uint64_t shape[] = {2, 2}; 167 | Tensor *t = tensor_create(data, shape, 2, false); 168 | 169 | Tensor *grad = tensor_softmax_backward(t, 1); 170 | 171 | for (int i = 0; i < 4; i++) { 172 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[i]); 173 | } 174 | 175 | tensor_free(grad); 176 | 177 | grad = tensor_softmax_backward(t, 0); 178 | for (int i = 0; i < 4; i++) { 179 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.25f, grad->data[i]); 180 | } 181 | 182 | tensor_free(t); 183 | tensor_free(grad); 184 | } 185 | 186 | int main(void) { 187 | UNITY_BEGIN(); 188 | RUN_TEST(test_sigmoid_backward_standard_values); 189 | RUN_TEST(test_relu_backward_standard_values); 190 | RUN_TEST(test_tanh_backward_standard_values); 191 | RUN_TEST(test_gelu_backward_standard_values); 192 | RUN_TEST(test_softmax_backward_diagonal); 193 | RUN_TEST(test_sigmoid_backward_stability); 194 | RUN_TEST(test_sigmoid_backward_edge_cases); 195 | RUN_TEST(test_relu_backward_edge_cases); 196 | RUN_TEST(test_tanh_backward_edge_cases); 197 | RUN_TEST(test_gelu_backward_edge_cases); 198 | RUN_TEST(test_softmax_backward_shapes); 199 | RUN_TEST(test_tanh_backward_stability); 200 | return UNITY_END(); 201 | } 202 | -------------------------------------------------------------------------------- /src/ops/losses_backward.c: -------------------------------------------------------------------------------- 1 | #include "losses_backward.h" 2 | #include "autograd.h" 3 | #include "tensor.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | Tensor *mse_loss_backward(const Tensor *predictions, const Tensor *targets) { 10 | assert(predictions != NULL); 11 | assert(targets != NULL); 12 | assert(predictions->data != NULL || predictions->size == 0); 13 | assert(targets->data != NULL || targets->size == 0); 14 | assert(predictions->size == targets->size); 15 | assert(predictions->ndim == targets->ndim); 16 | 17 | for (uint64_t i = 0; i < predictions->ndim; i++) { 18 | assert(predictions->shape[i] == targets->shape[i]); 19 | } 20 | 21 | Tensor *grad = tensor_create(NULL, predictions->shape, predictions->ndim, false); 22 | if (predictions->size == 0) { 23 | return grad; 24 | } 25 | 26 | float32_t inv_n = 2.0f / (float32_t)predictions->size; 27 | for (uint64_t i = 0; i < predictions->size; i++) { 28 | float32_t diff = predictions->data[i] - targets->data[i]; 29 | grad->data[i] = inv_n * diff; 30 | } 31 | 32 | return grad; 33 | } 34 | 35 | void mse_loss_backward_fn(Function *fn, const Tensor *grad_output) { 36 | assert(fn != NULL); 37 | assert(grad_output != NULL); 38 | assert(fn->num_inputs == 2); 39 | assert(grad_output->size == 1 && "loss output must be scalar"); 40 | 41 | Tensor *predictions = fn->inputs[0]; 42 | const Tensor *targets = fn->inputs[1]; 43 | 44 | if (predictions != NULL && predictions->requires_grad) { 45 | Tensor *grad = mse_loss_backward(predictions, targets); 46 | float32_t scale = grad_output->data[0]; 47 | for (uint64_t i = 0; i < grad->size; i++) { 48 | grad->data[i] *= scale; 49 | } 50 | accumulate_grad(predictions, grad); 51 | } 52 | } 53 | 54 | Tensor *cross_entropy_loss_backward(const Tensor *logits, const Tensor *targets) { 55 | assert(logits != NULL); 56 | assert(targets != NULL); 57 | assert(logits->data != NULL || logits->size == 0); 58 | assert(targets->data != NULL || targets->size == 0); 59 | 60 | assert(logits->ndim == 2); 61 | assert(targets->ndim == 1 || (targets->ndim == 2 && targets->shape[1] == 1)); 62 | assert(logits->shape[0] == targets->shape[0]); 63 | 64 | uint64_t batch_size = logits->shape[0]; 65 | uint64_t num_classes = logits->shape[1]; 66 | 67 | Tensor *grad = tensor_create(NULL, logits->shape, logits->ndim, false); 68 | if (batch_size == 0 || num_classes == 0) { 69 | return grad; 70 | } 71 | 72 | float32_t inv_batch = 1.0f / (float32_t)batch_size; 73 | 74 | for (uint64_t i = 0; i < batch_size; i++) { 75 | float32_t target_float = targets->data[i]; 76 | assert(target_float >= 0.0f && target_float < (float32_t)num_classes); 77 | uint64_t target_idx = (uint64_t)target_float; 78 | 79 | // for numerical stability, compute softmax via log-sum-exp trick 80 | float32_t max_logit = -FLT_MAX; 81 | for (uint64_t j = 0; j < num_classes; j++) { 82 | float32_t logit = logits->data[i * num_classes + j]; 83 | if (logit > max_logit) { 84 | max_logit = logit; 85 | } 86 | } 87 | 88 | float32_t sum_exp = 0.0f; 89 | for (uint64_t j = 0; j < num_classes; j++) { 90 | float32_t logit = logits->data[i * num_classes + j]; 91 | sum_exp += expf(logit - max_logit); 92 | } 93 | 94 | if (sum_exp < 1.0f) { 95 | sum_exp = 1.0f; 96 | } 97 | 98 | // softmax probabilities 99 | for (uint64_t j = 0; j < num_classes; j++) { 100 | float32_t logit = logits->data[i * num_classes + j]; 101 | float32_t prob = expf(logit - max_logit) / sum_exp; 102 | float32_t indicator = (j == target_idx) ? 1.0f : 0.0f; 103 | grad->data[i * num_classes + j] = (prob - indicator) * inv_batch; 104 | } 105 | } 106 | 107 | return grad; 108 | } 109 | 110 | void cross_entropy_loss_backward_fn(Function *fn, const Tensor *grad_output) { 111 | assert(fn != NULL); 112 | assert(grad_output != NULL); 113 | assert(fn->num_inputs == 2); 114 | assert(grad_output->size == 1 && "loss output must be scalar"); 115 | 116 | Tensor *logits = fn->inputs[0]; 117 | const Tensor *targets = fn->inputs[1]; 118 | 119 | if (logits != NULL && logits->requires_grad) { 120 | Tensor *grad = cross_entropy_loss_backward(logits, targets); 121 | float32_t scale = grad_output->data[0]; 122 | for (uint64_t i = 0; i < grad->size; i++) { 123 | grad->data[i] *= scale; 124 | } 125 | accumulate_grad(logits, grad); 126 | } 127 | } 128 | 129 | #define EPSILON 1e-7f 130 | 131 | Tensor *binary_cross_entropy_loss_backward(const Tensor *predictions, const Tensor *targets) { 132 | assert(predictions != NULL); 133 | assert(targets != NULL); 134 | assert(predictions->data != NULL || predictions->size == 0); 135 | assert(targets->data != NULL || targets->size == 0); 136 | assert(predictions->size == targets->size); 137 | assert(predictions->ndim == targets->ndim); 138 | 139 | for (uint64_t i = 0; i < predictions->ndim; i++) { 140 | assert(predictions->shape[i] == targets->shape[i]); 141 | } 142 | 143 | Tensor *grad = tensor_create(NULL, predictions->shape, predictions->ndim, false); 144 | if (predictions->size == 0) { 145 | return grad; 146 | } 147 | 148 | float32_t inv_n = 1.0f / (float32_t)predictions->size; 149 | 150 | for (uint64_t i = 0; i < predictions->size; i++) { 151 | float32_t p = predictions->data[i]; 152 | float32_t t = targets->data[i]; 153 | 154 | // clamp p into (0,1) for numerical stability, mirroring forward 155 | if (p < EPSILON) { 156 | p = EPSILON; 157 | } 158 | if (p > 1.0f - EPSILON) { 159 | p = 1.0f - EPSILON; 160 | } 161 | 162 | float32_t denom = p * (1.0f - p); 163 | if (denom < EPSILON) { 164 | denom = EPSILON; 165 | } 166 | 167 | grad->data[i] = ((p - t) / denom) * inv_n; 168 | } 169 | 170 | return grad; 171 | } 172 | 173 | void binary_cross_entropy_loss_backward_fn(Function *fn, const Tensor *grad_output) { 174 | assert(fn != NULL); 175 | assert(grad_output != NULL); 176 | assert(fn->num_inputs == 2); 177 | assert(grad_output->size == 1 && "loss output must be scalar"); 178 | 179 | Tensor *predictions = fn->inputs[0]; 180 | const Tensor *targets = fn->inputs[1]; 181 | 182 | if (predictions != NULL && predictions->requires_grad) { 183 | Tensor *grad = binary_cross_entropy_loss_backward(predictions, targets); 184 | float32_t scale = grad_output->data[0]; 185 | for (uint64_t i = 0; i < grad->size; i++) { 186 | grad->data[i] *= scale; 187 | } 188 | accumulate_grad(predictions, grad); 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /test/test_convolutions_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/convolutions_backward.h" 2 | #include "tensor.h" 3 | #include "unity.h" 4 | #include 5 | #include 6 | 7 | void setUp(void) {} 8 | void tearDown(void) {} 9 | 10 | void test_conv2d_backward_simple(void) { 11 | uint64_t shape[] = {1, 1, 1, 1}; 12 | Tensor *input = tensor_create(NULL, shape, 4, false); 13 | input->data[0] = 2.0f; 14 | 15 | Tensor *weight = tensor_create(NULL, shape, 4, false); 16 | weight->data[0] = 3.0f; 17 | 18 | Tensor *bias = tensor_create(NULL, shape, 1, false); 19 | bias->data[0] = 0.5f; 20 | 21 | Tensor *grad_output = tensor_create(NULL, shape, 4, false); 22 | grad_output->data[0] = 1.0f; 23 | 24 | Tensor *d_in = NULL; 25 | Tensor *d_w = NULL; 26 | Tensor *d_b = NULL; 27 | 28 | conv2d_backward(input, weight, bias, 1, 0, 1, grad_output, &d_in, &d_w, &d_b); 29 | 30 | TEST_ASSERT_NOT_NULL(d_in); 31 | TEST_ASSERT_NOT_NULL(d_w); 32 | TEST_ASSERT_NOT_NULL(d_b); 33 | 34 | TEST_ASSERT_EQUAL_FLOAT(3.0f, d_in->data[0]); 35 | TEST_ASSERT_EQUAL_FLOAT(2.0f, d_w->data[0]); 36 | TEST_ASSERT_EQUAL_FLOAT(1.0f, d_b->data[0]); 37 | 38 | tensor_free(input); 39 | tensor_free(weight); 40 | tensor_free(bias); 41 | tensor_free(grad_output); 42 | tensor_free(d_in); 43 | tensor_free(d_w); 44 | tensor_free(d_b); 45 | } 46 | 47 | void test_conv2d_backward_stride(void) { 48 | uint64_t in_shape[] = {1, 1, 3, 3}; 49 | Tensor *input = tensor_zeros(in_shape, 4, false); 50 | 51 | uint64_t w_shape[] = {1, 1, 2, 2}; 52 | Tensor *weight = tensor_zeros(w_shape, 4, false); 53 | weight->data[0] = 1.0f; 54 | weight->data[1] = 2.0f; 55 | weight->data[2] = 3.0f; 56 | weight->data[3] = 4.0f; 57 | 58 | uint64_t out_shape[] = {1, 1, 1, 1}; 59 | Tensor *grad_output = tensor_zeros(out_shape, 4, false); 60 | grad_output->data[0] = 10.0f; 61 | 62 | Tensor *d_in = NULL; 63 | Tensor *d_w = NULL; 64 | Tensor *d_b = NULL; 65 | 66 | conv2d_backward(input, weight, NULL, 2, 0, 2, grad_output, &d_in, &d_w, &d_b); 67 | 68 | TEST_ASSERT_EQUAL_FLOAT(10.0f, d_in->data[0]); 69 | TEST_ASSERT_EQUAL_FLOAT(20.0f, d_in->data[1]); 70 | TEST_ASSERT_EQUAL_FLOAT(30.0f, d_in->data[3]); 71 | TEST_ASSERT_EQUAL_FLOAT(40.0f, d_in->data[4]); 72 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[2]); 73 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[5]); 74 | 75 | tensor_free(input); 76 | tensor_free(weight); 77 | tensor_free(grad_output); 78 | tensor_free(d_in); 79 | tensor_free(d_w); 80 | } 81 | 82 | void test_conv2d_backward_padding(void) { 83 | uint64_t in_shape[] = {1, 1, 1, 1}; 84 | Tensor *input = tensor_zeros(in_shape, 4, false); 85 | 86 | uint64_t w_shape[] = {1, 1, 3, 3}; 87 | Tensor *weight = tensor_zeros(w_shape, 4, false); 88 | for (int i = 0; i < 9; ++i) 89 | weight->data[i] = 1.0f; 90 | weight->data[4] = 5.0f; 91 | 92 | uint64_t out_shape[] = {1, 1, 1, 1}; 93 | Tensor *grad_output = tensor_zeros(out_shape, 4, false); 94 | grad_output->data[0] = 2.0f; 95 | 96 | Tensor *d_in = NULL; 97 | Tensor *d_w = NULL; 98 | Tensor *d_b = NULL; 99 | 100 | conv2d_backward(input, weight, NULL, 1, 1, 3, grad_output, &d_in, &d_w, &d_b); 101 | 102 | TEST_ASSERT_EQUAL_UINT64(1, d_in->shape[2]); 103 | TEST_ASSERT_EQUAL_UINT64(1, d_in->shape[3]); 104 | TEST_ASSERT_EQUAL_FLOAT(10.0f, d_in->data[0]); 105 | 106 | tensor_free(input); 107 | tensor_free(weight); 108 | tensor_free(grad_output); 109 | tensor_free(d_in); 110 | tensor_free(d_w); 111 | } 112 | 113 | void test_maxpool2d_backward_simple(void) { 114 | uint64_t in_shape[] = {1, 1, 2, 2}; 115 | Tensor *input = tensor_zeros(in_shape, 4, false); 116 | input->data[0] = 1.0f; 117 | input->data[1] = 2.0f; 118 | input->data[2] = 3.0f; 119 | input->data[3] = 4.0f; 120 | 121 | uint64_t out_shape_arr[] = {1, 1, 1, 1}; 122 | Tensor *grad_output = tensor_zeros(out_shape_arr, 4, false); 123 | grad_output->data[0] = 10.0f; 124 | 125 | Tensor *d_in = maxpool2d_backward(input, out_shape_arr, 2, 2, 0, grad_output); 126 | 127 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[0]); 128 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[1]); 129 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[2]); 130 | TEST_ASSERT_EQUAL_FLOAT(10.0f, d_in->data[3]); 131 | 132 | tensor_free(input); 133 | tensor_free(grad_output); 134 | tensor_free(d_in); 135 | } 136 | 137 | void test_avgpool2d_backward_simple(void) { 138 | uint64_t in_shape[] = {1, 1, 2, 2}; 139 | Tensor *input = tensor_zeros(in_shape, 4, false); 140 | 141 | uint64_t out_shape_arr[] = {1, 1, 1, 1}; 142 | Tensor *grad_output = tensor_zeros(out_shape_arr, 4, false); 143 | grad_output->data[0] = 4.0f; 144 | 145 | Tensor *d_in = avgpool2d_backward(input, out_shape_arr, 2, 2, 0, grad_output); 146 | 147 | for (int i = 0; i < 4; ++i) { 148 | TEST_ASSERT_EQUAL_FLOAT(1.0f, d_in->data[i]); 149 | } 150 | 151 | tensor_free(input); 152 | tensor_free(grad_output); 153 | tensor_free(d_in); 154 | } 155 | 156 | void test_batchnorm2d_backward_simple(void) { 157 | uint64_t shape[] = {1, 1, 2, 2}; 158 | Tensor *input = tensor_zeros(shape, 4, false); 159 | input->data[0] = 0.0f; 160 | input->data[1] = 4.0f; 161 | input->data[2] = 0.0f; 162 | input->data[3] = 4.0f; 163 | 164 | Tensor *gamma = tensor_zeros(shape, 1, false); 165 | gamma->shape[0] = 1; 166 | gamma->ndim = 1; 167 | gamma->size = 1; 168 | gamma->data[0] = 1.0f; 169 | 170 | Tensor *mean = tensor_zeros(shape, 1, false); 171 | mean->shape[0] = 1; 172 | mean->ndim = 1; 173 | mean->size = 1; 174 | mean->data[0] = 2.0f; 175 | Tensor *var = tensor_zeros(shape, 1, false); 176 | var->shape[0] = 1; 177 | var->ndim = 1; 178 | var->size = 1; 179 | var->data[0] = 4.0f; 180 | 181 | Tensor *grad_output = tensor_zeros(shape, 4, false); 182 | for (int i = 0; i < 4; ++i) 183 | grad_output->data[i] = 1.0f; 184 | 185 | Tensor *d_in = NULL; 186 | Tensor *d_gamma = NULL; 187 | Tensor *d_beta = NULL; 188 | 189 | batchnorm2d_backward(input, gamma, mean, var, 0.0f, grad_output, &d_in, &d_gamma, &d_beta); 190 | 191 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_gamma->data[0]); 192 | TEST_ASSERT_EQUAL_FLOAT(4.0f, d_beta->data[0]); 193 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[0]); 194 | TEST_ASSERT_EQUAL_FLOAT(0.0f, d_in->data[1]); 195 | 196 | tensor_free(input); 197 | tensor_free(gamma); 198 | tensor_free(mean); 199 | tensor_free(var); 200 | tensor_free(grad_output); 201 | tensor_free(d_in); 202 | tensor_free(d_gamma); 203 | tensor_free(d_beta); 204 | } 205 | 206 | int main(void) { 207 | UNITY_BEGIN(); 208 | RUN_TEST(test_conv2d_backward_simple); 209 | RUN_TEST(test_conv2d_backward_stride); 210 | RUN_TEST(test_conv2d_backward_padding); 211 | RUN_TEST(test_maxpool2d_backward_simple); 212 | RUN_TEST(test_avgpool2d_backward_simple); 213 | RUN_TEST(test_batchnorm2d_backward_simple); 214 | return UNITY_END(); 215 | } 216 | -------------------------------------------------------------------------------- /src/utils/tqdm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define TQDM_BAR_WIDTH 70 9 | #define TQDM_PLOT_HEIGHT 15 10 | #define TQDM_PLOT_WIDTH 78 11 | 12 | // macro overloading to make the `prefix` and `postfix` args optional (default to `NULL`) 13 | #define tqdm(...) TQDM_SELECT(__VA_ARGS__, tqdm_4, tqdm_3, tqdm_2)(__VA_ARGS__) 14 | #define TQDM_SELECT(_1, _2, _3, _4, NAME, ...) NAME // selects macro based on arg count 15 | #define tqdm_2(current, total) tqdm_impl(current, total, NULL, NULL) 16 | #define tqdm_3(current, total, prefix) tqdm_impl(current, total, prefix, NULL) 17 | #define tqdm_4(current, total, prefix, postfix) tqdm_impl(current, total, prefix, postfix) 18 | 19 | static struct timeval start_time = {0, 0}; 20 | 21 | static inline void tqdm_impl(uint64_t current, uint64_t total, const char *prefix, const char *postfix) { 22 | if (total == 0) { 23 | return; 24 | } 25 | 26 | if (start_time.tv_sec == 0 && start_time.tv_usec == 0) { 27 | gettimeofday(&start_time, NULL); 28 | } 29 | 30 | const double progress = (double)current / (double)total; 31 | const uint32_t percentage = (uint32_t)(progress * 100.0); 32 | const uint32_t bar_width = TQDM_BAR_WIDTH; 33 | const uint32_t filled = (uint32_t)(progress * bar_width); 34 | 35 | struct timeval now; 36 | gettimeofday(&now, NULL); 37 | const double elapsed = (double)(now.tv_sec - start_time.tv_sec) + (double)(now.tv_usec - start_time.tv_usec) / 1e6; 38 | const double rate = (elapsed > 0) ? (double)current / elapsed : 0.0; 39 | 40 | printf("\r"); 41 | if (prefix && prefix[0] != '\0') { 42 | printf("%s: ", prefix); 43 | } 44 | printf("%3u%%|", percentage); 45 | for (uint32_t i = 0; i < filled; i++) { 46 | printf("█"); 47 | } 48 | if (filled < bar_width) { 49 | double partial = (progress * bar_width) - filled; 50 | if (partial > 0.75) { 51 | printf("▊"); 52 | } else if (partial > 0.5) { 53 | printf("▌"); 54 | } else if (partial > 0.25) { 55 | printf("▎"); 56 | } else { 57 | printf("▏"); 58 | } 59 | for (uint32_t i = filled + 1; i < bar_width; i++) { 60 | printf(" "); 61 | } 62 | } 63 | printf("| %" PRIu64 "/%" PRIu64 " [%.1fit/s]", current, total, rate); 64 | if (postfix && postfix[0] != '\0') { 65 | printf(" %s", postfix); 66 | } 67 | printf(" "); 68 | fflush(stdout); 69 | 70 | // reset for next use 71 | if (current >= total) { 72 | start_time.tv_sec = 0; 73 | start_time.tv_usec = 0; 74 | printf("\n"); 75 | } 76 | } 77 | 78 | // macro overloading for tqdm_plot with optional prefix 79 | #define tqdm_plot(...) TQDM_PLOT_SELECT(__VA_ARGS__, tqdm_plot_4, tqdm_plot_3)(__VA_ARGS__) 80 | #define TQDM_PLOT_SELECT(_1, _2, _3, _4, NAME, ...) NAME 81 | #define tqdm_plot_3(current, total, loss) tqdm_plot_impl(current, total, loss, NULL) 82 | #define tqdm_plot_4(current, total, loss, prefix) tqdm_plot_impl(current, total, loss, prefix) 83 | 84 | static double plot_loss_history[TQDM_PLOT_WIDTH]; 85 | static int plot_history_count = 0; 86 | static int plot_initialized = 0; 87 | 88 | static inline void tqdm_plot_impl(uint64_t current, uint64_t total, double loss, const char *prefix) { 89 | if (total == 0) { 90 | return; 91 | } 92 | 93 | // plot spacing 94 | if (!plot_initialized) { 95 | for (int i = 0; i < TQDM_PLOT_HEIGHT + 2; i++) { 96 | printf("\n"); 97 | } 98 | for (int i = 0; i < TQDM_PLOT_WIDTH; i++) { 99 | plot_loss_history[i] = 0.0; 100 | } 101 | plot_initialized = 1; 102 | } 103 | 104 | // update history buffer 105 | if (plot_history_count < TQDM_PLOT_WIDTH) { 106 | plot_loss_history[plot_history_count++] = loss; 107 | } else { 108 | for (int i = 0; i < TQDM_PLOT_WIDTH - 1; i++) { 109 | plot_loss_history[i] = plot_loss_history[i + 1]; 110 | } 111 | plot_loss_history[TQDM_PLOT_WIDTH - 1] = loss; 112 | } 113 | 114 | // determine min/max for scaling 115 | double min_loss = plot_loss_history[0]; 116 | double max_loss = plot_loss_history[0]; 117 | for (int i = 1; i < plot_history_count; i++) { 118 | if (plot_loss_history[i] < min_loss) 119 | min_loss = plot_loss_history[i]; 120 | if (plot_loss_history[i] > max_loss) 121 | max_loss = plot_loss_history[i]; 122 | } 123 | 124 | if (max_loss == min_loss) { 125 | max_loss += 1e-6; 126 | } 127 | 128 | // move cursor up to redraw plot 129 | printf("\033[%dA", TQDM_PLOT_HEIGHT + 2); 130 | 131 | // frame top 132 | printf("\r\033[K"); 133 | printf("┌"); 134 | if (prefix && prefix[0] != '\0') { 135 | int title_len = 0; 136 | while (prefix[title_len]) { 137 | title_len++; 138 | } 139 | int padding = (TQDM_PLOT_WIDTH - title_len - 2) / 2; // -2 for spaces around title 140 | if (padding < 0) { 141 | padding = 0; 142 | } 143 | for (int i = 0; i < padding; i++) { 144 | printf("─"); 145 | } 146 | printf(" %s ", prefix); 147 | for (int i = 0; i < TQDM_PLOT_WIDTH - padding - title_len - 2; i++) { 148 | printf("─"); 149 | } 150 | } else { 151 | for (int i = 0; i < TQDM_PLOT_WIDTH; i++) { 152 | printf("─"); 153 | } 154 | } 155 | printf("┐\n"); 156 | 157 | // render plot 158 | for (int row = TQDM_PLOT_HEIGHT - 1; row >= 0; row--) { 159 | printf("\r\033[K"); 160 | printf("│"); 161 | 162 | for (int col = 0; col < TQDM_PLOT_WIDTH; col++) { 163 | if (col >= plot_history_count) { 164 | printf(" "); 165 | continue; 166 | } 167 | 168 | double val = plot_loss_history[col]; 169 | double normalized = (val - min_loss) / (max_loss - min_loss); 170 | int tick = (int)(normalized * (TQDM_PLOT_HEIGHT - 1)); 171 | 172 | if (tick == row) { 173 | printf("•"); 174 | } else if (row < tick) { 175 | printf(" "); // below the point 176 | } else { 177 | printf(" "); // above the point 178 | } 179 | } 180 | 181 | // right border with meter indicator 182 | double current_normalized = (loss - min_loss) / (max_loss - min_loss); 183 | int current_tick = (int)(current_normalized * (TQDM_PLOT_HEIGHT - 1)); 184 | 185 | if (row == current_tick) { 186 | printf("┤"); // marker for current value 187 | } else { 188 | printf("│"); 189 | } 190 | 191 | // right border axis labels 192 | if (row == TQDM_PLOT_HEIGHT - 1) { 193 | printf(" %.4f (max)", max_loss); 194 | } else if (row == 0) { 195 | printf(" %.4f (min)", min_loss); 196 | } else if (row == current_tick) { 197 | printf("––– %.4f", loss); 198 | } 199 | 200 | printf("\n"); 201 | } 202 | 203 | // frame bottom 204 | printf("\r\033[K"); 205 | printf("└"); 206 | for (int i = 0; i < TQDM_PLOT_WIDTH; i++) { 207 | printf("─"); 208 | } 209 | printf("┘\n"); 210 | 211 | tqdm(current, total, NULL, NULL); 212 | 213 | if (current >= total) { 214 | plot_initialized = 0; 215 | plot_history_count = 0; 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/ops/arithmetic_backward.c: -------------------------------------------------------------------------------- 1 | #include "ops/arithmetic_backward.h" 2 | #include "ops/arithmetic.h" 3 | #include "ops/reductions.h" 4 | #include "ops/reshapes.h" 5 | #include "tensor.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | static Tensor *unbroadcast(const Tensor *grad, const Tensor *input) { 12 | assert(grad != NULL); 13 | assert(input != NULL); 14 | 15 | const Tensor *curr_grad = grad; 16 | bool owns_tensor = false; 17 | 18 | // broadcasting adds dimensions on the left, so collapse extra leading dimensions 19 | // e.g., grad (2, 3, 4) -> input (3, 4) => sum dim 0 20 | while (curr_grad->ndim > input->ndim) { 21 | Tensor *summed = tensor_sum(curr_grad, 0, false); 22 | if (owns_tensor) { 23 | tensor_free((Tensor *)curr_grad); 24 | } 25 | curr_grad = summed; 26 | owns_tensor = true; 27 | } 28 | 29 | assert(curr_grad->ndim == input->ndim); 30 | assert(input->ndim < MAX_NDIM); 31 | 32 | // dimensions match. collapse any dims where input had size 1 (broadcasted dim) 33 | // e.g., grad (3, 4) -> input (1, 4) => sum dim 0 34 | for (uint64_t dim_idx = 0; dim_idx < input->ndim; dim_idx++) { 35 | if (input->shape[dim_idx] == 1 && curr_grad->shape[dim_idx] > 1) { 36 | Tensor *summed = tensor_sum(curr_grad, (int64_t)dim_idx, true); 37 | if (owns_tensor) { 38 | tensor_free((Tensor *)curr_grad); 39 | } 40 | curr_grad = summed; 41 | owns_tensor = true; 42 | } 43 | } 44 | 45 | // if no reduction happened, ensure we return a new tensor to respect ownership contract 46 | if (!owns_tensor) { 47 | return tensor_create(grad->data, grad->shape, grad->ndim, false); 48 | } 49 | 50 | return (Tensor *)curr_grad; 51 | } 52 | 53 | // 54 | // add 55 | // 56 | 57 | Tensor *tensor_add_backward_a(const Tensor *grad_output, const Tensor *a) { 58 | assert(grad_output != NULL); 59 | assert(a != NULL); 60 | return unbroadcast(grad_output, a); 61 | } 62 | 63 | Tensor *tensor_add_backward_b(const Tensor *grad_output, const Tensor *b) { 64 | assert(grad_output != NULL); 65 | assert(b != NULL); 66 | return unbroadcast(grad_output, b); 67 | } 68 | 69 | void add_backward(Function *fn, const Tensor *grad_output) { 70 | assert(fn != NULL); 71 | assert(grad_output != NULL); 72 | assert(fn->num_inputs == 2); 73 | 74 | Tensor *a = fn->inputs[0]; 75 | Tensor *b = fn->inputs[1]; 76 | 77 | if (a != NULL && a->requires_grad) { 78 | Tensor *grad_a = tensor_add_backward_a(grad_output, a); 79 | accumulate_grad(a, grad_a); 80 | } 81 | 82 | if (b != NULL && b->requires_grad) { 83 | Tensor *grad_b = tensor_add_backward_b(grad_output, b); 84 | accumulate_grad(b, grad_b); 85 | } 86 | } 87 | 88 | // 89 | // sub 90 | // 91 | 92 | Tensor *tensor_sub_backward_a(const Tensor *grad_output, const Tensor *a) { 93 | assert(grad_output != NULL); 94 | assert(a != NULL); 95 | return unbroadcast(grad_output, a); 96 | } 97 | 98 | Tensor *tensor_sub_backward_b(const Tensor *grad_output, const Tensor *b) { 99 | assert(grad_output != NULL); 100 | assert(b != NULL); 101 | 102 | Tensor *zeros = tensor_zeros(grad_output->shape, grad_output->ndim, false); 103 | Tensor *neg_grad = tensor_sub(zeros, grad_output); 104 | tensor_free(zeros); 105 | 106 | Tensor *result = unbroadcast(neg_grad, b); 107 | tensor_free(neg_grad); 108 | return result; 109 | } 110 | 111 | void sub_backward(Function *fn, const Tensor *grad_output) { 112 | assert(fn != NULL); 113 | assert(grad_output != NULL); 114 | assert(fn->num_inputs == 2); 115 | 116 | Tensor *a = fn->inputs[0]; 117 | Tensor *b = fn->inputs[1]; 118 | 119 | if (a != NULL && a->requires_grad) { 120 | Tensor *grad_a = tensor_sub_backward_a(grad_output, a); 121 | accumulate_grad(a, grad_a); 122 | } 123 | 124 | if (b != NULL && b->requires_grad) { 125 | Tensor *grad_b = tensor_sub_backward_b(grad_output, b); 126 | accumulate_grad(b, grad_b); 127 | } 128 | } 129 | 130 | // 131 | // mul 132 | // 133 | 134 | Tensor *tensor_mul_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 135 | assert(grad_output != NULL); 136 | assert(a != NULL); 137 | assert(b != NULL); 138 | 139 | Tensor *temp = tensor_mul(grad_output, b, true); // disable_grad=true 140 | Tensor *result = unbroadcast(temp, a); 141 | tensor_free(temp); 142 | return result; 143 | } 144 | 145 | Tensor *tensor_mul_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 146 | assert(grad_output != NULL); 147 | assert(a != NULL); 148 | assert(b != NULL); 149 | 150 | Tensor *temp = tensor_mul(grad_output, a, true); // disable_grad=true 151 | Tensor *result = unbroadcast(temp, b); 152 | tensor_free(temp); 153 | return result; 154 | } 155 | 156 | void mul_backward(Function *fn, const Tensor *grad_output) { 157 | assert(fn != NULL); 158 | assert(grad_output != NULL); 159 | assert(fn->num_inputs == 2); 160 | 161 | Tensor *a = fn->inputs[0]; 162 | Tensor *b = fn->inputs[1]; 163 | 164 | if (a != NULL && a->requires_grad) { 165 | Tensor *grad_a = tensor_mul_backward_a(grad_output, a, b); 166 | accumulate_grad(a, grad_a); 167 | } 168 | 169 | if (b != NULL && b->requires_grad) { 170 | Tensor *grad_b = tensor_mul_backward_b(grad_output, a, b); 171 | accumulate_grad(b, grad_b); 172 | } 173 | } 174 | 175 | // 176 | // div 177 | // 178 | 179 | Tensor *tensor_div_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 180 | assert(grad_output != NULL); 181 | assert(a != NULL); 182 | assert(b != NULL); 183 | 184 | Tensor *temp = tensor_div(grad_output, b); 185 | Tensor *result = unbroadcast(temp, a); 186 | tensor_free(temp); 187 | return result; 188 | } 189 | 190 | Tensor *tensor_div_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 191 | assert(grad_output != NULL); 192 | assert(a != NULL); 193 | assert(b != NULL); 194 | 195 | Tensor *zeros = tensor_zeros(grad_output->shape, grad_output->ndim, false); 196 | Tensor *neg_grad = tensor_sub(zeros, grad_output); 197 | tensor_free(zeros); 198 | 199 | Tensor *num = tensor_mul(neg_grad, a); 200 | tensor_free(neg_grad); 201 | 202 | Tensor *b_sq = tensor_mul(b, b); 203 | 204 | Tensor *temp = tensor_div(num, b_sq); 205 | tensor_free(num); 206 | tensor_free(b_sq); 207 | 208 | Tensor *result = unbroadcast(temp, b); 209 | tensor_free(temp); 210 | return result; 211 | } 212 | 213 | void div_backward(Function *fn, const Tensor *grad_output) { 214 | assert(fn != NULL); 215 | assert(grad_output != NULL); 216 | assert(fn->num_inputs == 2); 217 | 218 | Tensor *a = fn->inputs[0]; 219 | Tensor *b = fn->inputs[1]; 220 | 221 | if (a != NULL && a->requires_grad) { 222 | Tensor *grad_a = tensor_div_backward_a(grad_output, a, b); 223 | accumulate_grad(a, grad_a); 224 | } 225 | 226 | if (b != NULL && b->requires_grad) { 227 | Tensor *grad_b = tensor_div_backward_b(grad_output, a, b); 228 | accumulate_grad(b, grad_b); 229 | } 230 | } 231 | 232 | // 233 | // matmul 234 | // 235 | 236 | Tensor *tensor_matmul_backward_a(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 237 | assert(grad_output != NULL); 238 | assert(a != NULL); 239 | assert(b != NULL); 240 | assert(b->ndim >= 2); 241 | 242 | Tensor *b_T = tensor_transpose(b, b->ndim - 2, b->ndim - 1); 243 | Tensor *temp = tensor_matmul(grad_output, b_T); 244 | tensor_free(b_T); 245 | 246 | Tensor *result = unbroadcast(temp, a); 247 | tensor_free(temp); 248 | return result; 249 | } 250 | 251 | Tensor *tensor_matmul_backward_b(const Tensor *grad_output, const Tensor *a, const Tensor *b) { 252 | assert(grad_output != NULL); 253 | assert(a != NULL); 254 | assert(b != NULL); 255 | assert(a->ndim >= 2); 256 | 257 | Tensor *a_T = tensor_transpose(a, a->ndim - 2, a->ndim - 1); 258 | Tensor *temp = tensor_matmul(a_T, grad_output); 259 | tensor_free(a_T); 260 | 261 | Tensor *result = unbroadcast(temp, b); 262 | tensor_free(temp); 263 | return result; 264 | } 265 | 266 | void matmul_backward(Function *fn, const Tensor *grad_output) { 267 | assert(fn != NULL); 268 | assert(grad_output != NULL); 269 | assert(fn->num_inputs == 2); 270 | 271 | Tensor *a = fn->inputs[0]; 272 | Tensor *b = fn->inputs[1]; 273 | 274 | if (a != NULL && a->requires_grad) { 275 | Tensor *grad_a = tensor_matmul_backward_a(grad_output, a, b); 276 | accumulate_grad(a, grad_a); 277 | } 278 | 279 | if (b != NULL && b->requires_grad) { 280 | Tensor *grad_b = tensor_matmul_backward_b(grad_output, a, b); 281 | accumulate_grad(b, grad_b); 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /src/ops/arithmetic.c: -------------------------------------------------------------------------------- 1 | #include "ops/arithmetic.h" 2 | #include "autograd.h" 3 | #include "ops/arithmetic_backward.h" 4 | #include "utils/aligned_alloc.h" 5 | #include 6 | #include 7 | #include 8 | 9 | /* 10 | * aligns dimensions from the right. 11 | * compatible if dimensions are equal or one of them is 1. 12 | * 13 | * shape_a: [ 3, 1] 14 | * shape_b: [2, 1, 5] 15 | * ^ ^ ^ 16 | * out: [2, 3, 5] 17 | */ 18 | bool broadcast_shapes_mut(const uint64_t *shape_a, uint64_t ndim_a, const uint64_t *shape_b, uint64_t ndim_b, uint64_t *out_shape, uint64_t *out_ndim) { 19 | assert(ndim_a <= MAX_NDIM); 20 | assert(ndim_b <= MAX_NDIM); 21 | assert(out_shape != NULL); 22 | assert(out_ndim != NULL); 23 | 24 | uint64_t max_ndim = (ndim_a > ndim_b) ? ndim_a : ndim_b; 25 | assert(max_ndim <= MAX_NDIM); 26 | *out_ndim = max_ndim; 27 | 28 | int64_t idx_a = (int64_t)ndim_a - 1; 29 | int64_t idx_b = (int64_t)ndim_b - 1; 30 | int64_t idx_out = (int64_t)max_ndim - 1; 31 | 32 | assert(idx_out < (int64_t)MAX_NDIM); 33 | while (idx_out >= 0) { 34 | uint64_t dim_a = (idx_a >= 0 && shape_a) ? shape_a[idx_a] : 1; 35 | uint64_t dim_b = (idx_b >= 0 && shape_b) ? shape_b[idx_b] : 1; 36 | 37 | if (dim_a != dim_b && dim_a != 1 && dim_b != 1) { 38 | return false; 39 | } 40 | 41 | if (dim_a == 1) { 42 | out_shape[idx_out] = dim_b; 43 | } else { 44 | out_shape[idx_out] = dim_a; 45 | } 46 | 47 | idx_a--; 48 | idx_b--; 49 | idx_out--; 50 | } 51 | return true; 52 | } 53 | 54 | typedef float32_t (*binary_op_t)(float32_t, float32_t); 55 | 56 | Tensor *tensor_binary_op(const Tensor *a, const Tensor *b, binary_op_t op) { 57 | assert(a != NULL); 58 | assert(b != NULL); 59 | assert(op != NULL); 60 | assert(a->data != NULL || a->size == 0); 61 | assert(b->data != NULL || b->size == 0); 62 | if (a->size > 0) { 63 | assert((uintptr_t)a->data % CACHELINE_SIZE == 0 && "a->data is not properly aligned"); 64 | } 65 | if (b->size > 0) { 66 | assert((uintptr_t)b->data % CACHELINE_SIZE == 0 && "b->data is not properly aligned"); 67 | } 68 | 69 | uint64_t out_shape[MAX_NDIM]; 70 | uint64_t out_ndim; 71 | if (!broadcast_shapes_mut(a->shape, a->ndim, b->shape, b->ndim, out_shape, &out_ndim)) { 72 | assert(false && "shapes cannot be broadcasted"); 73 | } 74 | assert(out_ndim <= MAX_NDIM); 75 | Tensor *out_tensor = tensor_zeros(out_shape, out_ndim, a->requires_grad || b->requires_grad); 76 | 77 | // curr = current position in output tensor as multidim indices 78 | uint64_t *curr = (uint64_t *)calloc((size_t)out_ndim, sizeof(uint64_t)); 79 | assert(curr != NULL && "calloc failed"); 80 | 81 | // i = current position in output tensor as linear index 82 | for (uint64_t i = 0; i < out_tensor->size; i++) { 83 | // convert i to curr (mutates curr array) 84 | linear_to_multidim_mut(i, out_shape, out_ndim, curr); 85 | 86 | uint64_t offset_a = multidim_to_linear(curr, out_ndim, a->shape, a->ndim, a->strides); 87 | assert(offset_a < a->size && "offset_a out of bounds"); 88 | 89 | uint64_t offset_b = multidim_to_linear(curr, out_ndim, b->shape, b->ndim, b->strides); 90 | assert(offset_b < b->size && "offset_b out of bounds"); 91 | 92 | out_tensor->data[i] = op(a->data[offset_a], b->data[offset_b]); 93 | } 94 | 95 | free(curr); 96 | 97 | assert(out_tensor != NULL); 98 | assert(out_tensor->ndim == out_ndim); 99 | assert(out_tensor->data != NULL || out_tensor->size == 0); 100 | return out_tensor; 101 | } 102 | 103 | // 104 | // add 105 | // 106 | 107 | static float32_t op_add(float32_t a, float32_t b) { return a + b; } 108 | Tensor *tensor_add(const Tensor *a, const Tensor *b) { 109 | Tensor *result = tensor_binary_op(a, b, op_add); 110 | 111 | if (result->requires_grad) { 112 | Function *fn = arena_alloc_function(); 113 | fn->apply = add_backward; 114 | fn->output = result; 115 | fn->num_inputs = 2; 116 | fn->inputs[0] = (Tensor *)a; 117 | fn->inputs[1] = (Tensor *)b; 118 | fn->pending_count = 0; 119 | fn->ctx = NULL; 120 | if (a->grad_fn != NULL) { 121 | a->grad_fn->pending_count++; 122 | } 123 | if (b->grad_fn != NULL) { 124 | b->grad_fn->pending_count++; 125 | } 126 | result->grad_fn = fn; 127 | } 128 | 129 | return result; 130 | } 131 | 132 | // 133 | // sub 134 | // 135 | 136 | static float32_t op_sub(float32_t a, float32_t b) { return a - b; } 137 | Tensor *tensor_sub(const Tensor *a, const Tensor *b) { 138 | Tensor *result = tensor_binary_op(a, b, op_sub); 139 | 140 | if (result->requires_grad) { 141 | Function *fn = arena_alloc_function(); 142 | fn->apply = sub_backward; 143 | fn->output = result; 144 | fn->num_inputs = 2; 145 | fn->inputs[0] = (Tensor *)a; 146 | fn->inputs[1] = (Tensor *)b; 147 | fn->pending_count = 0; 148 | fn->ctx = NULL; 149 | if (a->grad_fn != NULL) { 150 | a->grad_fn->pending_count++; 151 | } 152 | if (b->grad_fn != NULL) { 153 | b->grad_fn->pending_count++; 154 | } 155 | result->grad_fn = fn; 156 | } 157 | 158 | return result; 159 | } 160 | 161 | // 162 | // mul 163 | // 164 | 165 | static float32_t op_mul(float32_t a, float32_t b) { return a * b; } 166 | Tensor *tensor_mul_impl(const Tensor *a, const Tensor *b, bool disable_grad) { 167 | Tensor *result = tensor_binary_op(a, b, op_mul); 168 | 169 | // skip graph construction 170 | if (disable_grad) { 171 | result->requires_grad = false; 172 | return result; 173 | } 174 | 175 | if (result->requires_grad) { 176 | Function *fn = arena_alloc_function(); 177 | fn->apply = mul_backward; 178 | fn->output = result; 179 | fn->num_inputs = 2; 180 | fn->inputs[0] = (Tensor *)a; 181 | fn->inputs[1] = (Tensor *)b; 182 | fn->pending_count = 0; 183 | fn->ctx = NULL; 184 | if (a->grad_fn != NULL) { 185 | a->grad_fn->pending_count++; 186 | } 187 | if (b->grad_fn != NULL) { 188 | b->grad_fn->pending_count++; 189 | } 190 | result->grad_fn = fn; 191 | } 192 | return result; 193 | } 194 | 195 | // 196 | // div 197 | // 198 | 199 | static float32_t op_div(float32_t a, float32_t b) { return a / b; } 200 | Tensor *tensor_div(const Tensor *a, const Tensor *b) { 201 | Tensor *result = tensor_binary_op(a, b, op_div); 202 | 203 | if (result->requires_grad) { 204 | Function *fn = arena_alloc_function(); 205 | fn->apply = div_backward; 206 | fn->output = result; 207 | fn->num_inputs = 2; 208 | fn->inputs[0] = (Tensor *)a; 209 | fn->inputs[1] = (Tensor *)b; 210 | fn->pending_count = 0; 211 | fn->ctx = NULL; 212 | if (a->grad_fn != NULL) { 213 | a->grad_fn->pending_count++; 214 | } 215 | if (b->grad_fn != NULL) { 216 | b->grad_fn->pending_count++; 217 | } 218 | result->grad_fn = fn; 219 | } 220 | 221 | return result; 222 | } 223 | 224 | // 225 | // matmul 226 | // 227 | 228 | Tensor *tensor_matmul(const Tensor *a, const Tensor *b) { 229 | assert(a != NULL); 230 | assert(b != NULL); 231 | assert(a->data != NULL); 232 | assert(b->data != NULL); 233 | assert((uintptr_t)a->data % CACHELINE_SIZE == 0 && "a->data is not properly aligned"); 234 | assert((uintptr_t)b->data % CACHELINE_SIZE == 0 && "b->data is not properly aligned"); 235 | assert(a->ndim >= 1 && b->ndim >= 1 && "matmul requires at least 1D tensors"); 236 | assert(a->ndim == 2 && b->ndim == 2 && "only 2D matmul supported"); 237 | assert(a->shape[1] == b->shape[0] && "inner dimensions must match"); 238 | 239 | uint64_t M = a->shape[0]; 240 | uint64_t K = a->shape[1]; 241 | uint64_t N = b->shape[1]; 242 | 243 | assert(M <= MAX_TENSOR_SIZE); 244 | assert(K <= MAX_TENSOR_SIZE); 245 | assert(N <= MAX_TENSOR_SIZE); 246 | 247 | const uint64_t out_shape[] = {M, N}; 248 | Tensor *result = tensor_zeros(out_shape, 2, a->requires_grad || b->requires_grad); 249 | 250 | // naive algorithm 251 | for (uint64_t i = 0; i < M; i++) { 252 | for (uint64_t j = 0; j < N; j++) { 253 | float32_t sum = 0.0f; 254 | for (uint64_t k = 0; k < K; k++) { 255 | uint64_t a_offset = i * a->strides[0] + k * a->strides[1]; 256 | uint64_t b_offset = k * b->strides[0] + j * b->strides[1]; 257 | assert(a_offset < a->size && "a_offset out of bounds"); 258 | assert(b_offset < b->size && "b_offset out of bounds"); 259 | sum += a->data[a_offset] * b->data[b_offset]; 260 | } 261 | uint64_t result_offset = i * result->strides[0] + j * result->strides[1]; 262 | assert(result_offset < result->size && "result_offset out of bounds"); 263 | result->data[result_offset] = sum; 264 | } 265 | } 266 | 267 | assert(result != NULL); 268 | assert(result->ndim == 2); 269 | assert(result->shape[0] == M); 270 | assert(result->shape[1] == N); 271 | 272 | if (result->requires_grad) { 273 | Function *fn = arena_alloc_function(); 274 | fn->apply = matmul_backward; 275 | fn->output = result; 276 | fn->num_inputs = 2; 277 | fn->inputs[0] = (Tensor *)a; 278 | fn->inputs[1] = (Tensor *)b; 279 | fn->pending_count = 0; 280 | fn->ctx = NULL; 281 | if (a->grad_fn != NULL) { 282 | a->grad_fn->pending_count++; 283 | } 284 | if (b->grad_fn != NULL) { 285 | b->grad_fn->pending_count++; 286 | } 287 | result->grad_fn = fn; 288 | } 289 | 290 | return result; 291 | } 292 | -------------------------------------------------------------------------------- /src/ops/activations_backward.c: -------------------------------------------------------------------------------- 1 | #include "activations_backward.h" 2 | #include "ops/arithmetic.h" 3 | #include "tensor.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // 10 | // sigmoid 11 | // 12 | 13 | Tensor *tensor_sigmoid_backward(const Tensor *t) { 14 | assert(t != NULL); 15 | assert(t->data != NULL || t->size == 0); 16 | 17 | Tensor *grad = tensor_create(NULL, t->shape, t->ndim, false); 18 | 19 | for (uint64_t i = 0; i < t->size; i++) { 20 | float32_t x = t->data[i]; 21 | float32_t sigmoid_x = 1.0f / (1.0f + expf(-x)); 22 | grad->data[i] = sigmoid_x * (1.0f - sigmoid_x); 23 | } 24 | 25 | return grad; 26 | } 27 | 28 | void sigmoid_backward(Function *fn, const Tensor *grad_output) { 29 | assert(fn != NULL); 30 | assert(grad_output != NULL); 31 | assert(fn->num_inputs == 1); 32 | 33 | Tensor *input = fn->inputs[0]; 34 | const Tensor *output = fn->output; 35 | 36 | if (input != NULL && input->requires_grad) { 37 | // grad_input = grad_output * output * (1 - output) 38 | // where output = sigmoid(input) 39 | Tensor *local_grad = tensor_create(NULL, output->shape, output->ndim, false); 40 | for (uint64_t i = 0; i < output->size; i++) { 41 | float32_t out_val = output->data[i]; 42 | local_grad->data[i] = out_val * (1.0f - out_val); 43 | } 44 | 45 | Tensor *grad_input = tensor_mul(grad_output, local_grad, true); // disable_grad=true 46 | tensor_free(local_grad); 47 | accumulate_grad(input, grad_input); 48 | } 49 | } 50 | 51 | // 52 | // relu 53 | // 54 | 55 | Tensor *tensor_relu_backward(const Tensor *t) { 56 | assert(t != NULL); 57 | assert(t->data != NULL || t->size == 0); 58 | 59 | Tensor *grad = tensor_create(NULL, t->shape, t->ndim, false); 60 | 61 | for (uint64_t i = 0; i < t->size; i++) { 62 | grad->data[i] = (t->data[i] > 0.0f) ? 1.0f : 0.0f; 63 | } 64 | 65 | return grad; 66 | } 67 | 68 | void relu_backward(Function *fn, const Tensor *grad_output) { 69 | assert(fn != NULL); 70 | assert(grad_output != NULL); 71 | assert(fn->num_inputs == 1); 72 | 73 | Tensor *input = fn->inputs[0]; 74 | 75 | if (input != NULL && input->requires_grad) { 76 | // grad_input = grad_output * (input > 0 ? 1 : 0) 77 | Tensor *local_grad = tensor_create(NULL, input->shape, input->ndim, false); 78 | for (uint64_t i = 0; i < input->size; i++) { 79 | local_grad->data[i] = (input->data[i] > 0.0f) ? 1.0f : 0.0f; 80 | } 81 | 82 | Tensor *grad_input = tensor_mul(grad_output, local_grad, true); // disable_grad=true 83 | tensor_free(local_grad); 84 | accumulate_grad(input, grad_input); 85 | } 86 | } 87 | 88 | // 89 | // tanh 90 | // 91 | 92 | Tensor *tensor_tanh_backward(const Tensor *t) { 93 | assert(t != NULL); 94 | assert(t->data != NULL || t->size == 0); 95 | 96 | Tensor *grad = tensor_create(NULL, t->shape, t->ndim, false); 97 | 98 | for (uint64_t i = 0; i < t->size; i++) { 99 | float32_t tanh_x = tanhf(t->data[i]); 100 | grad->data[i] = 1.0f - tanh_x * tanh_x; 101 | } 102 | 103 | return grad; 104 | } 105 | 106 | void tanh_backward(Function *fn, const Tensor *grad_output) { 107 | assert(fn != NULL); 108 | assert(grad_output != NULL); 109 | assert(fn->num_inputs == 1); 110 | 111 | Tensor *input = fn->inputs[0]; 112 | const Tensor *output = fn->output; 113 | 114 | if (input != NULL && input->requires_grad) { 115 | // grad_input = grad_output * (1 - output^2) 116 | // where output = tanh(input) 117 | Tensor *local_grad = tensor_create(NULL, output->shape, output->ndim, false); 118 | for (uint64_t i = 0; i < output->size; i++) { 119 | float32_t out_val = output->data[i]; 120 | local_grad->data[i] = 1.0f - out_val * out_val; 121 | } 122 | 123 | Tensor *grad_input = tensor_mul(grad_output, local_grad, true); // disable_grad=true 124 | tensor_free(local_grad); 125 | accumulate_grad(input, grad_input); 126 | } 127 | } 128 | 129 | // 130 | // gelu 131 | // 132 | 133 | Tensor *tensor_gelu_backward(const Tensor *t) { 134 | assert(t != NULL); 135 | assert(t->data != NULL || t->size == 0); 136 | 137 | Tensor *grad = tensor_create(NULL, t->shape, t->ndim, false); 138 | 139 | float32_t sqrt_2_over_pi = sqrtf(2.0f / (float32_t)M_PI); 140 | float32_t coeff = 0.044715f; 141 | 142 | for (uint64_t i = 0; i < t->size; i++) { 143 | float32_t x = t->data[i]; 144 | float32_t x2 = x * x; 145 | float32_t x3 = x2 * x; 146 | 147 | float32_t tanh_arg = sqrt_2_over_pi * (x + coeff * x3); 148 | float32_t tanh_out = tanhf(tanh_arg); 149 | float32_t sech_sq = 1.0f - tanh_out * tanh_out; 150 | 151 | float32_t d_tanh_arg = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2); 152 | 153 | grad->data[i] = 0.5f * (1.0f + tanh_out) + 0.5f * x * sech_sq * d_tanh_arg; 154 | } 155 | 156 | return grad; 157 | } 158 | 159 | void gelu_backward(Function *fn, const Tensor *grad_output) { 160 | assert(fn != NULL); 161 | assert(grad_output != NULL); 162 | assert(fn->num_inputs == 1); 163 | 164 | Tensor *input = fn->inputs[0]; 165 | 166 | if (input != NULL && input->requires_grad) { 167 | // Use the existing tensor_gelu_backward function 168 | Tensor *local_grad = tensor_gelu_backward(input); 169 | Tensor *grad_input = tensor_mul(grad_output, local_grad, true); // disable_grad=true 170 | tensor_free(local_grad); 171 | accumulate_grad(input, grad_input); 172 | } 173 | } 174 | 175 | // 176 | // softmax 177 | // 178 | 179 | Tensor *tensor_softmax_backward(const Tensor *t, int64_t dim) { 180 | assert(t != NULL); 181 | assert(t->data != NULL || t->size == 0); 182 | 183 | int64_t ndim = (int64_t)t->ndim; 184 | int64_t target_dim = (dim < 0) ? (dim + ndim) : dim; 185 | assert(target_dim >= 0 && target_dim < ndim && "Invalid dimension"); 186 | 187 | Tensor *grad = tensor_create(NULL, t->shape, t->ndim, false); 188 | 189 | uint64_t outer_size = 1; 190 | for (int64_t d = 0; d < target_dim; d++) { 191 | outer_size *= t->shape[d]; 192 | } 193 | 194 | uint64_t dim_size = t->shape[target_dim]; 195 | 196 | uint64_t inner_size = 1; 197 | for (uint64_t d = (uint64_t)target_dim + 1; d < t->ndim; d++) { 198 | inner_size *= t->shape[d]; 199 | } 200 | 201 | for (uint64_t outer = 0; outer < outer_size; outer++) { 202 | for (uint64_t inner = 0; inner < inner_size; inner++) { 203 | uint64_t base_idx = outer * dim_size * inner_size + inner; 204 | 205 | float32_t max_val = -INFINITY; 206 | for (uint64_t d = 0; d < dim_size; d++) { 207 | uint64_t idx = base_idx + d * inner_size; 208 | if (t->data[idx] > max_val) { 209 | max_val = t->data[idx]; 210 | } 211 | } 212 | 213 | float32_t exp_sum = 0.0f; 214 | for (uint64_t d = 0; d < dim_size; d++) { 215 | uint64_t idx = base_idx + d * inner_size; 216 | exp_sum += expf(t->data[idx] - max_val); 217 | } 218 | 219 | for (uint64_t d = 0; d < dim_size; d++) { 220 | uint64_t idx = base_idx + d * inner_size; 221 | float32_t softmax_val = expf(t->data[idx] - max_val) / exp_sum; 222 | grad->data[idx] = softmax_val * (1.0f - softmax_val); 223 | } 224 | } 225 | } 226 | return grad; 227 | } 228 | 229 | void softmax_backward(Function *fn, const Tensor *grad_output) { 230 | assert(fn != NULL); 231 | assert(grad_output != NULL); 232 | assert(fn->num_inputs == 1); 233 | assert(fn->ctx != NULL && "softmax_backward requires context"); 234 | 235 | Tensor *input = fn->inputs[0]; 236 | const Tensor *output = fn->output; 237 | int64_t dim = *(int64_t *)fn->ctx; 238 | 239 | if (input != NULL && input->requires_grad) { 240 | // softmax backward: grad_input = output * (grad_output - sum(grad_output * output)) 241 | // this is the jacobian-vector product for softmax 242 | 243 | int64_t ndim = (int64_t)output->ndim; 244 | int64_t target_dim = (dim < 0) ? (dim + ndim) : dim; 245 | assert(target_dim >= 0 && target_dim < ndim && "Invalid dimension"); 246 | 247 | uint64_t outer_size = 1; 248 | for (int64_t d = 0; d < target_dim; d++) { 249 | outer_size *= output->shape[d]; 250 | } 251 | 252 | uint64_t dim_size = output->shape[target_dim]; 253 | 254 | uint64_t inner_size = 1; 255 | for (uint64_t d = (uint64_t)target_dim + 1; d < output->ndim; d++) { 256 | inner_size *= output->shape[d]; 257 | } 258 | 259 | Tensor *grad_input = tensor_create(NULL, input->shape, input->ndim, false); 260 | 261 | for (uint64_t outer = 0; outer < outer_size; outer++) { 262 | for (uint64_t inner = 0; inner < inner_size; inner++) { 263 | uint64_t base_idx = outer * dim_size * inner_size + inner; 264 | 265 | // sum(grad_output * output) along the softmax dimension 266 | float32_t sum_grad_output_output = 0.0f; 267 | for (uint64_t d = 0; d < dim_size; d++) { 268 | uint64_t idx = base_idx + d * inner_size; 269 | sum_grad_output_output += grad_output->data[idx] * output->data[idx]; 270 | } 271 | 272 | // grad_input = output * (grad_output - sum) 273 | for (uint64_t d = 0; d < dim_size; d++) { 274 | uint64_t idx = base_idx + d * inner_size; 275 | grad_input->data[idx] = output->data[idx] * (grad_output->data[idx] - sum_grad_output_output); 276 | } 277 | } 278 | } 279 | 280 | accumulate_grad(input, grad_input); 281 | } 282 | 283 | free(fn->ctx); 284 | fn->ctx = NULL; 285 | } 286 | -------------------------------------------------------------------------------- /src/optimizers.c: -------------------------------------------------------------------------------- 1 | #include "optimizers.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // 8 | // base implementation 9 | // 10 | 11 | void optimizer_zero_grad(Optimizer *opt) { 12 | assert(opt != NULL); 13 | for (size_t i = 0; i < opt->param_count; ++i) { 14 | Tensor *param = opt->params[i]; 15 | if (param->grad == NULL) { 16 | continue; 17 | } 18 | tensor_free(param->grad); 19 | param->grad = NULL; 20 | } 21 | } 22 | 23 | void optimizer_step(Optimizer *opt) { 24 | assert(opt != NULL); 25 | opt->step(opt); 26 | } 27 | 28 | void optimizer_free(Optimizer *opt) { 29 | if (opt != NULL) { 30 | opt->free(opt); 31 | } 32 | } 33 | 34 | // 35 | // sgd 36 | // 37 | 38 | typedef struct { 39 | Optimizer base; 40 | float32_t lr; 41 | float32_t momentum; 42 | float32_t weight_decay; 43 | float32_t **momentum_buffers; // array of pointers to float arrays (same size as param data) 44 | } SGD; 45 | 46 | static void sgd_free(Optimizer *opt) { 47 | SGD *sgd = (SGD *)opt; 48 | if (sgd->momentum_buffers) { 49 | for (size_t i = 0; i < opt->param_count; ++i) { 50 | if (sgd->momentum_buffers[i]) { 51 | free(sgd->momentum_buffers[i]); 52 | } 53 | } 54 | free(sgd->momentum_buffers); 55 | } 56 | if (opt->params) { 57 | free(opt->params); 58 | } 59 | free(sgd); 60 | } 61 | 62 | static void sgd_ensure_buffer(SGD *sgd, size_t param_idx, size_t elem_count) { 63 | if (sgd->momentum_buffers[param_idx] != NULL) { 64 | return; 65 | } 66 | float32_t *buf = calloc(elem_count, sizeof(float32_t)); 67 | assert(buf != NULL); 68 | sgd->momentum_buffers[param_idx] = buf; 69 | } 70 | 71 | static void sgd_step(Optimizer *opt) { 72 | SGD *sgd = (SGD *)opt; 73 | opt->step_count++; 74 | 75 | for (size_t i = 0; i < opt->param_count; ++i) { 76 | Tensor *param = opt->params[i]; 77 | if (param->grad == NULL) { 78 | continue; 79 | } 80 | 81 | float32_t *p_data = param->data; 82 | const float32_t *g_data = param->grad->data; 83 | const size_t elem_count = param->size; 84 | 85 | if (sgd->momentum != 0.0f) { 86 | sgd_ensure_buffer(sgd, i, elem_count); 87 | } 88 | 89 | float32_t *m_buf = sgd->momentum_buffers[i]; 90 | const float32_t lr = sgd->lr; 91 | const float32_t momentum = sgd->momentum; 92 | const float32_t weight_decay = sgd->weight_decay; 93 | 94 | for (size_t j = 0; j < elem_count; ++j) { 95 | float32_t g = g_data[j]; 96 | 97 | // weight decay (L2 penalty) 98 | if (weight_decay != 0.0f) { 99 | g += weight_decay * p_data[j]; 100 | } 101 | 102 | // momentum 103 | if (momentum != 0.0f) { 104 | // v = momentum * v_prev + g 105 | m_buf[j] = momentum * m_buf[j] + g; 106 | // update gradient to be the velocity 107 | g = m_buf[j]; 108 | } 109 | 110 | // update parameter 111 | p_data[j] -= lr * g; 112 | } 113 | } 114 | } 115 | 116 | Optimizer *optimizer_sgd_create(Tensor **params, size_t count, float32_t lr, float32_t momentum, float32_t weight_decay) { 117 | assert(params != NULL); 118 | assert(count > 0); 119 | 120 | Optimizer *opt = calloc(1, sizeof(SGD)); 121 | assert(opt != NULL); 122 | SGD *sgd = (SGD *)opt; 123 | 124 | sgd->base.param_count = count; 125 | // copy the params array so validation/integrity is kept 126 | sgd->base.params = calloc(count, sizeof(Tensor *)); 127 | assert(sgd->base.params != NULL); 128 | for (size_t i = 0; i < count; ++i) { 129 | assert(params[i] != NULL); 130 | assert(params[i]->requires_grad && "all optimized tensors must require grad"); 131 | sgd->base.params[i] = params[i]; 132 | } 133 | 134 | sgd->base.step = sgd_step; 135 | sgd->base.free = sgd_free; 136 | sgd->base.step_count = 0; 137 | 138 | sgd->lr = lr; 139 | sgd->momentum = momentum; 140 | sgd->weight_decay = weight_decay; 141 | 142 | sgd->momentum_buffers = calloc(count, sizeof(float32_t *)); 143 | assert(sgd->momentum_buffers != NULL); 144 | 145 | return opt; 146 | } 147 | 148 | // 149 | // adam 150 | // 151 | 152 | typedef struct { 153 | Optimizer base; 154 | float32_t lr; 155 | float32_t beta1; 156 | float32_t beta2; 157 | float32_t eps; 158 | float32_t weight_decay; 159 | float32_t **m_buffers; // first moment 160 | float32_t **v_buffers; // second moment 161 | } Adam; 162 | 163 | static void adam_free(Optimizer *opt) { 164 | Adam *adam = (Adam *)opt; 165 | if (adam->m_buffers) { 166 | for (size_t i = 0; i < opt->param_count; ++i) { 167 | if (adam->m_buffers[i]) { 168 | free(adam->m_buffers[i]); 169 | } 170 | } 171 | free(adam->m_buffers); 172 | } 173 | if (adam->v_buffers) { 174 | for (size_t i = 0; i < opt->param_count; ++i) { 175 | if (adam->v_buffers[i]) { 176 | free(adam->v_buffers[i]); 177 | } 178 | } 179 | free(adam->v_buffers); 180 | } 181 | if (opt->params) { 182 | free(opt->params); 183 | } 184 | free(adam); 185 | } 186 | 187 | static void adam_ensure_buffers(Adam *adam, size_t param_idx, size_t elem_count) { 188 | if (adam->m_buffers[param_idx] != NULL) { 189 | return; 190 | } 191 | adam->m_buffers[param_idx] = calloc(elem_count, sizeof(float32_t)); 192 | assert(adam->m_buffers[param_idx] != NULL); 193 | adam->v_buffers[param_idx] = calloc(elem_count, sizeof(float32_t)); 194 | assert(adam->v_buffers[param_idx] != NULL); 195 | } 196 | 197 | static void adam_step_impl(Optimizer *opt, bool is_adamw) { 198 | Adam *adam = (Adam *)opt; 199 | opt->step_count++; 200 | 201 | const float32_t beta1 = adam->beta1; 202 | const float32_t beta2 = adam->beta2; 203 | const float32_t eps = adam->eps; 204 | const float32_t weight_decay = adam->weight_decay; 205 | const float32_t lr = adam->lr; 206 | 207 | // pre-compute bias corrections (constant across all parameters) 208 | // 1 - beta^t 209 | const float32_t bias_correction1 = 1.0f - (float32_t)pow(beta1, (double)opt->step_count); 210 | const float32_t bias_correction2 = 1.0f - (float32_t)pow(beta2, (double)opt->step_count); 211 | 212 | for (size_t i = 0; i < opt->param_count; ++i) { 213 | Tensor *param = opt->params[i]; 214 | if (param->grad == NULL) { 215 | continue; 216 | } 217 | 218 | float32_t *p_data = param->data; 219 | const float32_t *g_data = param->grad->data; 220 | const size_t elem_count = param->size; 221 | 222 | // ensure internal state buffers exist 223 | adam_ensure_buffers(adam, i, elem_count); 224 | 225 | float32_t *m = adam->m_buffers[i]; 226 | float32_t *v = adam->v_buffers[i]; 227 | 228 | for (size_t j = 0; j < elem_count; ++j) { 229 | float32_t g = g_data[j]; 230 | 231 | if (!is_adamw) { 232 | // adam: Add weight decay to gradient (L2 regularization equivalent) 233 | if (weight_decay != 0.0f) { 234 | g += weight_decay * p_data[j]; 235 | } 236 | } 237 | 238 | // update biased first moment estimate: m = beta1 * m + (1 - beta1) * g 239 | m[j] = beta1 * m[j] + (1.0f - beta1) * g; 240 | 241 | // update biased second moment estimate: v = beta2 * v + (1 - beta2) * g^2 242 | v[j] = beta2 * v[j] + (1.0f - beta2) * (g * g); 243 | 244 | // compute bias-corrected moments 245 | float32_t m_hat = m[j] / bias_correction1; 246 | float32_t v_hat = v[j] / bias_correction2; 247 | 248 | // update parameter 249 | p_data[j] -= lr * m_hat / (sqrtf(v_hat) + eps); 250 | 251 | if (is_adamw) { 252 | // AdamW: decay weights directly (decoupled weight decay) 253 | // P_new = P_old - lr * (weight_decay * P_old + other_terms) 254 | if (weight_decay != 0.0f) { 255 | p_data[j] *= (1.0f - lr * weight_decay); 256 | } 257 | } 258 | } 259 | } 260 | } 261 | 262 | static void adam_step(Optimizer *opt) { adam_step_impl(opt, false); } 263 | 264 | static void adamw_step(Optimizer *opt) { adam_step_impl(opt, true); } 265 | 266 | static Optimizer *adam_create_internal(Tensor **params, size_t count, float32_t lr, float32_t beta1, float32_t beta2, float32_t eps, float32_t weight_decay, void (*step_fn)(Optimizer *)) { 267 | assert(params != NULL); 268 | assert(count > 0); 269 | 270 | Optimizer *opt = calloc(1, sizeof(Adam)); 271 | assert(opt != NULL); 272 | Adam *adam = (Adam *)opt; 273 | 274 | adam->base.param_count = count; 275 | adam->base.params = calloc(count, sizeof(Tensor *)); 276 | assert(adam->base.params != NULL); 277 | for (size_t i = 0; i < count; ++i) { 278 | assert(params[i] != NULL); 279 | assert(params[i]->requires_grad && "all optimized tensors must require grad"); 280 | adam->base.params[i] = params[i]; 281 | } 282 | 283 | adam->base.step = step_fn; // this is the only difference between adam and adamw 284 | adam->base.free = adam_free; 285 | adam->base.step_count = 0; 286 | 287 | adam->lr = lr; 288 | adam->beta1 = beta1; 289 | adam->beta2 = beta2; 290 | adam->eps = eps; 291 | adam->weight_decay = weight_decay; 292 | 293 | adam->m_buffers = calloc(count, sizeof(float32_t *)); 294 | assert(adam->m_buffers != NULL); 295 | adam->v_buffers = calloc(count, sizeof(float32_t *)); 296 | assert(adam->v_buffers != NULL); 297 | 298 | return opt; 299 | } 300 | 301 | Optimizer *optimizer_adam_create(Tensor **params, size_t count, float32_t lr, float32_t beta1, float32_t beta2, float32_t eps, float32_t weight_decay) { return adam_create_internal(params, count, lr, beta1, beta2, eps, weight_decay, adam_step); } 302 | 303 | Optimizer *optimizer_adamw_create(Tensor **params, size_t count, float32_t lr, float32_t beta1, float32_t beta2, float32_t eps, float32_t weight_decay) { return adam_create_internal(params, count, lr, beta1, beta2, eps, weight_decay, adamw_step); } 304 | -------------------------------------------------------------------------------- /src/layers.c: -------------------------------------------------------------------------------- 1 | #include "layers.h" 2 | #include "ops/arithmetic.h" 3 | #include "tensor.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // 11 | // base layer implementation 12 | // 13 | 14 | Tensor *layer_forward(Layer *layer, const Tensor *input, bool training) { 15 | assert(layer != NULL); 16 | assert(layer->forward != NULL); 17 | assert(input != NULL); 18 | 19 | return layer->forward(layer, input, training); 20 | } 21 | 22 | void layer_free(Layer *layer) { 23 | if (layer != NULL && layer->free != NULL) { 24 | layer->free(layer); 25 | } 26 | } 27 | 28 | void layer_parameters(Layer *layer, Tensor ***out_params, size_t *out_count) { 29 | assert(out_params != NULL); 30 | assert(out_count != NULL); 31 | 32 | if (layer != NULL && layer->parameters != NULL) { 33 | layer->parameters(layer, out_params, out_count); 34 | return; 35 | } 36 | *out_params = NULL; 37 | *out_count = 0; 38 | } 39 | 40 | // 41 | // linear layer 42 | // 43 | 44 | typedef struct { 45 | Layer base; 46 | Tensor *weight; 47 | Tensor *bias; 48 | uint64_t in_features; 49 | uint64_t out_features; 50 | } LinearLayer; 51 | 52 | static Tensor *linear_forward(Layer *layer, const Tensor *input, bool training) { 53 | assert(layer != NULL); 54 | assert(input != NULL); 55 | (void)training; 56 | 57 | const LinearLayer *l = (const LinearLayer *)layer; 58 | assert(l->weight != NULL); 59 | 60 | // output = input @ weight + bias 61 | Tensor *output = tensor_matmul(input, l->weight); 62 | assert(output != NULL); 63 | 64 | if (l->bias != NULL) { 65 | Tensor *output_bias = tensor_add(output, l->bias); 66 | assert(output_bias != NULL); 67 | tensor_free(output); 68 | output = output_bias; 69 | } 70 | 71 | return output; 72 | } 73 | 74 | static void linear_free(Layer *layer) { 75 | if (layer == NULL) { 76 | return; 77 | } 78 | 79 | LinearLayer *l = (LinearLayer *)layer; 80 | if (l->weight != NULL) { 81 | tensor_free(l->weight); 82 | } 83 | if (l->bias != NULL) { 84 | tensor_free(l->bias); 85 | } 86 | free(l); 87 | } 88 | 89 | static void linear_parameters(Layer *layer, Tensor ***out_params, size_t *out_count) { 90 | assert(layer != NULL); 91 | assert(out_params != NULL); 92 | assert(out_count != NULL); 93 | 94 | LinearLayer *l = (LinearLayer *)layer; 95 | size_t const count = (l->bias != NULL) ? 2 : 1; 96 | 97 | Tensor **params = (Tensor **)malloc(count * sizeof(Tensor *)); 98 | assert(params != NULL && "malloc failed"); 99 | 100 | params[0] = l->weight; 101 | if (l->bias != NULL) { 102 | params[1] = l->bias; 103 | } 104 | *out_params = params; 105 | *out_count = count; 106 | } 107 | 108 | Layer *layer_linear_create(uint64_t in_features, uint64_t features_out, bool bias) { 109 | assert(in_features > 0); 110 | assert(features_out > 0); 111 | assert(in_features < (SIZE_MAX / features_out)); 112 | 113 | LinearLayer *l = (LinearLayer *)calloc(1, sizeof(LinearLayer)); 114 | assert(l != NULL && "calloc failed"); 115 | 116 | l->base.forward = linear_forward; 117 | l->base.free = linear_free; 118 | l->base.parameters = linear_parameters; 119 | l->base.name = "Linear"; 120 | l->in_features = in_features; 121 | l->out_features = features_out; 122 | 123 | float32_t const limit = 1.0f / sqrtf((float32_t)in_features); 124 | uint64_t const w_shape[] = {in_features, features_out}; 125 | uint64_t const w_size = in_features * features_out; 126 | float32_t *w_data = (float32_t *)malloc(w_size * sizeof(float32_t)); 127 | assert(w_data != NULL && "malloc failed"); 128 | 129 | for (size_t i = 0; i < w_size; ++i) { 130 | float32_t const r = (float32_t)rand() / (float32_t)RAND_MAX; 131 | w_data[i] = (r * 2.0f * limit) - limit; 132 | } 133 | 134 | l->weight = tensor_create(w_data, w_shape, 2, true); 135 | assert(l->weight != NULL); 136 | free(w_data); 137 | 138 | if (bias) { 139 | uint64_t const b_shape[] = {features_out}; 140 | l->bias = tensor_zeros(b_shape, 1, true); 141 | assert(l->bias != NULL); 142 | } else { 143 | l->bias = NULL; 144 | } 145 | 146 | return (Layer *)l; 147 | } 148 | 149 | // 150 | // dropout layer 151 | // 152 | 153 | typedef struct { 154 | Layer base; 155 | float32_t p; 156 | } DropoutLayer; 157 | 158 | static Tensor *dropout_forward(const Layer *layer, const Tensor *input, bool training) { 159 | assert(layer != NULL); 160 | assert(input != NULL); 161 | 162 | const DropoutLayer *l = (const DropoutLayer *)layer; 163 | if (!training || l->p <= 0.0f) { 164 | Tensor *t = tensor_create(input->data, input->shape, input->ndim, input->requires_grad); 165 | assert(t != NULL); 166 | return t; 167 | } 168 | 169 | if (l->p >= 1.0f) { 170 | Tensor *t = tensor_zeros(input->shape, input->ndim, input->requires_grad); 171 | assert(t != NULL); 172 | return t; 173 | } 174 | 175 | float32_t const scale = 1.0f / (1.0f - l->p); 176 | float32_t *mask_data = (float32_t *)malloc(input->size * sizeof(float32_t)); 177 | assert(mask_data != NULL && "malloc failed"); 178 | 179 | // mask_data[i] = 0 with probability p, or 1/(1-p) with probability (1-p) 180 | for (size_t i = 0; i < input->size; ++i) { 181 | float32_t const r = (float32_t)rand() / (float32_t)RAND_MAX; 182 | mask_data[i] = (r < (1.0f - l->p)) ? scale : 0.0f; 183 | } 184 | 185 | Tensor *mask = tensor_create(mask_data, input->shape, input->ndim, false); 186 | assert(mask != NULL); 187 | free(mask_data); 188 | 189 | // output = input * mask 190 | Tensor *output = tensor_mul(input, mask); 191 | assert(output != NULL); 192 | tensor_free(mask); 193 | 194 | return output; 195 | } 196 | 197 | static void dropout_free(Layer *layer) { 198 | if (layer == NULL) { 199 | return; 200 | } 201 | free(layer); 202 | } 203 | 204 | static void dropout_parameters(Layer *layer, Tensor ***out_params, size_t *out_count) { 205 | (void)layer; 206 | assert(out_params != NULL); 207 | assert(out_count != NULL); 208 | *out_params = NULL; 209 | *out_count = 0; 210 | } 211 | 212 | Layer *layer_dropout_create(float32_t p) { 213 | assert(p >= 0.0f && p <= 1.0f && "dropout probability must be between 0 and 1"); 214 | 215 | DropoutLayer *l = (DropoutLayer *)calloc(1, sizeof(DropoutLayer)); 216 | assert(l != NULL && "calloc failed"); 217 | 218 | l->base.forward = (Tensor * (*)(Layer *, const Tensor *, bool)) dropout_forward; 219 | l->base.free = dropout_free; 220 | l->base.parameters = dropout_parameters; 221 | l->base.name = "Dropout"; 222 | l->p = p; 223 | return (Layer *)l; 224 | } 225 | 226 | // 227 | // sequential layer 228 | // 229 | 230 | typedef struct { 231 | Layer base; 232 | Layer **layers; 233 | size_t count; 234 | } SequentialLayer; 235 | 236 | static Tensor *sequential_forward(Layer *layer, const Tensor *input, bool training) { 237 | assert(layer != NULL); 238 | assert(input != NULL); 239 | 240 | const SequentialLayer *l = (const SequentialLayer *)layer; 241 | 242 | if (l->count == 0) { 243 | Tensor *t = tensor_create(input->data, input->shape, input->ndim, input->requires_grad); 244 | assert(t != NULL); 245 | return t; 246 | } 247 | 248 | Tensor *current = NULL; 249 | 250 | // first layer 251 | current = layer_forward(l->layers[0], input, training); 252 | assert(current != NULL); 253 | 254 | for (size_t i = 1; i < l->count; ++i) { 255 | // composition: pass output of previous layer as input to next layer 256 | Tensor *next = layer_forward(l->layers[i], current, training); 257 | assert(next != NULL); 258 | tensor_free(current); 259 | current = next; 260 | } 261 | 262 | return current; 263 | } 264 | 265 | static void sequential_free(Layer *layer) { 266 | if (layer == NULL) { 267 | return; 268 | } 269 | 270 | SequentialLayer *l = (SequentialLayer *)layer; 271 | for (size_t i = 0; i < l->count; ++i) { 272 | layer_free(l->layers[i]); 273 | } 274 | free(l->layers); 275 | free(l); 276 | } 277 | 278 | static void sequential_parameters(Layer *layer, Tensor ***out_params, size_t *out_count) { 279 | assert(layer != NULL); 280 | assert(out_params != NULL); 281 | assert(out_count != NULL); 282 | 283 | SequentialLayer *l = (SequentialLayer *)layer; 284 | 285 | size_t total_params = 0; 286 | for (size_t i = 0; i < l->count; ++i) { 287 | Tensor **sub_params; 288 | size_t sub_count; 289 | layer_parameters(l->layers[i], &sub_params, &sub_count); 290 | total_params += sub_count; 291 | if (sub_params != NULL) { 292 | free(sub_params); 293 | } 294 | } 295 | 296 | if (total_params == 0) { 297 | *out_params = NULL; 298 | *out_count = 0; 299 | return; 300 | } 301 | 302 | Tensor **all_params = (Tensor **)malloc(total_params * sizeof(Tensor *)); 303 | assert(all_params != NULL && "malloc failed"); 304 | 305 | size_t current_idx = 0; 306 | for (size_t i = 0; i < l->count; ++i) { 307 | Tensor **sub_params; 308 | size_t sub_count; 309 | layer_parameters(l->layers[i], &sub_params, &sub_count); 310 | for (size_t j = 0; j < sub_count; ++j) { 311 | all_params[current_idx++] = sub_params[j]; 312 | } 313 | if (sub_params != NULL) { 314 | free(sub_params); 315 | } 316 | } 317 | 318 | *out_params = all_params; 319 | *out_count = total_params; 320 | } 321 | 322 | Layer *layer_sequential_create(Layer **layers, size_t count) { 323 | SequentialLayer *l = (SequentialLayer *)calloc(1, sizeof(SequentialLayer)); 324 | assert(l != NULL && "calloc failed"); 325 | 326 | l->base.forward = sequential_forward; 327 | l->base.free = sequential_free; 328 | l->base.parameters = sequential_parameters; 329 | l->base.name = "Sequential"; 330 | 331 | l->layers = (Layer **)malloc(count * sizeof(Layer *)); 332 | assert(l->layers != NULL && "malloc failed"); 333 | if (layers != NULL && count > 0) { 334 | memcpy(l->layers, layers, count * sizeof(Layer *)); 335 | } 336 | l->count = count; 337 | 338 | return (Layer *)l; 339 | } 340 | -------------------------------------------------------------------------------- /src/ops/reductions.c: -------------------------------------------------------------------------------- 1 | #include "ops/reductions.h" 2 | #include "autograd.h" 3 | #include "ops/arithmetic.h" 4 | #include "ops/reductions_backward.h" 5 | #include 6 | #include 7 | #include 8 | 9 | /* 10 | * calculates the output shape and ndim for a tensor reduction 11 | * 12 | * example: 13 | * input shape: [2, 3] 14 | * reduce dim_idx: 0 (first dimension) 15 | * 16 | * a) keepdims == false -> output shape: [3] 17 | * b) keepdims == true -> output shape: [1, 3] 18 | */ 19 | static void reduction_shapes_mut(const Tensor *t, int64_t dim_idx, bool keepdims, uint64_t **out_shape, uint64_t *out_ndim) { 20 | assert(t != NULL); 21 | assert(t->ndim <= MAX_NDIM); 22 | dim_idx = (dim_idx < 0) ? (dim_idx + (int64_t)t->ndim) : dim_idx; // handle negative indices 23 | assert(dim_idx >= 0 && dim_idx < (int64_t)t->ndim && "dim_idx out of bounds"); 24 | 25 | *out_ndim = keepdims ? t->ndim : t->ndim - 1; 26 | assert(*out_ndim <= MAX_NDIM); 27 | *out_shape = NULL; 28 | 29 | if (*out_ndim > 0) { 30 | *out_shape = (uint64_t *)malloc((size_t)(*out_ndim) * sizeof(uint64_t)); 31 | assert(*out_shape != NULL && "malloc failed"); 32 | } 33 | 34 | if (keepdims) { 35 | // keep dim_idx, collapse to size 1 36 | for (uint64_t i = 0; i < t->ndim; i++) { 37 | (*out_shape)[i] = ((int64_t)i == dim_idx) ? 1 : t->shape[i]; 38 | } 39 | } else { 40 | // drop dim_idx entirely 41 | uint64_t k = 0; 42 | for (uint64_t i = 0; i < t->ndim; i++) { 43 | if ((int64_t)i != dim_idx) { 44 | (*out_shape)[k++] = t->shape[i]; 45 | } 46 | } 47 | } 48 | } 49 | 50 | // same as multidim_to_linear, but skips the reduced dimension 51 | static uint64_t reduction_multidim_to_linear(const Tensor *t, const uint64_t *multidim, int64_t dim_idx, bool keepdims) { 52 | assert(t != NULL); 53 | dim_idx = (dim_idx < 0) ? (dim_idx + (int64_t)t->ndim) : dim_idx; 54 | assert(dim_idx >= 0 && dim_idx < (int64_t)t->ndim && "dim_idx out of bounds"); 55 | 56 | uint64_t offset = 0; 57 | for (uint64_t d = 0; d < t->ndim; d++) { 58 | // skip reduced dimension 59 | if ((int64_t)d == dim_idx) { 60 | continue; 61 | } 62 | 63 | assert(multidim != NULL); 64 | // map d (original dim) to index in multidim (reduced shape) 65 | uint64_t idx = keepdims ? d : (d > (uint64_t)dim_idx ? d - 1 : d); 66 | uint64_t idx_val = multidim[idx]; 67 | assert(t->shape != NULL); 68 | assert(idx_val < t->shape[d] && "index out of bounds"); 69 | 70 | offset += idx_val * t->strides[d]; 71 | } 72 | assert((t->size == 0 || offset < t->size) && "offset out of bounds"); 73 | return offset; 74 | } 75 | 76 | /* 77 | * sums elements along a dimension. 78 | * 79 | * example: 80 | * 81 | * shape: [2, 3] 82 | * 83 | * logical: [[1, 2, 3], 84 | * [4, 5, 6]] 85 | * 86 | * operation (sum along dim 0): 87 | * 88 | * a) keepdims = true 89 | * shape: [1, 3] 90 | * result: [[5, 7, 9]] 91 | * 92 | * b) keepdims = false 93 | * shape: [3] 94 | * result: [5, 7, 9] 95 | */ 96 | Tensor *tensor_sum(const Tensor *t, int64_t dim_idx, bool keepdims) { 97 | assert(t != NULL); 98 | assert(t->data != NULL || t->size == 0); 99 | dim_idx = (dim_idx < 0) ? (dim_idx + (int64_t)t->ndim) : dim_idx; 100 | assert(dim_idx >= 0 && dim_idx < (int64_t)t->ndim && "dim_idx out of bounds"); 101 | 102 | uint64_t *new_shape; 103 | uint64_t new_ndim; 104 | reduction_shapes_mut(t, dim_idx, keepdims, &new_shape, &new_ndim); 105 | 106 | Tensor *result = tensor_zeros(new_shape, new_ndim, t->requires_grad); 107 | if (new_shape) { 108 | free(new_shape); 109 | } 110 | 111 | // buffer for current multidim index 112 | uint64_t *curr = (new_ndim > 0) ? (uint64_t *)calloc((size_t)new_ndim, sizeof(uint64_t)) : NULL; 113 | if (new_ndim > 0) { 114 | assert(curr != NULL && "calloc failed"); 115 | } 116 | 117 | for (uint64_t i = 0; i < result->size; i++) { 118 | linear_to_multidim_mut(i, result->shape, new_ndim, curr); 119 | 120 | uint64_t base_offset = reduction_multidim_to_linear(t, curr, dim_idx, keepdims); 121 | 122 | // sum along axis_dim 123 | float32_t sum = 0.0f; 124 | uint64_t axis_dim = (t->shape) ? t->shape[dim_idx] : 1; 125 | assert(axis_dim <= MAX_TENSOR_SIZE && "axis_dim exceeds maximum tensor size"); 126 | uint64_t axis_stride = t->strides[dim_idx]; 127 | for (uint64_t j = 0; j < axis_dim; j++) { 128 | uint64_t offset = base_offset + j * axis_stride; 129 | assert(offset < t->size && "offset out of bounds"); 130 | sum += t->data[offset]; 131 | } 132 | result->data[i] = sum; 133 | } 134 | 135 | if (curr) { 136 | free(curr); 137 | } 138 | 139 | if (result->requires_grad) { 140 | Function *fn = arena_alloc_function(); 141 | fn->apply = sum_backward; 142 | fn->output = result; 143 | fn->num_inputs = 1; 144 | fn->inputs[0] = (Tensor *)t; 145 | fn->pending_count = 0; 146 | 147 | SumContext *ctx = (SumContext *)malloc(sizeof(SumContext)); 148 | assert(ctx != NULL && "malloc failed"); 149 | ctx->dim_idx = dim_idx; 150 | ctx->keepdims = keepdims; 151 | fn->ctx = ctx; 152 | 153 | if (t->grad_fn != NULL) { 154 | t->grad_fn->pending_count++; 155 | } 156 | 157 | result->grad_fn = fn; 158 | } 159 | 160 | return result; 161 | } 162 | 163 | Tensor *tensor_mean(const Tensor *t, int64_t dim_idx, bool keepdims) { 164 | assert(t != NULL); 165 | assert(t->data != NULL || t->size == 0); 166 | dim_idx = (dim_idx < 0) ? (dim_idx + (int64_t)t->ndim) : dim_idx; 167 | assert(dim_idx >= 0 && dim_idx < (int64_t)t->ndim && "dim_idx out of bounds"); 168 | 169 | uint64_t *new_shape; 170 | uint64_t new_ndim; 171 | reduction_shapes_mut(t, dim_idx, keepdims, &new_shape, &new_ndim); 172 | 173 | Tensor *result = tensor_zeros(new_shape, new_ndim, t->requires_grad); 174 | if (new_shape) { 175 | free(new_shape); 176 | } 177 | 178 | uint64_t *curr = (new_ndim > 0) ? (uint64_t *)calloc((size_t)new_ndim, sizeof(uint64_t)) : NULL; 179 | if (new_ndim > 0) { 180 | assert(curr != NULL && "calloc failed"); 181 | } 182 | 183 | uint64_t n = (t->shape) ? t->shape[dim_idx] : 1; 184 | assert(n > 0 && "division by zero: axis dimension is 0"); 185 | float32_t scale = 1.0f / (float32_t)n; 186 | 187 | for (uint64_t i = 0; i < result->size; i++) { 188 | linear_to_multidim_mut(i, result->shape, new_ndim, curr); 189 | 190 | uint64_t base_offset = reduction_multidim_to_linear(t, curr, dim_idx, keepdims); 191 | 192 | float32_t sum = 0.0f; 193 | uint64_t axis_dim = (t->shape) ? t->shape[dim_idx] : 1; 194 | uint64_t axis_stride = t->strides[dim_idx]; 195 | for (uint64_t j = 0; j < axis_dim; j++) { 196 | uint64_t offset = base_offset + j * axis_stride; 197 | assert(offset < t->size && "offset out of bounds"); 198 | sum += t->data[offset]; 199 | } 200 | result->data[i] = sum * scale; 201 | } 202 | 203 | if (curr) { 204 | free(curr); 205 | } 206 | 207 | if (result->requires_grad) { 208 | Function *fn = arena_alloc_function(); 209 | fn->apply = mean_backward; 210 | fn->output = result; 211 | fn->num_inputs = 1; 212 | fn->inputs[0] = (Tensor *)t; 213 | fn->pending_count = 0; 214 | 215 | MeanContext *ctx = (MeanContext *)malloc(sizeof(MeanContext)); 216 | assert(ctx != NULL && "malloc failed"); 217 | ctx->dim_idx = dim_idx; 218 | ctx->keepdims = keepdims; 219 | fn->ctx = ctx; 220 | 221 | if (t->grad_fn != NULL) { 222 | t->grad_fn->pending_count++; 223 | } 224 | 225 | result->grad_fn = fn; 226 | } 227 | 228 | return result; 229 | } 230 | 231 | Tensor *tensor_max(const Tensor *t, int64_t dim_idx, bool keepdims) { 232 | assert(t != NULL); 233 | assert(t->data != NULL || t->size == 0); 234 | dim_idx = (dim_idx < 0) ? (dim_idx + (int64_t)t->ndim) : dim_idx; 235 | assert(dim_idx >= 0 && dim_idx < (int64_t)t->ndim && "dim_idx out of bounds"); 236 | 237 | uint64_t *new_shape; 238 | uint64_t new_ndim; 239 | reduction_shapes_mut(t, dim_idx, keepdims, &new_shape, &new_ndim); 240 | 241 | Tensor *result = tensor_zeros(new_shape, new_ndim, t->requires_grad); 242 | if (new_shape) { 243 | free(new_shape); 244 | } 245 | 246 | uint64_t *curr = (new_ndim > 0) ? (uint64_t *)calloc((size_t)new_ndim, sizeof(uint64_t)) : NULL; 247 | if (new_ndim > 0) { 248 | assert(curr != NULL && "calloc failed"); 249 | } 250 | 251 | for (uint64_t i = 0; i < result->size; i++) { 252 | if (new_ndim > 0) { 253 | linear_to_multidim_mut(i, result->shape, new_ndim, curr); 254 | } 255 | 256 | uint64_t base_offset = reduction_multidim_to_linear(t, curr, dim_idx, keepdims); 257 | 258 | float32_t max_val = -INFINITY; 259 | uint64_t axis_dim = (t->shape) ? t->shape[dim_idx] : 1; 260 | uint64_t axis_stride = t->strides[dim_idx]; 261 | 262 | if (axis_dim > 0) { 263 | assert(base_offset < t->size && "base_offset out of bounds"); 264 | max_val = t->data[base_offset]; 265 | for (uint64_t j = 1; j < axis_dim; j++) { 266 | uint64_t offset = base_offset + j * axis_stride; 267 | assert(offset < t->size && "offset out of bounds"); 268 | float32_t val = t->data[offset]; 269 | if (val > max_val) { 270 | max_val = val; 271 | } 272 | } 273 | } 274 | result->data[i] = max_val; 275 | } 276 | 277 | if (curr) { 278 | free(curr); 279 | } 280 | 281 | if (result->requires_grad) { 282 | Function *fn = arena_alloc_function(); 283 | fn->apply = max_backward; 284 | fn->output = result; 285 | fn->num_inputs = 1; 286 | fn->inputs[0] = (Tensor *)t; 287 | fn->pending_count = 0; 288 | 289 | MaxContext *ctx = (MaxContext *)malloc(sizeof(MaxContext)); 290 | assert(ctx != NULL && "malloc failed"); 291 | ctx->dim_idx = dim_idx; 292 | ctx->keepdims = keepdims; 293 | ctx->output = result; 294 | fn->ctx = ctx; 295 | 296 | if (t->grad_fn != NULL) { 297 | t->grad_fn->pending_count++; 298 | } 299 | 300 | result->grad_fn = fn; 301 | } 302 | 303 | return result; 304 | } 305 | -------------------------------------------------------------------------------- /test/test_arithmetic.c: -------------------------------------------------------------------------------- 1 | #include "ops/arithmetic.h" 2 | #include "tensor.h" 3 | #include "unity.h" 4 | #include 5 | 6 | void setUp(void) {} 7 | void tearDown(void) {} 8 | 9 | static Tensor *create_tensor_1d(float32_t *data, uint64_t size) { 10 | uint64_t shape[] = {size}; 11 | return tensor_create(data, shape, 1, false); 12 | } 13 | 14 | static Tensor *create_scalar(float32_t val) { return tensor_create(&val, NULL, 0, false); } 15 | 16 | void test_tensor_add(void) { 17 | float32_t a_data[] = {1.0f, 2.0f, 3.0f}; 18 | float32_t b_data[] = {4.0f, 5.0f, 6.0f}; 19 | Tensor *a = create_tensor_1d(a_data, 3); 20 | Tensor *b = create_tensor_1d(b_data, 3); 21 | Tensor *c = tensor_add(a, b); 22 | 23 | TEST_ASSERT_EQUAL_UINT64(1, c->ndim); 24 | TEST_ASSERT_EQUAL_UINT64(3, c->shape[0]); 25 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 5.0f, c->data[0]); 26 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 7.0f, c->data[1]); 27 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 9.0f, c->data[2]); 28 | 29 | tensor_free(a); 30 | tensor_free(b); 31 | tensor_free(c); 32 | } 33 | 34 | void test_tensor_broadcasting(void) { 35 | Tensor *a = create_scalar(10.0f); 36 | float32_t b_data[] = {1.0f, 2.0f, 3.0f}; 37 | Tensor *b = create_tensor_1d(b_data, 3); 38 | 39 | Tensor *c = tensor_add(a, b); 40 | 41 | TEST_ASSERT_EQUAL_UINT64(1, c->ndim); 42 | TEST_ASSERT_EQUAL_UINT64(3, c->shape[0]); 43 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 11.0f, c->data[0]); 44 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 12.0f, c->data[1]); 45 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 13.0f, c->data[2]); 46 | 47 | tensor_free(a); 48 | tensor_free(b); 49 | tensor_free(c); 50 | } 51 | 52 | void test_tensor_matmul(void) { 53 | float32_t a_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; 54 | uint64_t a_shape[] = {2, 3}; 55 | Tensor *a = tensor_create(a_data, a_shape, 2, false); 56 | 57 | float32_t b_data[] = {7.0f, 8.0f, 9.0f, 1.0f, 2.0f, 3.0f}; 58 | uint64_t b_shape[] = {3, 2}; 59 | Tensor *b = tensor_create(b_data, b_shape, 2, false); 60 | 61 | Tensor *c = tensor_matmul(a, b); 62 | 63 | TEST_ASSERT_EQUAL_UINT64(2, c->ndim); 64 | TEST_ASSERT_EQUAL_UINT64(2, c->shape[0]); 65 | TEST_ASSERT_EQUAL_UINT64(2, c->shape[1]); 66 | 67 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 31.0f, c->data[0]); 68 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 19.0f, c->data[1]); 69 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 85.0f, c->data[2]); 70 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 55.0f, c->data[3]); 71 | 72 | tensor_free(a); 73 | tensor_free(b); 74 | tensor_free(c); 75 | } 76 | 77 | void test_tensor_sub(void) { 78 | float32_t a_data[] = {5.0f, 6.0f, 7.0f}; 79 | float32_t b_data[] = {1.0f, 2.0f, 3.0f}; 80 | Tensor *a = create_tensor_1d(a_data, 3); 81 | Tensor *b = create_tensor_1d(b_data, 3); 82 | Tensor *c = tensor_sub(a, b); 83 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[0]); 84 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[1]); 85 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[2]); 86 | tensor_free(a); 87 | tensor_free(b); 88 | tensor_free(c); 89 | } 90 | 91 | void test_tensor_mul(void) { 92 | float32_t a_data[] = {2.0f, 3.0f, 4.0f}; 93 | float32_t b_data[] = {2.0f, 3.0f, 4.0f}; 94 | Tensor *a = create_tensor_1d(a_data, 3); 95 | Tensor *b = create_tensor_1d(b_data, 3); 96 | Tensor *c = tensor_mul(a, b); 97 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[0]); 98 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 9.0f, c->data[1]); 99 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 16.0f, c->data[2]); 100 | tensor_free(a); 101 | tensor_free(b); 102 | tensor_free(c); 103 | } 104 | 105 | void test_tensor_div(void) { 106 | float32_t a_data[] = {4.0f, 9.0f, 16.0f}; 107 | float32_t b_data[] = {2.0f, 3.0f, 4.0f}; 108 | Tensor *a = create_tensor_1d(a_data, 3); 109 | Tensor *b = create_tensor_1d(b_data, 3); 110 | Tensor *c = tensor_div(a, b); 111 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[0]); 112 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 3.0f, c->data[1]); 113 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[2]); 114 | tensor_free(a); 115 | tensor_free(b); 116 | tensor_free(c); 117 | } 118 | 119 | void test_tensor_add_broadcast_scalar_lhs(void) { 120 | Tensor *a = create_scalar(10.0f); 121 | float32_t b_data[] = {1.0f, 2.0f, 3.0f}; 122 | Tensor *b = create_tensor_1d(b_data, 3); 123 | Tensor *c = tensor_add(a, b); 124 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 11.0f, c->data[0]); 125 | tensor_free(a); 126 | tensor_free(b); 127 | tensor_free(c); 128 | } 129 | 130 | void test_tensor_add_broadcast_scalar_rhs(void) { 131 | float32_t a_data[] = {1.0f, 2.0f, 3.0f}; 132 | Tensor *a = create_tensor_1d(a_data, 3); 133 | Tensor *b = create_scalar(10.0f); 134 | Tensor *c = tensor_add(a, b); 135 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 11.0f, c->data[0]); 136 | tensor_free(a); 137 | tensor_free(b); 138 | tensor_free(c); 139 | } 140 | 141 | void test_tensor_sub_broadcast_scalar(void) { 142 | float32_t a_data[] = {10.0f, 11.0f}; 143 | Tensor *a = create_tensor_1d(a_data, 2); 144 | Tensor *b = create_scalar(5.0f); 145 | Tensor *c = tensor_sub(a, b); 146 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 5.0f, c->data[0]); 147 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 6.0f, c->data[1]); 148 | tensor_free(a); 149 | tensor_free(b); 150 | tensor_free(c); 151 | } 152 | 153 | void test_tensor_mul_broadcast_scalar(void) { 154 | float32_t a_data[] = {2.0f, 3.0f}; 155 | Tensor *a = create_tensor_1d(a_data, 2); 156 | Tensor *b = create_scalar(2.0f); 157 | Tensor *c = tensor_mul(a, b); 158 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[0]); 159 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 6.0f, c->data[1]); 160 | tensor_free(a); 161 | tensor_free(b); 162 | tensor_free(c); 163 | } 164 | 165 | void test_tensor_div_broadcast_scalar(void) { 166 | float32_t a_data[] = {10.0f, 20.0f}; 167 | Tensor *a = create_tensor_1d(a_data, 2); 168 | Tensor *b = create_scalar(2.0f); 169 | Tensor *c = tensor_div(a, b); 170 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 5.0f, c->data[0]); 171 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 10.0f, c->data[1]); 172 | tensor_free(a); 173 | tensor_free(b); 174 | tensor_free(c); 175 | } 176 | 177 | void test_tensor_add_zero(void) { 178 | float32_t a_data[] = {1.0f, 2.0f}; 179 | Tensor *a = create_tensor_1d(a_data, 2); 180 | Tensor *b = create_scalar(0.0f); 181 | Tensor *c = tensor_add(a, b); 182 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, c->data[0]); 183 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[1]); 184 | tensor_free(a); 185 | tensor_free(b); 186 | tensor_free(c); 187 | } 188 | 189 | void test_tensor_sub_zero(void) { 190 | float32_t a_data[] = {1.0f, 2.0f}; 191 | Tensor *a = create_tensor_1d(a_data, 2); 192 | Tensor *b = create_scalar(0.0f); 193 | Tensor *c = tensor_sub(a, b); 194 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, c->data[0]); 195 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[1]); 196 | tensor_free(a); 197 | tensor_free(b); 198 | tensor_free(c); 199 | } 200 | 201 | void test_tensor_mul_zero(void) { 202 | float32_t a_data[] = {1.0f, 2.0f}; 203 | Tensor *a = create_tensor_1d(a_data, 2); 204 | Tensor *b = create_scalar(0.0f); 205 | Tensor *c = tensor_mul(a, b); 206 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, c->data[0]); 207 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, c->data[1]); 208 | tensor_free(a); 209 | tensor_free(b); 210 | tensor_free(c); 211 | } 212 | 213 | void test_tensor_mul_one(void) { 214 | float32_t a_data[] = {1.0f, 2.0f}; 215 | Tensor *a = create_tensor_1d(a_data, 2); 216 | Tensor *b = create_scalar(1.0f); 217 | Tensor *c = tensor_mul(a, b); 218 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, c->data[0]); 219 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[1]); 220 | tensor_free(a); 221 | tensor_free(b); 222 | tensor_free(c); 223 | } 224 | 225 | void test_tensor_mul_neg_one(void) { 226 | float32_t a_data[] = {1.0f, 2.0f}; 227 | Tensor *a = create_tensor_1d(a_data, 2); 228 | Tensor *b = create_scalar(-1.0f); 229 | Tensor *c = tensor_mul(a, b); 230 | TEST_ASSERT_FLOAT_WITHIN(1e-6, -1.0f, c->data[0]); 231 | TEST_ASSERT_FLOAT_WITHIN(1e-6, -2.0f, c->data[1]); 232 | tensor_free(a); 233 | tensor_free(b); 234 | tensor_free(c); 235 | } 236 | 237 | void test_tensor_div_one(void) { 238 | float32_t a_data[] = {1.0f, 2.0f}; 239 | Tensor *a = create_tensor_1d(a_data, 2); 240 | Tensor *b = create_scalar(1.0f); 241 | Tensor *c = tensor_div(a, b); 242 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, c->data[0]); 243 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[1]); 244 | tensor_free(a); 245 | tensor_free(b); 246 | tensor_free(c); 247 | } 248 | 249 | void test_tensor_matmul_identity(void) { 250 | float32_t a_data[] = {1.0f, 2.0f, 3.0f, 4.0f}; 251 | uint64_t a_shape[] = {2, 2}; 252 | Tensor *a = tensor_create(a_data, a_shape, 2, false); 253 | 254 | float32_t b_data[] = {1.0f, 0.0f, 0.0f, 1.0f}; 255 | uint64_t b_shape[] = {2, 2}; 256 | Tensor *b = tensor_create(b_data, b_shape, 2, false); 257 | 258 | Tensor *c = tensor_matmul(a, b); 259 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 1.0f, c->data[0]); 260 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 2.0f, c->data[1]); 261 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 3.0f, c->data[2]); 262 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 4.0f, c->data[3]); 263 | 264 | tensor_free(a); 265 | tensor_free(b); 266 | tensor_free(c); 267 | } 268 | 269 | void test_tensor_matmul_zero(void) { 270 | float32_t a_data[] = {1.0f, 2.0f, 3.0f, 4.0f}; 271 | uint64_t a_shape[] = {2, 2}; 272 | Tensor *a = tensor_create(a_data, a_shape, 2, false); 273 | 274 | float32_t b_data[] = {0.0f, 0.0f, 0.0f, 0.0f}; 275 | uint64_t b_shape[] = {2, 2}; 276 | Tensor *b = tensor_create(b_data, b_shape, 2, false); 277 | 278 | Tensor *c = tensor_matmul(a, b); 279 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, c->data[0]); 280 | TEST_ASSERT_FLOAT_WITHIN(1e-6, 0.0f, c->data[1]); 281 | 282 | tensor_free(a); 283 | tensor_free(b); 284 | tensor_free(c); 285 | } 286 | 287 | int main(void) { 288 | UNITY_BEGIN(); 289 | RUN_TEST(test_tensor_add); 290 | RUN_TEST(test_tensor_broadcasting); 291 | RUN_TEST(test_tensor_matmul); 292 | RUN_TEST(test_tensor_sub); 293 | RUN_TEST(test_tensor_mul); 294 | RUN_TEST(test_tensor_div); 295 | RUN_TEST(test_tensor_add_broadcast_scalar_lhs); 296 | RUN_TEST(test_tensor_add_broadcast_scalar_rhs); 297 | RUN_TEST(test_tensor_sub_broadcast_scalar); 298 | RUN_TEST(test_tensor_mul_broadcast_scalar); 299 | RUN_TEST(test_tensor_div_broadcast_scalar); 300 | RUN_TEST(test_tensor_add_zero); 301 | RUN_TEST(test_tensor_sub_zero); 302 | RUN_TEST(test_tensor_mul_zero); 303 | RUN_TEST(test_tensor_mul_one); 304 | RUN_TEST(test_tensor_mul_neg_one); 305 | RUN_TEST(test_tensor_div_one); 306 | RUN_TEST(test_tensor_matmul_identity); 307 | RUN_TEST(test_tensor_matmul_zero); 308 | return UNITY_END(); 309 | } 310 | -------------------------------------------------------------------------------- /docs/autodiff.py: -------------------------------------------------------------------------------- 1 | # author: sueszli 2 | # reviewer: ivan yashchuk 3 | # /// script 4 | # requires-python = ">=3.11" 5 | # dependencies = [ 6 | # "astunparse==1.6.3", 7 | # "jax==0.4.20", 8 | # "click==8.1.7", 9 | # "torch==2.1.1", 10 | # ] 11 | # /// 12 | 13 | import ast 14 | import astunparse 15 | 16 | import os 17 | from math import exp, cos, sin 18 | from collections import namedtuple 19 | from numbers import Number 20 | 21 | import unittest 22 | import jax.numpy as jnp 23 | from jax import grad 24 | import torch 25 | 26 | 27 | # idea: instead of splitting up types and methods as done below, we could also have 28 | # a single abstract data type and overload primitive operators through __add__, __mul__, etc. 29 | DualNum = namedtuple("DualNum", ["value", "derivative"]) 30 | 31 | 32 | # operators are incomplete and only serve the purpose of our single example function f(x) 33 | # see: 34 | # - https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers 35 | # - https://youtu.be/5F6roh4pmJU?si=LW1ZKKvaGdl9shCz&t=555 36 | class DualNumOps: 37 | @staticmethod 38 | def custom_exp(inp: DualNum): 39 | return DualNum(exp(inp.value), exp(inp.value) * inp.derivative) 40 | 41 | @staticmethod 42 | def custom_cos(inp: DualNum): 43 | return DualNum(cos(inp.value), -sin(inp.value) * inp.derivative) 44 | 45 | @staticmethod 46 | def custom_add(inp1: DualNum, inp2: DualNum): 47 | if not isinstance(inp1, DualNum): 48 | inp1 = DualNum(inp1, 0.0) 49 | if not isinstance(inp2, DualNum): 50 | inp2 = DualNum(inp2, 0.0) 51 | 52 | return DualNum(inp1.value + inp2.value, inp1.derivative + inp2.derivative) 53 | 54 | @staticmethod 55 | def custom_mul(inp1: DualNum, inp2: DualNum): 56 | return DualNum( 57 | inp1.value * inp2.value, 58 | inp1.derivative * inp2.value + inp2.derivative * inp1.value, 59 | ) 60 | 61 | @staticmethod 62 | def custom_pow(inp: DualNum, k: Number): 63 | assert isinstance(k, int), "k must be an integer" 64 | k_int = int(k) 65 | return DualNum(inp.value**k, inp.derivative * k * inp.value ** (k_int - 1)) 66 | 67 | 68 | # update the abstract syntax tree 69 | # see: 70 | # - https://docs.python.org/3/library/ast.html 71 | # - https://greentreesnakes.readthedocs.io/en/latest/index.html 72 | # - https://greentreesnakes.readthedocs.io/en/latest/manipulating.html 73 | def transform(fstr: str) -> str: 74 | class CustomOpTransformer(ast.NodeTransformer): 75 | def visit_FunctionDef(self, node): 76 | node.name = "f_forward_ad" 77 | node.args.args[0].annotation = ast.Name(id="DualNum", ctx=ast.Load()) 78 | node.returns = ast.Name(id="DualNum", ctx=ast.Load()) 79 | 80 | self.generic_visit(node) # visit children 81 | return node 82 | 83 | def visit_Call(self, node): 84 | if isinstance(node.func, ast.Name): 85 | if node.func.id == "exp": 86 | node.func.id = "DualNumOps.custom_exp" 87 | elif node.func.id == "cos": 88 | node.func.id = "DualNumOps.custom_cos" 89 | 90 | self.generic_visit(node) # visit children 91 | return node 92 | 93 | def visit_BinOp(self, node): 94 | if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant): 95 | return node # both are constants, reached leaf 96 | 97 | if isinstance(node.op, ast.Add): 98 | node = ast.Call(func=ast.Name(id="DualNumOps.custom_add", ctx=ast.Load()), args=[node.left, node.right], keywords=[]) 99 | elif isinstance(node.op, ast.Mult): 100 | node = ast.Call(func=ast.Name(id="DualNumOps.custom_mul", ctx=ast.Load()), args=[node.left, node.right], keywords=[]) 101 | elif isinstance(node.op, ast.Pow): 102 | node = ast.Call(func=ast.Name(id="DualNumOps.custom_pow", ctx=ast.Load()), args=[node.left, node.right], keywords=[]) 103 | 104 | self.generic_visit(node) # visit children 105 | return node 106 | 107 | tree = ast.parse(fstr) 108 | transformer = CustomOpTransformer() 109 | tree = transformer.visit(tree) # call root 110 | 111 | new_fstr = astunparse.unparse(tree) 112 | return new_fstr 113 | 114 | 115 | class TestCases(unittest.TestCase): 116 | jax_func = lambda x: jnp.exp(x) ** 3 + jnp.cos(x) * x + 10**2 117 | torch_func = lambda x: torch.exp(x) ** 3 + torch.cos(x) * x + 10**2 118 | 119 | def test_jax_1(self): 120 | testarg = 2.4 121 | res_jax = grad(TestCases.jax_func)(testarg) 122 | res_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 123 | res_match = jnp.allclose(res_jax, res_custom) 124 | self.assertTrue(res_match) 125 | 126 | def test_jax_2(self): 127 | testarg = 61.78 128 | res_jax = grad(TestCases.jax_func)(testarg) 129 | res_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 130 | res_match = jnp.allclose(res_jax, res_custom) 131 | self.assertTrue(res_match) 132 | 133 | def test_jax_3(self): 134 | testarg = 26.42 135 | res_jax = grad(TestCases.jax_func)(testarg) 136 | res_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 137 | res_match = jnp.allclose(res_jax, res_custom) 138 | self.assertTrue(res_match) 139 | 140 | def test_torch_1(self): 141 | testarg = 2.4 142 | x = torch.tensor(testarg, requires_grad=True) 143 | TestCases.torch_func(x).backward() 144 | result_pytorch = x.grad.item() if x.grad is not None else None 145 | result_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 146 | torch_match = torch.isclose(torch.tensor(result_pytorch), torch.tensor(result_custom)) 147 | self.assertTrue(torch_match) 148 | 149 | def test_torch_2(self): 150 | testarg = 61.78 151 | x = torch.tensor(testarg, requires_grad=True) 152 | TestCases.torch_func(x).backward() 153 | result_pytorch = x.grad.item() if x.grad is not None else None 154 | result_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 155 | torch_match = torch.isclose(torch.tensor(result_pytorch), torch.tensor(result_custom)) 156 | self.assertTrue(torch_match) 157 | 158 | def test_torch_3(self): 159 | testarg = 26.42 160 | x = torch.tensor(testarg, requires_grad=True) 161 | TestCases.torch_func(x).backward() 162 | result_pytorch = x.grad.item() if x.grad is not None else None 163 | result_custom = f_forward_ad(DualNum(testarg, 1.0)).derivative # type: ignore 164 | torch_match = torch.isclose(torch.tensor(result_pytorch), torch.tensor(result_custom)) 165 | self.assertTrue(torch_match) 166 | 167 | 168 | f_str = """ 169 | def f(x): 170 | return exp(x)**3 + cos(x) * x + 10**2 171 | """ 172 | 173 | # the following is the AST of `f_str` as dumped by `ast.dump(tree)` for reference 174 | """ 175 | Module( 176 | body=[ 177 | FunctionDef( 178 | body=[ 179 | Return( 180 | value=BinOp( 181 | 182 | left=BinOp( <------ `exp(x)**3 + cos(x) * x` AS LEFT 183 | left=BinOp( <------ `exp(x)**3` AS LEFT 184 | left=Call( 185 | func=Name(id='exp', ctx=Load()), 186 | name='f', 187 | args=arguments( 188 | posonlyargs=[], 189 | args=[arg(arg='x')], 190 | kwonlyargs=[], 191 | kw_defaults=[], 192 | defaults=[] 193 | ), 194 | args=[Name(id='x', ctx=Load())], 195 | keywords=[] 196 | ), 197 | op=Pow(), 198 | right=Constant(value=3) 199 | ), 200 | 201 | op=Add(), <------ LEFT + RIGHT 202 | 203 | right=BinOp( <------ `cos(x) * x` AS RIGHT 204 | left=Call( 205 | func=Name(id='cos', ctx=Load()), 206 | args=[Name(id='x', ctx=Load())], 207 | keywords=[] 208 | ), 209 | op=Mult(), 210 | right=Name(id='x', ctx=Load()) 211 | ) 212 | ), 213 | 214 | op=Add(), <------ LEFT + RIGHT 215 | 216 | right=BinOp( <------ `10**2` AS RIGHT 217 | left=Constant(value=10), 218 | op=Pow(), 219 | right=Constant(value=2) 220 | ) 221 | ) 222 | ) 223 | ], 224 | decorator_list=[], 225 | type_params=[] 226 | ) 227 | ], 228 | type_ignores=[] 229 | ) 230 | """ 231 | 232 | if __name__ == "__main__": 233 | os.system("cls" if os.name == "nt" else "clear") 234 | os.system("uname -a") if os.name == "posix" else os.system("systeminfo") 235 | 236 | # bring f(x) into local namespace 237 | exec(f_str) 238 | assert "f" in locals(), "f is not defined" 239 | assert f(2) == exp(2) ** 3 + cos(2) * 2 + 10**2 # type: ignore 240 | 241 | # get f_forward_ad(x) from f(x) 242 | python_str = transform(f_str) 243 | print(f"\n\nOriginal function:\n\033[92m{f_str}\033[0m\nTransformed function:\033[92m{python_str}\033[0m\n\n") 244 | 245 | # bring f_forward_ad(x) into local namespace 246 | exec(python_str) 247 | assert "f_forward_ad" in locals(), "f_forward_ad is not defined" 248 | 249 | # run tests 250 | unittest.main() 251 | 252 | 253 | # Original function: 254 | # 255 | # def f(x): 256 | # return exp(x)**3 + cos(x) * x + 10**2 257 | # 258 | # Transformed function: 259 | # 260 | # def f_forward_ad(x: DualNum) -> DualNum: 261 | # return DualNumOps.custom_add(DualNumOps.custom_add(DualNumOps.custom_pow(DualNumOps.custom_exp(x), 3), DualNumOps.custom_mul(DualNumOps.custom_cos(x), x)), (10 ** 2)) 262 | # 263 | # ---------------------------------------------------------------------- 264 | # Ran 6 tests in 0.388s 265 | # 266 | # OK 267 | -------------------------------------------------------------------------------- /src/tensor.c: -------------------------------------------------------------------------------- 1 | #include "tensor.h" 2 | #include "ops/arithmetic.h" 3 | #include "utils/aligned_alloc.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // 13 | // memory management 14 | // 15 | 16 | static uint64_t size(const uint64_t *shape, uint64_t ndim) { 17 | assert(ndim <= MAX_NDIM); 18 | 19 | // scalar 20 | if (ndim == 0) { 21 | return 1; 22 | } 23 | assert(shape != NULL); 24 | 25 | // product of dimensions 26 | uint64_t size = 1; 27 | for (uint64_t i = 0; i < ndim; i++) { 28 | assert(shape[i] == 0 || size <= MAX_TENSOR_SIZE / shape[i]); 29 | size *= shape[i]; 30 | } 31 | assert(size <= MAX_TENSOR_SIZE && "tensor size exceeds maximum"); 32 | return size; 33 | } 34 | 35 | /* 36 | * strides: how many elements to skip in flat memory to move 1 step along each dimension. 37 | * converts multi-dim index to linear offset: `offset = sum_i (index[i] * strides[i])` 38 | * 39 | * example: 40 | * 41 | * shape: [2, 3] (2 rows, 3 cols) 42 | * 43 | * memory: [a, b, c, d, e, f] 44 | * 45 | * logical: [[a, b, c], row 0 46 | * [d, e, f]] row 1 47 | * 48 | * algorithm (iterate backward through dimensions): 49 | * i=1: strides[1] = 1 (within a row, move 1 elem) 50 | * stride = 1 * 3 = 3 51 | * i=0: strides[0] = 3 (between rows, move 3 elems) 52 | * stride = 3 * 2 = 6 53 | * 54 | * result: strides = [3, 1] 55 | * 56 | * access examples: 57 | * element[row=1, col=2]: offset = 1*3 + 2*1 = 5 -> data[5] = f 58 | */ 59 | static uint64_t *strides(const uint64_t *shape, uint64_t ndim) { 60 | assert(ndim <= MAX_NDIM); 61 | 62 | if (ndim == 0) { 63 | return NULL; 64 | } 65 | assert(shape != NULL); 66 | 67 | uint64_t *strides = (uint64_t *)malloc((size_t)ndim * sizeof(uint64_t)); 68 | assert(strides != NULL && "malloc failed"); 69 | 70 | uint64_t stride = 1; 71 | for (int64_t i = (int64_t)ndim - 1; i >= 0; i--) { 72 | strides[i] = stride; 73 | if (shape[i] && stride > UINT64_MAX / shape[i]) { 74 | free(strides); 75 | assert(false && "stride calculation overflow"); 76 | } 77 | stride *= shape[i]; 78 | } 79 | 80 | return strides; 81 | } 82 | 83 | Tensor *tensor_create(const float32_t *data, const uint64_t *shape, uint64_t ndim, bool requires_grad) { 84 | assert(ndim <= MAX_NDIM); 85 | assert(shape != NULL || ndim == 0); 86 | 87 | Tensor *t = (Tensor *)malloc(sizeof(Tensor)); 88 | assert(t != NULL && "malloc failed"); 89 | 90 | t->ndim = ndim; 91 | t->requires_grad = requires_grad; 92 | t->grad = NULL; 93 | t->grad_fn = NULL; 94 | t->ref_count = 1; 95 | t->shape = NULL; 96 | t->strides = NULL; 97 | 98 | // scalar 99 | if (ndim == 0) { 100 | t->size = 1; 101 | t->data = (float32_t *)safe_aligned_alloc(sizeof(float32_t)); 102 | assert((uintptr_t)t->data % CACHELINE_SIZE == 0 && "data is not properly aligned"); 103 | if (data) { 104 | t->data[0] = data[0]; 105 | } else { 106 | t->data[0] = 0.0f; 107 | } 108 | assert(t->ndim == 0); 109 | assert(t->size == 1); 110 | assert(t->data != NULL); 111 | return t; 112 | } 113 | 114 | t->shape = (uint64_t *)malloc((size_t)ndim * sizeof(uint64_t)); 115 | assert(t->shape != NULL && "malloc failed"); 116 | memcpy(t->shape, shape, (size_t)ndim * sizeof(uint64_t)); 117 | 118 | t->strides = strides(t->shape, ndim); 119 | 120 | t->size = size(shape, ndim); 121 | 122 | // zero-size 123 | if (t->size == 0) { 124 | t->data = NULL; 125 | assert(t->ndim == ndim); 126 | assert(t->size == 0); 127 | assert(t->data == NULL); 128 | return t; 129 | } 130 | 131 | // data allocation must be aligned 132 | t->data = (float32_t *)safe_aligned_alloc(t->size * sizeof(float32_t)); 133 | assert((uintptr_t)t->data % CACHELINE_SIZE == 0 && "data is not properly aligned"); 134 | if (data) { 135 | memcpy(t->data, data, (size_t)t->size * sizeof(float32_t)); 136 | } else { 137 | memset(t->data, 0, (size_t)t->size * sizeof(float32_t)); 138 | } 139 | 140 | assert(t->ndim == ndim); 141 | assert(t->size == size(shape, ndim)); 142 | assert(t->data != NULL || t->size == 0); 143 | return t; 144 | } 145 | 146 | Tensor *tensor_zeros(const uint64_t *shape, uint64_t ndim, bool requires_grad) { return tensor_create(NULL, shape, ndim, requires_grad); } 147 | 148 | void tensor_free(Tensor *t) { 149 | if (!t) { 150 | return; 151 | } 152 | if (t->data) { 153 | free(t->data); 154 | } 155 | if (t->shape) { 156 | free(t->shape); 157 | } 158 | if (t->strides) { 159 | free(t->strides); 160 | } 161 | if (t->grad) { 162 | tensor_free(t->grad); 163 | } 164 | free(t); 165 | } 166 | 167 | Tensor *tensor_retain(Tensor *t) { 168 | if (t) { 169 | t->ref_count++; 170 | } 171 | return t; 172 | } 173 | 174 | void tensor_release(Tensor *t) { 175 | if (!t) { 176 | return; 177 | } 178 | assert(t->ref_count > 0 && "ref_count already zero"); 179 | t->ref_count--; 180 | if (t->ref_count == 0) { 181 | tensor_free(t); 182 | } 183 | } 184 | 185 | void tensor_zero_grad(Tensor *t) { 186 | if (!t) { 187 | return; 188 | } 189 | if (t->grad) { 190 | tensor_free(t->grad); 191 | t->grad = NULL; 192 | } 193 | } 194 | 195 | /* 196 | * converts a linear offset to multi-dimensional indices. 197 | * mutates out_multidim array. 198 | * 199 | * example: 200 | * 201 | * shape: [2, 3] (2 rows, 3 cols) 202 | * 203 | * memory: [a, b, c, d, e, f] 204 | * 205 | * logical: [[a, b, c], row 0 206 | * [d, e, f]] row 1 207 | * 208 | * algorithm (right-to-left): 209 | * given: lin=4 210 | * 211 | * d=1 (rightmost/col): 4 % 3 = 1 -> col 1 212 | * 4 / 3 = 1 -> carry to next dimension 213 | * 214 | * d=0 (leftmost/row): 1 % 2 = 1 -> row 1 215 | * 1 / 2 = 0 -> done 216 | * 217 | * result: [1, 1] -> element 'e' 218 | */ 219 | void linear_to_multidim_mut(uint64_t lin, const uint64_t *shape, uint64_t ndim, uint64_t *out_multidim) { 220 | assert(shape != NULL || ndim == 0); 221 | assert(out_multidim != NULL || ndim == 0); 222 | assert(ndim <= MAX_NDIM); 223 | 224 | uint64_t carry = lin; 225 | for (int64_t d = (int64_t)ndim - 1; d >= 0; d--) { 226 | out_multidim[d] = carry % shape[d]; 227 | carry /= shape[d]; 228 | } 229 | } 230 | 231 | /* 232 | * converts multi-dimensional coordinates to a linear memory offset. 233 | * 234 | * example: requesting element at [1, 2] from a tensor 235 | * 236 | * shape: [2, 3] (2 rows, 3 cols) 237 | * 238 | * memory: [a, b, c] 239 | * 240 | * is equivalent to: 241 | * [[a, b, c], // row 0 242 | * [a, b, c]] // row 1 (implicit broadcast) 243 | * 244 | * because the first dimension's size is 1, it behaves as if its shape 245 | * were [X, 3] for any X >= 1. implicitly broadcasting. 246 | * 247 | * calculation (right-aligned dimensions): 248 | * 249 | * - dimension 0 (rows): 250 | * - source size is 1. target requests 1. 251 | * - rule: source dimension size is 1 => broadcast! use index 0. 252 | * - offset += 0 * stride[0] (3) = 0 253 | * 254 | * - dimension 1 (columns): 255 | * - source size is 3. target requests 2. 256 | * - rule: source dimension size > 1 => no broadcast. use index 2. 257 | * - offset += 2 * stride[1] (1) = 2 258 | * 259 | * result: offset = 2 (value 'C'). 260 | */ 261 | uint64_t multidim_to_linear(const uint64_t *target, uint64_t target_ndim, const uint64_t *shape, uint64_t ndim, const uint64_t *strides) { 262 | assert(target != NULL || target_ndim == 0); 263 | assert(shape != NULL || ndim == 0); 264 | assert(strides != NULL || ndim == 0); 265 | assert(target_ndim >= ndim); 266 | assert(ndim <= MAX_NDIM); 267 | 268 | uint64_t offset = 0; 269 | for (uint64_t d = 0; d < ndim; d++) { 270 | uint64_t target_dim = d + (target_ndim - ndim); // align right 271 | uint64_t idx = (shape[d] == 1) ? 0 : target[target_dim]; 272 | offset += idx * strides[d]; 273 | } 274 | return offset; 275 | } 276 | 277 | // 278 | // utils 279 | // 280 | 281 | static void tensor_print_recursive(const Tensor *t, uint64_t dim, uint64_t offset, uint64_t indent) { 282 | assert(dim <= MAX_NDIM && "recursion depth exceeds maximum"); 283 | assert(t != NULL); 284 | 285 | if (dim == t->ndim) { 286 | assert(offset < t->size && "offset out of bounds"); 287 | printf("%f", t->data[offset]); 288 | return; 289 | } 290 | 291 | if (dim == t->ndim - 1) { 292 | printf("["); 293 | for (uint64_t i = 0; i < t->shape[dim]; i++) { 294 | uint64_t data_offset = offset + i * t->strides[dim]; 295 | assert(data_offset < t->size && "offset out of bounds"); 296 | printf("%f", t->data[data_offset]); 297 | if (i < t->shape[dim] - 1) { 298 | printf(", "); 299 | } 300 | } 301 | printf("]"); 302 | return; 303 | } 304 | 305 | printf("["); 306 | for (uint64_t i = 0; i < t->shape[dim]; i++) { 307 | if (i > 0) { 308 | for (uint64_t j = 0; j < indent; j++) { 309 | printf(" "); 310 | } 311 | } 312 | tensor_print_recursive(t, dim + 1, offset + i * t->strides[dim], indent + 1); 313 | 314 | if (i < t->shape[dim] - 1) { 315 | printf(","); 316 | uint64_t newlines = t->ndim - dim - 1; 317 | for (uint64_t k = 0; k < newlines; k++) { 318 | printf("\n"); 319 | } 320 | } 321 | } 322 | printf("]"); 323 | } 324 | 325 | void tensor_print(const Tensor *t) { 326 | if (!t) { 327 | printf("Tensor(NULL)\n"); 328 | return; 329 | } 330 | printf("Tensor(shape=["); 331 | if (t->shape) { 332 | for (uint64_t i = 0; i < t->ndim; i++) { 333 | printf("%" PRIu64 "%s", t->shape[i], i < t->ndim - 1 ? ", " : ""); 334 | } 335 | } 336 | printf("], size=%" PRIu64 ", requires_grad=%s)\n", t->size, t->requires_grad ? "true" : "false"); 337 | 338 | if (t->data) { 339 | const uint64_t max_size = 1000; 340 | if (t->size <= max_size) { 341 | printf("Data: "); 342 | tensor_print_recursive(t, 0, 0, 6); 343 | printf("\n"); 344 | } else { 345 | printf("Data: ... (size > 1000)\n"); 346 | } 347 | } 348 | } 349 | 350 | // use stride to get offset in flat data array 351 | Tensor *tensor_get(const Tensor *t, const uint64_t *multidim) { 352 | if (!t) { 353 | return NULL; 354 | } 355 | assert(multidim != NULL); 356 | assert(t->data != NULL); 357 | assert(t->ndim <= MAX_NDIM); 358 | 359 | uint64_t offset = multidim_to_linear(multidim, t->ndim, t->shape, t->ndim, t->strides); 360 | assert(offset < t->size); 361 | 362 | // scalar tensor with 0 dim 363 | Tensor *val = tensor_create(&t->data[offset], NULL, 0, t->requires_grad); 364 | return val; 365 | } 366 | --------------------------------------------------------------------------------