├── README.md ├── base ├── base_function.py └── functions_class.py ├── eval.py ├── functions.py ├── utils.py └── variables.py /README.md: -------------------------------------------------------------------------------- 1 | # Autograd in PyTorch 2 | 3 | ## This is a re-implementation of PyTorch's autograd (`torch.autograd`). 4 | 5 | As you know, Pytorch contains 3 major components: 6 | + `tensor` can be seen as a replacement of `numpy` for both GPU and CPU because it has unified API. 7 | + `autograd.Variable` is an automatic differentiation tool given a forward formulation. 8 | + `nn`: is a deep learning framework build based on `tensor` and `autograd`. 9 | 10 | This project is aimed to re-implement the `autograd` part because: 11 | + Pytorch's autograd functions are mostly implemented in `C/C++` (for performance purposes) so it is much harder for create new autograd's function. 12 | + Instead of building the backward function of complex function, we can build just the forward function, the backward function will be automatically built based on autograd. 13 | + Understand clearly how autograd works is very important for serious deep learning learner. 14 | + You don't need to build the complex computational graph to do back-propagation in deep learning! 15 | 16 | The reasons why we choose pytorch tensor over numpy array: 17 | + Pytorch tensor supports GPU. 18 | + We can easily validate your own autograd function with Pytorch's autograd with the same API. 19 | + Getting familiar ourselves with Pytorch's API is a big plus for later deep learning project. 20 | 21 | Requirement: `pytorch=0.4` or later 22 | 23 | ---- 24 | 25 | Overview of autograd mechanism: 26 | 27 | ![](https://cdn-images-1.medium.com/max/1600/1*wE1f2i7L8QRw8iuVx5mOpw.png) 28 | 29 | In `variables.py`: 30 | + `data` is the pytorch tensor 31 | + `grad` is the Variable contains the gradient of current Variable 32 | + `grad_fn` is the Function creating the current Variable, if `grad_fn = None`, it is the leaf node of the graph 33 | + `backward()` receives the `in_grad` from the function that current Variable is used, and call the function that creates that Variable. 34 | + Other methods for overloading behavior of Variable 35 | 36 | In `base\base_function.py` and `base\functions_class.py`: 37 | + `forward()` receives the `input` and compute the `output` 38 | + `backward()` reveives the `in_grad` from the Variable it creates and compute the `out_grad` for each its `input` Variable component 39 | 40 | 41 | In `functions.py`: 42 | + Interface for building functions 43 | + Contains the most common of Pytorch tensor's function (not all), but enough for building any graph in Deep Learning 44 | 45 | In `eval.py`: 46 | + `check_example()`: test case for checking your own autograd function whether it is working properly or not 47 | 48 | --- 49 | 50 | If you want to define your own autograd functions, do the following step: 51 | + If your autograd function are made up of existing functions (in `functions.py`) or classes (in `functions_class.py`), just reuse them. 52 | + First, define your new function class in `functions_class.py`, the main purpose of each function: 53 | - `forward`: store `input` to `self.input`, compute `output` and store it to `self.output`, and store other useful information for backward pass in `self.context` dictionary 54 | - `backward`: store `in_grad` to `self.in_grad`, compute `grad` for each `input`, multiply `in_grad` with `grad` to get `out_grad` (according to chain rule) and call each `input` variable's `backward()` function to back-propagate the gradient until the leaf node. 55 | + Second, define your interface in `functions.py` see other functions for more information 56 | + Third, if the `Variable` supports that function, define new method in `Variable` class in `variables.py` 57 | + Finally, test the new function with your own test case in `eval.py` with command `python eval.py` -------------------------------------------------------------------------------- /base/base_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import get_broadcast_dims 3 | 4 | 5 | class Function: 6 | def __init__(self): 7 | self.context = dict() 8 | self.in_grad = None 9 | self.output = None 10 | 11 | def __call__(self, *args, **kwargs): 12 | return self.forward(*args, **kwargs) 13 | 14 | def forward(self, *tensors): 15 | raise NotImplementedError 16 | 17 | def forward_(self): 18 | raise NotImplementedError 19 | 20 | def backward_(self): 21 | raise NotImplementedError 22 | 23 | def backward__(self): 24 | raise NotImplementedError 25 | 26 | def backward(self, in_grad): 27 | self.in_grad = in_grad.data 28 | self.backward_() 29 | self.backward__() 30 | 31 | def get_args(self, list_options): 32 | 33 | list_args = [] 34 | 35 | for i, option in enumerate(list_options): 36 | var = None 37 | if self.args is not None: 38 | if len(self.args) > i: 39 | var = self.args[i] 40 | 41 | if self.kwargs is not None: 42 | if option in self.kwargs: 43 | var = self.kwargs[option] 44 | 45 | self.context[option] = var 46 | 47 | list_args.append(var) 48 | 49 | return list_args 50 | 51 | 52 | class BinaryFunction(Function): 53 | def __init__(self): 54 | super().__init__() 55 | self.input1 = None 56 | self.input2 = None 57 | self.input1_var = None 58 | self.input2_var = None 59 | self.grad1 = None 60 | self.grad2 = None 61 | self.args = None 62 | self.kwargs = None 63 | 64 | def forward(self, input1, input2, *args, **kwargs): 65 | self.input1 = input1.data 66 | self.input2 = input2.data 67 | self.input1_var = input1 68 | self.input2_var = input2 69 | self.args = args 70 | self.kwargs = kwargs 71 | self.forward_() 72 | return self.output 73 | 74 | def forward_(self): 75 | self.output = getattr(torch, self.op)(self.input1, self.input2, *self.args, **self.kwargs) 76 | 77 | def backward__(self): 78 | if self.input1.data.is_cuda: 79 | self.grad1 = self.grad1.cuda() 80 | 81 | if self.input2.data.is_cuda: 82 | self.grad2 = self.grad2.cuda() 83 | 84 | if self.op == 'matmul': 85 | out_grad1 = self.in_grad @ self.grad1.t() 86 | out_grad2 = self.grad2.t() @ self.in_grad 87 | else: 88 | out_grad1 = self.in_grad * self.grad1 89 | out_grad2 = self.in_grad * self.grad2 90 | 91 | if self.input1.data.dim() != 0: # not a scalar value, it needs gradient 92 | if out_grad1.shape != self.input1.shape: # apply broadcast in forward pass 93 | list_dims, list_not_keeps = get_broadcast_dims(self.input1, out_grad1) 94 | 95 | for dim in list_dims: 96 | if dim in list_not_keeps: 97 | out_grad1 = out_grad1.sum(dim, keepdim=False) 98 | else: 99 | out_grad1 = out_grad1.sum(dim, keepdim=True) 100 | 101 | out_grad1 = out_grad1.view_as(self.input1) 102 | 103 | self.input1_var.backward(out_grad1) 104 | 105 | if self.input2.data.dim() != 0: # not a scalar value, it needs gradient 106 | if out_grad2.shape != self.input2.shape: # apply broadcast in forward pass 107 | list_dims, list_not_keeps = get_broadcast_dims(self.input2, out_grad2) 108 | 109 | for dim in list_dims: 110 | if dim in list_not_keeps: 111 | out_grad2 = out_grad2.sum(dim, keepdim=False) 112 | else: 113 | out_grad2 = out_grad2.sum(dim, keepdim=True) 114 | 115 | out_grad2 = out_grad2.view_as(self.input2) 116 | 117 | self.input2_var.backward(out_grad2) 118 | 119 | 120 | class UnaryFunction(Function): 121 | def __init__(self): 122 | super().__init__() 123 | self.input = None 124 | self.input_var = None 125 | self.grad = None 126 | self.args = None 127 | self.kwargs = None 128 | 129 | def forward_(self): 130 | self.output = getattr(torch, self.op)(self.input, *self.args, **self.kwargs) 131 | 132 | def forward(self, input, *args, **kwargs): 133 | self.input = input.data 134 | self.input_var = input 135 | self.args = args 136 | self.kwargs = kwargs 137 | self.forward_() 138 | 139 | return self.output 140 | 141 | def backward__(self): 142 | if self.input.data.is_cuda: 143 | self.grad = self.grad.cuda() 144 | 145 | if self.input.numel() != self.in_grad.numel(): # reduce op 146 | if self.op == 'getitem': 147 | self.grad.__setitem__(*self.args, 1, **self.kwargs) 148 | else: 149 | dim = self.context['dim'] 150 | if dim is not None: 151 | self.in_grad = self.in_grad.unsqueeze(dim) 152 | else: 153 | self.in_grad = self.in_grad.unsqueeze(-1) 154 | 155 | out_grad = self.in_grad * self.grad 156 | 157 | self.input_var.backward(out_grad) 158 | -------------------------------------------------------------------------------- /base/functions_class.py: -------------------------------------------------------------------------------- 1 | from base.base_function import BinaryFunction, UnaryFunction 2 | import torch 3 | from utils import get_broadcast_dims 4 | 5 | 6 | class ReLU(UnaryFunction): 7 | op = 'ReLU' 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward_(self): 13 | pos_mask = self.input > 0 14 | self.context['pos_mask'] = pos_mask 15 | return self.input[pos_mask] 16 | 17 | def backward_(self): 18 | pos_mask = self.context['pos_mask'] 19 | out = torch.zeros_like(pos_mask) 20 | out[pos_mask] = 1 21 | return out 22 | 23 | 24 | # ------- Reduction Function --------- # 25 | 26 | class Sum(UnaryFunction): 27 | op = 'sum' 28 | 29 | def backward_(self): 30 | dim, keepdim = self.get_args(['dim', 'keepdim']) 31 | self.grad = torch.ones_like(self.input) 32 | 33 | 34 | class Max(UnaryFunction): 35 | op = 'max' 36 | 37 | def backward_(self): 38 | self.get_args(['dim', 'keepdim']) 39 | 40 | self.grad = torch.zeros_like(self.input).float() 41 | 42 | if isinstance(self.output, (tuple, list)): 43 | index = self.output[1] 44 | self.grad[index] = 1. 45 | else: # max of all elements 46 | index = getattr(self.input.view(-1, 1), self.op)(0)[1] 47 | self.grad = self.grad.view(-1, 1) 48 | self.grad[index] = 1. 49 | self.grad = self.grad.view(*self.input.size()) 50 | 51 | 52 | class Min(Max): 53 | op = 'min' 54 | 55 | 56 | class Prod(UnaryFunction): 57 | op = 'prod' 58 | 59 | def backward_(self): 60 | dim, keepdim = self.get_args(['dim', 'keepdim']) 61 | if dim is not None and not keepdim: 62 | self.grad = self.output.unsqueeze(dim).expand_as(self.input) / self.input 63 | else: 64 | self.grad = self.output.expand_as(self.input) / self.input 65 | 66 | 67 | class Mean(UnaryFunction): 68 | op = 'mean' 69 | 70 | def backward_(self): 71 | dim, keepdim = self.get_args(['dim', 'keepdim']) 72 | 73 | if dim is not None: 74 | n = self.input.size(dim) 75 | else: 76 | n = self.input.numel() 77 | 78 | self.grad = torch.ones_like(self.input) / n 79 | 80 | 81 | class Var(UnaryFunction): 82 | op = 'var' 83 | 84 | def backward_(self): 85 | dim, keepdim, unbiased = self.get_args(['dim', 'keepdim', 'unbiased']) 86 | 87 | if dim is not None: 88 | n = self.input.size(dim) 89 | else: 90 | n = self.input.numel() 91 | 92 | mean_grad = torch.ones_like(self.input) / n 93 | temp_kwargs = dict(self.kwargs) 94 | temp_kwargs.pop('unbiased', None) 95 | 96 | mean = self.input.mean(*self.args, **temp_kwargs) 97 | if dim is not None: 98 | mean = mean.unsqueeze(dim) 99 | 100 | temp_grad = 2 * (self.input - mean) * (1 - mean_grad) 101 | 102 | if unbiased: # use n - 1, else 103 | self.grad = temp_grad / (n - 1) 104 | else: # use n 105 | self.grad = temp_grad / n 106 | 107 | 108 | class Norm(UnaryFunction): 109 | op = 'norm' 110 | 111 | def backward_(self): 112 | p, dim, keepdim = self.get_args(['p', 'dim', 'keepdim']) 113 | 114 | if p is None: 115 | p = 2 116 | 117 | if dim is None: 118 | if p == 2: 119 | self.grad = self.input / self.output 120 | else: 121 | pow = self.input.abs().pow(p - 2) 122 | scale = 1. / self.output ** (p - 1) 123 | self.grad = self.input * pow * scale 124 | else: 125 | if keepdim is False: 126 | self.in_grad = self.in_grad.unsqueeze(dim) 127 | self.output = self.output.unsqueeze(dim) 128 | 129 | self.in_grad = self.in_grad.expand_as(self.input) 130 | if p == 2: 131 | big_output = self.output.expand_as(self.input) 132 | self.grad = self.input / big_output 133 | else: 134 | pow = self.input.abs().pow(p - 2) 135 | big_output = self.output.pow(p - 1).expand_as(self.input) 136 | self.grad = self.input * pow / big_output 137 | 138 | 139 | class T(UnaryFunction): 140 | op = 't' 141 | 142 | def backward_(self): 143 | self.grad = torch.ones_like(self.input) 144 | self.in_grad = self.in_grad.t() 145 | 146 | 147 | class Transpose(UnaryFunction): 148 | op = 'transpose' 149 | 150 | def backward_(self): 151 | self.grad = torch.ones_like(self.input) 152 | self.in_grad = self.in_grad.transpose(*self.args, **self.kwargs) 153 | 154 | 155 | class Permute(UnaryFunction): 156 | op = 'permute' 157 | 158 | def forward_(self): 159 | self.output = self.input.permute(*self.args, **self.kwargs) 160 | 161 | def backward_(self): 162 | self.grad = torch.ones_like(self.input) 163 | self.in_grad = self.in_grad.permute(*self.args, **self.kwargs) 164 | 165 | 166 | class Exp(UnaryFunction): 167 | op = 'exp' 168 | 169 | def backward_(self): 170 | self.grad = self.output 171 | 172 | 173 | class Sqrt(UnaryFunction): 174 | op = 'sqrt' 175 | 176 | def backward_(self): 177 | self.grad = 1 / (2 * self.output) 178 | 179 | 180 | class Abs(UnaryFunction): 181 | op = 'abs' 182 | 183 | def backward_(self): 184 | self.grad = self.input.sign() 185 | 186 | 187 | class Sin(UnaryFunction): 188 | op = 'sin' 189 | 190 | def backward_(self): 191 | self.grad = self.input.cos() 192 | 193 | 194 | class Cos(UnaryFunction): 195 | op = 'cos' 196 | 197 | def backward_(self): 198 | self.grad = -self.input.sin() 199 | 200 | 201 | class Tan(UnaryFunction): 202 | op = 'tan' 203 | 204 | def backward_(self): 205 | self.grad = 1 + self.input.tan()**2 206 | 207 | 208 | class Tanh(UnaryFunction): 209 | op = 'tanh' 210 | 211 | def backward_(self): 212 | self.grad = 1 - self.input.tan()**2 213 | 214 | 215 | class Sigmoid(UnaryFunction): 216 | op = 'sigmoid' 217 | 218 | def backward_(self): 219 | self.grad = self.output * (1 - self.output) 220 | 221 | 222 | class Log(UnaryFunction): 223 | op = 'log' 224 | 225 | def backward_(self): 226 | self.grad = torch.reciprocal(self.input) 227 | 228 | 229 | class Add(BinaryFunction): 230 | op = 'add' 231 | 232 | def backward_(self): 233 | self.grad1 = torch.ones_like(self.input1) 234 | self.grad2 = torch.ones_like(self.input2) 235 | 236 | 237 | class Mul(BinaryFunction): 238 | op = 'mul' 239 | 240 | def backward_(self): 241 | self.grad1 = self.input2 242 | self.grad2 = self.input1 243 | 244 | 245 | class MatMul(BinaryFunction): 246 | op = 'matmul' 247 | 248 | def backward_(self): 249 | self.grad1 = self.input2 250 | self.grad2 = self.input1 251 | 252 | 253 | class Pow(BinaryFunction): 254 | op = 'pow' 255 | 256 | def backward_(self): 257 | self.grad1 = self.input2 * self.input1 ** (self.input2 - 1) 258 | self.grad2 = torch.log(self.input1) * self.input1 ** self.input2 259 | 260 | 261 | class Max2(BinaryFunction): 262 | op = 'max' 263 | 264 | def backward_(self): 265 | input1 = self.input1 266 | input2 = self.input2 267 | 268 | if input1.numel() > input2.numel(): # use broadcast 269 | input2 = input2.expand_as(input1).contiguous() 270 | else: 271 | input1 = input1.expand_as(input2).contiguous() 272 | 273 | input1 = input1.view(-1) 274 | input2 = input2.view(-1) 275 | input = torch.stack([input1, input2], 1) 276 | 277 | grad = torch.zeros_like(input).float() 278 | index = getattr(input, self.op)(1)[1] 279 | 280 | grad[range(len(index)), index] = 1. 281 | 282 | grad1, grad2 = grad[:, 0], grad[:, 1] 283 | 284 | if grad1.numel() != self.input1.numel(): # input1 is broadcasted 285 | grad1 = grad1.view_as(self.input2) 286 | 287 | list_dims, list_not_keeps = get_broadcast_dims(self.input1, grad1) 288 | 289 | for dim in list_dims: 290 | if dim in list_not_keeps: 291 | grad1 = grad1.sum(dim, keepdim=False) / grad1.size(dim) 292 | else: 293 | grad1 = grad1.sum(dim, keepdim=True) / grad1.size(dim) 294 | 295 | self.grad1 = grad1.view_as(self.input1) 296 | 297 | if grad2.numel() != self.input2.numel(): 298 | grad2 = grad2.view_as(self.input1) 299 | 300 | list_dims, list_not_keeps = get_broadcast_dims(self.input2, grad2) 301 | 302 | for dim in list_dims: 303 | if dim in list_not_keeps: 304 | grad2 = grad2.sum(dim, keepdim=False) / grad2.size(dim) 305 | else: 306 | grad2 = grad2.sum(dim, keepdim=True) / grad2.size(dim) 307 | 308 | self.grad2 = grad2.view_as(self.input2) 309 | 310 | 311 | class Min2(Max2): 312 | op = 'min' 313 | 314 | 315 | class Cat(UnaryFunction): 316 | op = 'cat' 317 | 318 | def __init__(self): 319 | super().__init__() 320 | self.input = None 321 | self.grad = None 322 | self.args = None 323 | self.kwargs = None 324 | 325 | def forward(self, input, *args, **kwargs): 326 | self.input = input 327 | self.args = args 328 | self.kwargs = kwargs 329 | self.output = getattr(torch, self.op)([x.data for x in self.input], *args, **kwargs) 330 | 331 | return self.output 332 | 333 | def backward_(self): 334 | self.grad = torch.ones_like(self.in_grad) 335 | 336 | def backward__(self): 337 | if self.input[0].data.is_cuda: 338 | self.grad = self.grad.cuda() 339 | 340 | dim, = self.get_args(['dim']) 341 | 342 | if dim is None: 343 | dim = 0 344 | 345 | list_lengths = [x.size(dim) for x in self.input] 346 | 347 | out_grad = self.in_grad * self.grad 348 | 349 | out_grads = out_grad.split(list_lengths, dim) 350 | 351 | for input_, grad in zip(self.input, out_grads): 352 | input_.backward(grad) 353 | 354 | 355 | class Stack(Cat): 356 | op = 'stack' 357 | 358 | 359 | class Squeeze(UnaryFunction): 360 | op = 'squeeze' 361 | 362 | def backward_(self): 363 | self.grad = torch.ones_like(self.input) 364 | dim, = self.get_args(['dim']) 365 | 366 | if dim is not None: 367 | self.in_grad = self.in_grad.unsqueeze(*self.args, **self.kwargs) 368 | else: 369 | squeeze_dims = [i for i, x in enumerate(self.input.size()) if x == 1] 370 | for dim in squeeze_dims: 371 | self.in_grad = self.in_grad.unsqueeze(dim) 372 | 373 | 374 | class Unsqueeze(UnaryFunction): 375 | op = 'unsqueeze' 376 | 377 | def backward_(self): 378 | self.grad = torch.ones_like(self.input) 379 | self.in_grad = torch.squeeze(self.in_grad.data, *self.args, **self.kwargs) 380 | 381 | 382 | class Split(UnaryFunction): 383 | op = 'split' 384 | 385 | def backward_(self): 386 | self.grad = torch.ones_like(self.input) 387 | 388 | def backward(self, in_grad): 389 | dim, = self.get_args(['dim']) 390 | 391 | if dim is None: 392 | dim = 0 393 | 394 | self.in_grad = torch.cat([x.data for x in in_grad], dim) 395 | self.backward_() 396 | self.backward__() 397 | 398 | 399 | class Getitem(UnaryFunction): 400 | op = '__getitem__' 401 | 402 | def forward_(self): 403 | self.output = self.input.__getitem__(*self.args, **self.kwargs) 404 | 405 | def backward_(self): 406 | self.grad = torch.zeros_like(self.input) 407 | 408 | 409 | class View(UnaryFunction): 410 | op = 'view' 411 | 412 | def forward_(self): 413 | self.output = self.input.view(*self.args, **self.kwargs) 414 | 415 | def backward_(self): 416 | self.grad = torch.ones_like(self.input) 417 | self.in_grad = self.in_grad.view(*self.input.size()) 418 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import check 3 | 4 | 5 | def check_example(pytorch=False): 6 | 7 | if pytorch: 8 | from torch.autograd import Variable 9 | else: 10 | from variables import Variable 11 | 12 | a = Variable(torch.Tensor([[-4, 2], [8, 4], [5, 6]]).float(), requires_grad=True) 13 | b = Variable(torch.Tensor([[2, 3, 4], [5, 6, 7]]).float(), requires_grad=True) 14 | c = Variable(torch.Tensor([[4, 5, 6]]).float(), requires_grad=True) 15 | 16 | if pytorch: 17 | import torch as F 18 | else: 19 | import functions as F 20 | 21 | d = (a * b.permute(1, 0)).view(2, 3) 22 | g = d.sum() 23 | 24 | # k = a + 1 25 | # d = (k @ b).tanh() 26 | # e = d @ c.t() 27 | # f = e ** a.sum(1) 28 | # g = f.norm(dim=0, keepdim=True) 29 | # g = g.sum() 30 | 31 | # k.retain_grad() 32 | # c.retain_grad() 33 | d.retain_grad() 34 | # e.retain_grad() 35 | # f.retain_grad() 36 | g.retain_grad() 37 | 38 | g.backward() 39 | 40 | local = locals() 41 | vars = list('abcdg') 42 | ret_tuple = tuple(local[x] for x in vars) 43 | return ret_tuple, vars 44 | 45 | 46 | if __name__ == '__main__': 47 | ret1, vars = check_example() 48 | 49 | ret2 = check_example(pytorch=True)[0] 50 | 51 | check(ret1, ret2, vars) -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | from base.functions_class import * 2 | 3 | 4 | def binary_op(op1, op2, op_): 5 | if op1.__class__ == op2.__class__: 6 | Variable = op2.__class__ 7 | else: 8 | if isinstance(op1, (float, int, bool)): 9 | Variable = op2.__class__ 10 | op1 = Variable(torch.tensor(float(op1))) 11 | else: 12 | Variable = op1.__class__ 13 | op2 = Variable(torch.tensor(float(op2))) 14 | 15 | op = op_() 16 | out = op(op1, op2) 17 | return Variable(out, op) 18 | 19 | 20 | def unary_op(op1, op_, *args, **kwargs): 21 | if not isinstance(op1, (tuple, list)): 22 | Variable = op1.__class__ 23 | else: 24 | Variable = op1[0].__class__ 25 | 26 | op = op_() 27 | out = op(op1, *args, **kwargs) 28 | 29 | if isinstance(out, torch.Tensor): 30 | return Variable(out, op) 31 | elif isinstance(out, (tuple, list)): # return 2 values 32 | return [Variable(o, op) for o in out] 33 | 34 | 35 | def add(op1, op2): 36 | return binary_op(op1, op2, Add) 37 | 38 | 39 | def matmul(op1, op2): 40 | return binary_op(op1, op2, MatMul) 41 | 42 | 43 | def mul(op1, op2): 44 | return binary_op(op1, op2, Mul) 45 | 46 | 47 | def add(op1, op2): 48 | return binary_op(op1, op2, Add) 49 | 50 | 51 | def sub(op1, op2): 52 | return binary_op(op1, -1 * op2, Add) 53 | 54 | 55 | def truediv(op1, op2): 56 | return binary_op(op1, op2 ** -1, Mul) 57 | 58 | 59 | def pow(op1, power, modulo=None): 60 | return binary_op(op1, power, Pow) 61 | 62 | 63 | def t(op): 64 | return unary_op(op, T) 65 | 66 | 67 | def transpose(op, *args, **kwargs): 68 | return unary_op(op, Transpose, *args, **kwargs) 69 | 70 | 71 | def permute(op, *args, **kwargs): 72 | return unary_op(op, Permute, *args, **kwargs) 73 | 74 | 75 | def abs(op): 76 | return unary_op(op, Abs) 77 | 78 | 79 | def sin(op): 80 | return unary_op(op, Sin) 81 | 82 | 83 | def cos(op): 84 | return unary_op(op, Cos) 85 | 86 | 87 | def tan(op): 88 | return unary_op(op, Tan) 89 | 90 | 91 | def tanh(op): 92 | return unary_op(op, Tanh) 93 | 94 | 95 | def neg(op): 96 | return binary_op(op, -1, Mul) 97 | 98 | 99 | def sum(op, *args, **kwargs): 100 | return unary_op(op, Sum, *args, **kwargs) 101 | 102 | 103 | def exp(op): 104 | return unary_op(op, Exp) 105 | 106 | 107 | def log(op): 108 | return unary_op(op, Log) 109 | 110 | 111 | def sqrt(op): 112 | return unary_op(op, Sqrt) 113 | 114 | 115 | def sigmoid(op): 116 | return unary_op(op, Sigmoid) 117 | 118 | 119 | def norm(op, *args, **kwargs): 120 | return unary_op(op, Norm, *args, **kwargs) 121 | 122 | 123 | def mean(op1, *args, **kwargs): 124 | return unary_op(op1, Mean, *args, **kwargs) 125 | 126 | 127 | def prod(op1, *args, **kwargs): 128 | return unary_op(op1, Prod, *args, **kwargs) 129 | 130 | 131 | def var(op1, *args, **kwargs): 132 | return unary_op(op1, Var, *args, **kwargs) 133 | 134 | 135 | def std(op1, *args, **kwargs): 136 | return unary_op(op1, Var, *args, **kwargs).sqrt() 137 | 138 | 139 | def max(op1, *args, **kwargs): 140 | if isinstance(args[0], op1.__class__): 141 | return binary_op(op1, args[0], Max2) 142 | else: 143 | return unary_op(op1, Max, *args, **kwargs) 144 | 145 | 146 | def min(op1, *args, **kwargs): 147 | if isinstance(args[0], op1.__class__): 148 | return binary_op(op1, args[0], Min2) 149 | else: 150 | return unary_op(op1, Min, *args, **kwargs) 151 | 152 | 153 | def cat(list_ops, *args, **kwargs): 154 | return unary_op(list_ops, Cat, *args, **kwargs) 155 | 156 | 157 | def stack(list_ops, *args, **kwargs): 158 | return unary_op(list_ops, Stack, *args, **kwargs) 159 | 160 | 161 | def squeeze(op1, *args, **kwargs): 162 | return unary_op(op1, Squeeze, *args, **kwargs) 163 | 164 | 165 | def unsqueeze(op1, *args, **kwargs): 166 | return unary_op(op1, Unsqueeze, *args, **kwargs) 167 | 168 | 169 | def split(op1, *args, **kwargs): 170 | return unary_op(op1, Split, *args, **kwargs) 171 | 172 | 173 | def getitem(op1, *args, **kwargs): 174 | return unary_op(op1, Getitem, *args, **kwargs) 175 | 176 | 177 | def view(op1, *args, **kwargs): 178 | return unary_op(op1, View, *args, **kwargs) 179 | 180 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_broadcast_dims(input, output): 5 | list_dims = [] 6 | list_not_keeps = [] 7 | 8 | if input.dim() < output.dim(): 9 | table = torch.zeros(input.dim(), output.dim()) 10 | for i, v_i in enumerate(input.size()): 11 | for j, v_j in enumerate(output.size()): 12 | if v_i == v_j and all(table[i, :j] == 0): # just accept one-to-one mapping 13 | table[i, j] = 1 14 | 15 | for k in range(output.dim()): 16 | if all(table[:, k] == 0): # add dimension here 17 | input.unsqueeze(k) 18 | list_not_keeps.append(k) 19 | 20 | for i, (l1, l2) in enumerate(zip(input.size(), output.size())): 21 | if l1 < l2: 22 | list_dims.append(i) 23 | 24 | return list_dims, set(list_not_keeps) 25 | 26 | 27 | def check(list1, list2, vars): 28 | for i, (e1, e2) in enumerate(zip(list1, list2)): 29 | print("Compare data of variable {}:".format(vars[i])) 30 | if torch.equal(e1.data, e2.data): 31 | print('Correct') 32 | else: 33 | print('Incorrect: \n\tcomputed: {} with shape: {}\n\texpected: {} with shape: {}' 34 | .format(e1.data, e1.shape, e2.data, e2.shape)) 35 | 36 | print("Compare grad of variable {}:".format(vars[i])) 37 | if (e1.grad is None and e2.grad is None) or torch.equal(e1.grad.data, e2.grad.data)\ 38 | or any(torch.isnan(e2.grad.data).tolist()): 39 | print('Correct') 40 | else: 41 | print('Incorrect: \n\tcomputed: {} with shape: {}\n\texpected: {} with shape: {}' 42 | .format(e1.grad.data, e1.grad.shape, e2.grad.data, e2.grad.shape)) 43 | 44 | print('---') 45 | -------------------------------------------------------------------------------- /variables.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functions as F 3 | 4 | 5 | class Variable: 6 | def __init__(self, data, grad_fn=None, requires_grad=None): 7 | self.data = data 8 | self.grad = None 9 | self.grad_fn = grad_fn 10 | 11 | def __matmul__(self, other): 12 | return F.matmul(self, other) 13 | 14 | def __mul__(self, other): 15 | return F.mul(self, other) 16 | 17 | def __add__(self, other): 18 | return F.add(self, other) 19 | 20 | def __sub__(self, other): 21 | return F.sub(self, other) 22 | 23 | def __rsub__(self, other): 24 | return F.sub(other, self) 25 | 26 | def __truediv__(self, other): 27 | return F.truediv(self, other) 28 | 29 | def __rtruediv__(self, other): 30 | return F.truediv(other, self) 31 | 32 | def __pow__(self, power, modulo=None): 33 | return F.pow(self, power) 34 | 35 | def __rpow__(self, other): 36 | return F.pow(other, self) 37 | 38 | def __abs__(self): 39 | return F.abs(self) 40 | 41 | def __neg__(self): 42 | return F.neg(self) 43 | 44 | def __getitem__(self, *args, **kwargs): 45 | return F.getitem(self, *args, **kwargs) 46 | 47 | def size(self, *args): 48 | return self.data.size(*args) 49 | 50 | @property 51 | def shape(self): 52 | return self.data.size() 53 | 54 | @property 55 | def T(self): 56 | return self.t() 57 | 58 | def t(self): 59 | return F.t(self) 60 | 61 | def transpose(self, *args, **kwargs): 62 | return F.transpose(self, *args, **kwargs) 63 | 64 | def permute(self, *args, **kwargs): 65 | return F.permute(self, *args, **kwargs) 66 | 67 | def split(self, *args, **kwargs): 68 | return F.split(self, *args, **kwargs) 69 | 70 | def squeeze(self, *args, **kwargs): 71 | return F.squeeze(self, *args, **kwargs) 72 | 73 | def unsqueeze(self, *args, **kwargs): 74 | return F.unsqueeze(self, *args, **kwargs) 75 | 76 | def view(self, *args, **kwargs): 77 | return F.view(self, *args, **kwargs) 78 | 79 | def dim(self): 80 | return self.data.dim() 81 | 82 | def numel(self): 83 | return self.data.numel() 84 | 85 | def abs(self): 86 | return F.abs(self) 87 | 88 | def sin(self): 89 | return F.sin(self) 90 | 91 | def cos(self): 92 | return F.cos(self) 93 | 94 | def tan(self): 95 | return F.tan(self) 96 | 97 | def tanh(self): 98 | return F.tanh(self) 99 | 100 | def exp(self): 101 | return F.exp(self) 102 | 103 | def log(self): 104 | return F.log(self) 105 | 106 | def sqrt(self): 107 | return F.sqrt(self) 108 | 109 | def sigmoid(self): 110 | return F.sigmoid(self) 111 | 112 | # Reduction op 113 | 114 | def sum(self, *args, **kwargs): 115 | return F.sum(self, *args, **kwargs) 116 | 117 | def norm(self, *args, **kwargs): 118 | return F.norm(self, *args, **kwargs) 119 | 120 | def mean(self, *args, **kwargs): 121 | return F.mean(self, *args, **kwargs) 122 | 123 | def var(self, *args, **kwargs): 124 | return F.var(self, *args, **kwargs) 125 | 126 | def std(self, *args, **kwargs): 127 | return F.std(self, *args, **kwargs) 128 | 129 | def max(self, *args, **kwargs): 130 | return F.max(self, *args, **kwargs) 131 | 132 | def min(self, *args, **kwargs): 133 | return F.min(self, *args, **kwargs) 134 | 135 | def prod(self, *args, **kwargs): 136 | return F.prod(self, *args, **kwargs) 137 | 138 | __radd__ = __add__ 139 | __rmul__ = __mul__ 140 | __iadd__ = __add__ 141 | 142 | def retain_grad(self): 143 | pass 144 | 145 | def backward(self, in_grad=None): 146 | 147 | if in_grad is None: 148 | if self.data.size() != torch.Size([]): 149 | raise RuntimeError('grad can be implicitly created only for scalar outputs') 150 | temp_grad = Variable(torch.tensor(1.).cuda() if self.data.is_cuda else torch.tensor(1.)) 151 | else: 152 | temp_grad = Variable(in_grad) 153 | 154 | self.grad = (self.grad if self.grad is not None else 0) + temp_grad 155 | 156 | if self.grad_fn: 157 | self.grad_fn.backward(self.grad) 158 | 159 | def __repr__(self): 160 | return 'Variable containing: {}'.format(self.data.__repr__()) 161 | 162 | 163 | class Parameter(Variable): 164 | def __init__(self, tensor): 165 | super().__init__(tensor) 166 | 167 | def uniform_(self, a, b): 168 | self.data = torch.rand(*self.data.size()) * (b - a) + a 169 | --------------------------------------------------------------------------------