├── README.md ├── build-with-mpi.md ├── build_with_nccl.md ├── dispatch.md ├── profile.md ├── tensor_add_explained.md └── walkthrough.md /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-learning 2 | learning notes when learning the source code of pytorch 3 | -------------------------------------------------------------------------------- /build-with-mpi.md: -------------------------------------------------------------------------------- 1 | Build pytorch with cuda-aware MPI support 2 | ==== 3 | Advantages: 4 | - Low cpu usage 5 | - low latency 6 | - no extra copy from/to system memory 7 | - easy to use 8 | 9 | Requirements: 10 | - UVA support (available since ) 11 | 12 | Build mpi with CUDA support (openmpi) 13 | --- 14 | 15 | dependencies: 16 | - automake 17 | - flex 18 | 19 | ```bash 20 | apt install automake flex 21 | ``` 22 | > [color=#10d19a]**currently only openmpi-1.10 is supported by pytorch's compile system.** 23 | > [name=Stone sky] 24 | > [time=Thu, Apr 5, 2018 9:56 PM] 25 | > 26 | > [color=#10d19a]openmpi-3.1.1 can compile with pytorch mater branch 27 | > 28 | > [name=Stone sky] 29 | > [time=Wed, Jul 25, 2018 11:39 PM] 30 | 31 | download the source file from internet, you may refer to [up-to-date page](https://www.open-mpi.org/software/ompi/v3.0/) 32 | ```bash 33 | wget https://www.open-mpi.org/software/ompi/v2.1/downloads/openmpi-2.1.3.tar.gz # TODO: change to URL of 1.10.7 34 | tar xvf openmpi-1.10.7.tar.gz 35 | ``` 36 | 37 | The newest release at this time is *3.0.1*, but I failed to build on that due to a building bug. 38 | build from source with CUDA support. 39 | 40 | 41 | ``` 42 | cd openmpi-1.10.7 43 | mkdir build && cd build 44 | ../configure --with-cuda --enable-mpi-thread-multiple # it's not tab completed by zsh 45 | ``` 46 | If your CUDA location is not `/usr/local/cuda` or you want to compile with non-default CUDA version, you may follow the [official-CUDA-tutorial]( 47 | https://www.open-mpi.org/faq/?category=buildcuda) for customized build options. 48 | 49 | Build pytorch with new open-mpi 50 | ---- 51 | 52 | 1. build with system-wide mpi (older version) 53 | 54 | Since the building system of pytorch looks for `libmpi` and `libmpicxx` at `/usr/lib`, while the default install path of open-mpi is `/usr/local/lib`. The general building process will raise error `mpi not found` for that. You can either copy/link the `.so` libraries or specify extra linking flags to compile successfully. 55 | 56 | Workaround for pytorch 57 | ```bash 58 | sudo cp /usr/local/lib/libmpi* /usr/lib 59 | # compile pytorch 60 | python setup.py clean 61 | python setup.py build develop 62 | 63 | # (optional, delete the redundant files) 64 | sudo rm /usr/lib/libmpi* 65 | ``` 66 | 67 | Then export the libraries to `LD_LIBRARY_PATH` in case of `file not found` error. 68 | 69 | ```bash 70 | export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH" 71 | ``` 72 | 73 | 2. build with arbitary version of mpi 74 | 75 | Pytorch uses the *find_MPI* package bundled with *CMAKE*. In the newest CMAKE, it can automatically detect the MPI's lib and include path if an MPI compatible compiler is specified. 76 | 77 | e.g. 78 | ```bash 79 | python setup.py clean 80 | CMAKE_C_COMPILER=$(which mpicc) CMAKE_CXX_COMPILER=$(which mpicxx) python setup.py build develop 81 | ``` 82 | 83 | 84 | How to check the CUDA-aware MPI support 85 | --- 86 | 87 | 1. simply list the dynamic libraries linked to pytorch's run-time. 88 | `ldd torch/*.so`. If compiled with MPI, you can find `libmpi.so`. If compiled with CUDA-aware MPI, you can find `libopen-rte.so`. 89 | 90 | 2. run test-code 91 | 92 | 93 | ```python 94 | import torch 95 | import torch.distributed as dist 96 | 97 | dist.init_process_group(backend='mpi') 98 | 99 | t = torch.zeros(5,5).fill_(dist.get_rank()).cuda() 100 | 101 | dist.all_reduce(t) # ??? 102 | 103 | 104 | ``` -------------------------------------------------------------------------------- /build_with_nccl.md: -------------------------------------------------------------------------------- 1 | Build pytorch with NCCL2 2 | ===== 3 | 4 | Pre 5 | --- 6 | 7 | #### What is NCCL2? 8 | Check this out: 9 | https://github.com/PaddlePaddle/Paddle/wiki/NCCL2-Survey 10 | 11 | Check your NCCL version: 12 | - pytorch comes with NCCL 1 bundled, but it will detect the system-wide NCCL for better performance (**Compile time only**). 13 | - `vim /usr/include/nccl.h` to check the NCCL version. 14 | 15 | Step by Step 16 | --- 17 | 18 | 1. Download NCCL2 runtime and header files. 19 | NCCL2 is not open-sourced, so you have to download the compiled version from NVIDIA's website. Or directly from the repo file server: http://developer.download.nvidia.com/compute/machine-learning/repos 20 | warning: you have to choose the right system version and hardware architecture. 21 | Choose the right NCCL version and cuda version. Here I downloaded `libnccl2_2.2.12-1+cuda8.0_amd64.deb` and `libnccl-dev_2.2.12-1+cuda8.0_amd64.deb`. 22 | 23 | 24 | 2. Install NCCL2: 25 | It's quite straightforward: (ubuntu-16.04LTS for e.g.) 26 | ```bash 27 | sudo dpkg -i libnccl2_2.2.12-1+cuda8.0_amd64.deb libnccl-dev_2.2.12-1+cuda8.0_amd64.deb 28 | ``` 29 | > I suffered to find the location of installed files (header and lib), simply run `dpkg -c *.deb` solves the problem. 30 | 31 | 32 | 2. compile pytorch from source 33 | After the installation of new NCCL, you have to build the pytorch from a clean directory. 34 | ```bash 35 | cd pytorch 36 | 37 | # both of the following two lines are required, or the cached NCCL1 will 38 | # be used instead of NCCL2 39 | rm build -rf 40 | rm torch/lib/tmp_install -rf 41 | # or you can run this for a complete clean directory: 42 | # python setup.py build clean 43 | 44 | python setup.py build develop 45 | ``` 46 | 47 | If your NCCL2 is installed into customized directory, you can pass the location by environment variables. (haven't tested by myself) 48 | ``` 49 | NCCL_LIB_DIR=~/.local/lib NCCL_INCLUDE_DIR=/~/.local/include python setup.py build develop 50 | ``` 51 | 52 | 3. run 53 | 54 | test_script 55 | 56 | > on node-0 57 | 58 | ```python 59 | import torch 60 | import torch.distributed as dist 61 | 62 | dist.init_process_group(backend="nccl", 63 | init_method="file://distributed_test", 64 | world_size=2, 65 | rank=0) 66 | print('haha') 67 | tensor_list = [] 68 | for dev_idx in range(2): 69 | tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx)) 70 | 71 | dist.all_reduce_multigpu(tensor_list) 72 | ``` 73 | 74 | > on node-1 75 | 76 | ```python 77 | import torch 78 | import torch.distributed as dist 79 | 80 | dist.init_process_group(backend="nccl", 81 | init_method="file://distributed_test", 82 | world_size=2, 83 | rank=1) 84 | print('haha') 85 | tensor_list = [] 86 | for dev_idx in range(2): 87 | tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx)) 88 | 89 | dist.all_reduce_multigpu(tensor_list) 90 | ``` 91 | 92 | #### how to run: 93 | 94 | - **critical**: choose the right NIC for inter-node communication 95 | 96 | Sometimes NCCL for inter-node communication fails to setup connections with each other, then it can raise error *unhandled system error*. In my system it's like 97 | ```log 98 | RuntimeError: NCCL error in: /slwork/users/kys10/Workspace/pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:322, unhandled system error 99 | ``` 100 | Actually if you runs the inter-node version on NCCL1 which only supports intra-node communication, the error message is alwo `unhandled system error` 101 | 102 | In this case you have to specify the NIC viable for inter-node connection. 103 | 104 | Typically, in the machine with docker installed, the error can be workaround by: 105 | `NCCL_SOCKET_IFNAME=^docker0 python node-0.py` 106 | 107 | ref: 108 | - [nvidia-forum](https://devtalk.nvidia.com/default/topic/1023946/gpu-accelerated-libraries/nccl-2-0-support-inter-node-communication-using-sockets-/) 109 | - [NCCL2 doc](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/index.html#ncclknobs) 110 | -------------------------------------------------------------------------------- /dispatch.md: -------------------------------------------------------------------------------- 1 | # how pytorch dispatch the `Variable.index_select` and the auto-grad 2 | 3 | This post assumes readers have the basic concepts of the `Function` and `Variable` in *pytorch*. 4 | 5 | pytorch is a deep learning framework that is famous for the simplicity of prototyping and debugging. There are no separate graph defining and actual computing parts in pytorch, instead, it builds the computation graph on-the-fly. 6 | 7 | ## The *C/C++* backend 8 | 9 | For pytorch, or any other deep learning frameworks such as **mxnet**, **tensorflow**, the backend is usually written in *C/C++* for best performance. It does no more than maintaining tensor information and doing tensor(matrix) math. 10 | 11 | The backend of pytorch is somehow fragile because it re-uses many codes from project **torch**. Fortunately the pytorch community is working on a new unified tensor framework named *ATen* which is already used since version *0.3.0*. At the current stage, lib *ATen* defines data structures such as `Tensor`, `Storage`, `TensorInfo` ..., with more and more native 12 | operations implemented independent of `TH` or `THC` library. 13 | 14 | Currently the computation is dispatched from *ATen* to the corresponding methods defined in `TH` (CPU), `THC` (GPU) and perhaps `THS` (sparse matrix operation). Below the more complicated `THC` backend is used to clarify the function path from *ATen* to real computing kernels. 15 | 16 | ### invoking CUDA kernels 17 | 18 | - kernel wrappers 19 | The wrappers perform some trivial tasks before and after kernel launching, such as error checking, data preparation and setting the kernel runtime parameters, e.g. the `blockSize`,`gridSize`, `stream ID`. 20 | ```C 21 | // pytorch/aten/src/THC/generic/THCTensorIndex.cu 22 | void THCTensor_(gather)(THCState* state, THCTensor *tensor, 23 | THCTensor *src, int dim, THCudaLongTensor *index) { 24 | THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); 25 | THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); 26 | 27 | THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 4, 28 | "Index tensor must have same dimensions as input tensor"); 29 | THLongStorage *indexSize = THCudaLongTensor_newSizeOf(state, index); 30 | THArgCheck(THCTensor_(isSize)(state, tensor, indexSize), 4, 31 | "Index tensor must have the same size as output tensor."); 32 | THLongStorage_free(indexSize); 33 | // to invoke CUDA kernel 34 | ``` 35 | 36 | - CUDA kernels 37 | The CUDA kernels are codes running on CUDA-compatible GPUs, they are generally the critical part to optmize. 38 | ```C 39 | // pytorch/aten/src/THC/THCTensorIndex.cu 40 | template 41 | __global__ void THCudaTensor_gatherKernel( 42 | TensorInfo tensor, 43 | TensorInfo src, 44 | TensorInfo index, 45 | const int dim, 46 | const IndexType totalElements) { 47 | for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x; 48 | linearId < totalElements; 49 | linearId += gridDim.x * blockDim.x) { 50 | IndexType tensorOffset = 0; 51 | IndexType srcOffset = 0; 52 | IndexType indexOffset = 0; 53 | 54 | IndexToScatterGatherOffsets::compute(linearId, dim, 55 | index, &indexOffset, 56 | tensor, &tensorOffset, 57 | src, &srcOffset); 58 | 59 | int64_t indexValue = index.data[indexOffset] - TH_INDEX_BASE; 60 | assert(indexValue >= 0 && indexValue < src.sizes[dim]); 61 | srcOffset += indexValue * src.strides[dim]; 62 | 63 | tensor.data[tensorOffset] = src.data[srcOffset]; 64 | } 65 | } 66 | ``` 67 | 68 | 69 | ## Binding `Python` and `C/C++` API 70 | There are two code generation schemes in Pytorch, native methods and `TH` `THC` dependent methods. 71 | For native methods, you can open file `aten/src/ATen/native/native_functions.yaml` and it's pretty easy to understand. (variants: function, methods, backend: CPU, GPU) 72 | For methods related to `TH` `THC` libraries, Pytorch uses `cwrap` to declare the binding of python call and its' corresponding *ATen* backend. Simply searching *.cwrap* file in pytorch repo will give you some hints (there may be only one 73 | cwrap file now). 74 | 75 | In detail, *cwrap* files are used to generate `PyMethods` for python type object `torch.Tensor` at building phase (when running `python setup.py build`). `PyMethod` is the inner mechanism of Cpython's object implementation, it enables python object to call native C functions, the ATEN function in Pytorch. After defining the proper `PyMethod`, the `tensor_GPU.index_select` call in Python finally invokes the corresponding C backend function `THCTensor_(indexSelect)`, and `tensor_CPU.index_select` calls `THTensor_(indexSelect)`. 76 | 77 | ### how to read cwrap file 78 | **name**: the method name of *ATen* and *Python API* 79 | **cname**: the backend name of library call, the same as **name** is not specified 80 | 81 | The **cname** is important that you can search it in pytorch's directory to find the function definition/declaration. In 82 | case many unrelated search results show up, you'd better limit the search scope to `.c/.cpp/.h` files. 83 | 84 | See docs: 85 | `pytorch/aten/src/ATen/native/README.md` 86 | 87 | And source codes: 88 | `pytorch/aten/src/ATen/native/native_functions.yaml` 89 | `pytorch/aten/src/ATen/Declarations.cwrap` 90 | 91 | ## Defining the `forward`/`backward` pair 92 | In early version, pytorch uses hand-written `forward` and `backward` methods in each `Function` class. But now a separate declaration file is employed to define the `forward`/`backward` pair for simplicity. 93 | 94 | Like the *Python/C binding*, the declaration file is also used to generate `Variable` methods at building stage. 95 | e.g. for `Variable.index_select`, the gradient of self is obtained by calling the `grad.type().zeros(self.sizes()).index_add_(dim, index, grad)`. 96 | 97 | code snippets for `forward`/`backward` binding: 98 | ```yaml 99 | # pytorch/tools/autograd/derivatives.yaml 100 | - name: index_select(Tensor self, int64_t dim, Tensor index) 101 | self: grad.type().zeros(self.sizes()).index_add_(dim, index, grad) 102 | 103 | - name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim) 104 | self: select_backward(grad, dim, indices, self.sizes(), keepdim) 105 | ``` 106 | Here I extract some useful comments from the `derivatives.yaml`: 107 | > Each entry consists of: 108 | > - A 'name', which specifies the ATen name of the function you 109 | > are defining derivatives for, and an argument specification. 110 | > - One or more gradients entries, mapping a differentiable input 111 | > names to a formula specifying how to compute its gradient. 112 | > Note that a single gradient entry can specify the gradient 113 | > formula for multiple input names, by specifying a key 114 | > "self, other" (see atan2 for an example). 115 | The values in this yaml file are standard *C++* (C++11 exactly) statements without trailing semi-colons, which will be invoked by **backward engine** to apply chain rule. 116 | There are two approaches to defining the backward function, a simple one-liner or a more complex function defined in `pytorch/tools/autograd/templates/Functions.cpp`. For example, the backward function for `kthvalue()` is `select_backward()`, 117 | 118 | ```C 119 | // pytorch/tools/autograd/templates/Functions.cpp 120 | Tensor sum_backward(const Tensor & grad, IntList sizes, int64_t dim, bool keepdim) { 121 | #ifdef WITH_SCALARS 122 | if (!keepdim && sizes.size() > 0) { 123 | #else 124 | if (!keepdim && sizes.size() > 1) { 125 | #endif 126 | return grad.unsqueeze(dim).expand(sizes); 127 | } else { 128 | return grad.expand(sizes); 129 | } 130 | } 131 | ``` 132 | 133 | In the meanwhile, users can still write their own `Function` in Python by subclassing and defining their own `forward` and `backward` methods. 134 | See source code: 135 | `pytorch/tools/autograd/derivatives.yaml` 136 | `pytorch/tools/autograd/templates/Functions.cpp` 137 | -------------------------------------------------------------------------------- /profile.md: -------------------------------------------------------------------------------- 1 | How to trace down pytorch's CUDA performance bottleneck 2 | == 3 | 4 | python level 5 | --- 6 | 7 | #### requirements 8 | 1. snakeviz 9 | 2. line-profiler ( optional ) 10 | ```bash 11 | pip install snakeviz cprofile 12 | ``` 13 | >Note: if you profile the program with python3's cProfile, you must use the snakeviz of python3. Otherwise snakeviz raises a warning message saying the profiler output is not a cProfile file 14 | 15 | Since the CUDA tensor operations in pytorch are all asynchronous. Normally the results returned by naive cprofiler will give you some wrong results. Fortunately we can enforce the cuda calls to be synchronized by simply setting the environment variable `CUDA_LAUNCH_BLOCKING=1`. 16 | 17 | ```bash 18 | CUDA_LAUNCH_BLOCKING=1 python -m cProfile -o program.prof program.py 19 | ``` 20 | 21 | ```bash 22 | # if you runs locally, the command will automaticall yopen a web browser 23 | snakeviz program.prof 24 | 25 | # normally we run our program on remote server, addtional parameters are 26 | # required to access from internet 27 | snakeviz -s -H 0.0.0.0 -p 28 | 29 | # you may get detailed informtion from help page 30 | snakeviz --help 31 | ``` 32 | 33 | - ( optional ) `line-profiler` for pure CLI profiling 34 | Instead of `snakeviz`, you can use `line-profiler` to get the time spend on each line of codes. As noted before, you should also set the `CUDA_LAUNCH_BLOCKING=1` for sensible profiling results. 35 | 36 | To use `line-profiler`, you should modify a few lines in your code, simply place a `@profile` decorator above the funtions or methods of insterest. 37 | ```bash 38 | pip install line-profiler 39 | CUDA_LAUNCH_BLOCKING=1 kernprof -lv program.py 40 | ``` 41 | 42 | kernel level 43 | --- 44 | 45 | #### requirements 46 | 1. nvprof 47 | 48 | After finding the most time-consuming calls, you can inspect deeper details by using the `profiler` provided by pytorch. The built-in profiler gives you the detailed elapsed time on basic tensor operations (such as `Addmm`, `IndexSelect`). 49 | 50 | When the tensor operation is specified, I highly recommend you to write a simple test script where only the specific tensor operation is called. 51 | If you are able to locate the most expensive function call, you may use the nvprof to check if the existing kernel provided by pytorch meets your special use case. 52 | 53 | ```bash 54 | nvprof --analysis-metrics -o test.nvprof --print-gpu-trace python test.py 2>> nvprof.log 55 | ``` 56 | And you can also modify the cuda kernel codes and then build the pytorch from source. 57 | 58 | ##### tuning CUDA kernels 59 | some refs:[CUDA performance guide] [CUDA C programming guide] 60 | ##### build pytorch from source 61 | ```bash 62 | # make sure you uninstall the released version of pytorch 63 | pip uninstall pytorch 64 | 65 | # pull the source code from github 66 | git clone https://github.com/pytorch/pytorch 67 | # pull the submodule dependencies 68 | git submodule update --init 69 | # build from source 70 | python setup.py build [develop] 71 | ``` 72 | The optional `develop` is to specify whether the python files are mapped from git repo to `PYTHONPATH` or just copied to. 73 | -------------------------------------------------------------------------------- /tensor_add_explained.md: -------------------------------------------------------------------------------- 1 | Figure out what `b.add_(b)` has done 2 | === 3 | > This is a post about how to find the corresponding codes of a **Tensor** method, of the corresponding function `torch.add` 4 | 5 | Why do I dive into the `Tensor.add_()` 6 | --- 7 | 8 | These days I have been investigating the GFlops gap between theoretical PEAK and achieved of pytorch's `add` method. In brief, the `add` method performs element-wise addition between two **Tensors**. 9 | On my test-bed, equipped with 2 Xeon E5 2620v4 which can achieve 256 GFlops PEAK performance according to Intel, the `b.add_(b)` (inplace add self) can only achieves 6GFlops at best. 10 | 11 | I've tried to inspect the underlying code path for a simple matrix add. 12 | Finally after a hard day, I found that for such **level-1 blas** ($O(N)$) operation, the memory bandwidth should definitely be the bottleneck of overall performance, no matter what advanced **SIMD** instructions (*AVX*, *SSE*, *NEON*) are utilized. Then the real **6** GFlops makes sense. 13 | 14 | > bandwidth of current fastest DDR4 memory: 25.6 GB/s (X2 for load/save) [wiki-pedia](https://en.wikipedia.org/wiki/DDR4_SDRAM) 15 | > load single precision floats: $\frac{25.6}{4}=6.4 GFlops$ 16 | > for non-inplace addition, which needs to load 2 floats for an addition, the Flops is halved. 17 | 18 | 19 | How to find the codes into *AVX2* instructions 20 | --- 21 | 22 | ### Python binding and ATEN binding 23 | For a comprehensive study on the auto-generation build system, please refer to blog *dispatch*. 24 | ATEN is a `C/C++` Tensor library inspired by the need of pytorch, it is OO designed and also serves as the backend of pytorch's Python API. 25 | 26 | Actually the Python API simply does some `Python/C` type conversion and error checking, plus invoking the corresponding method of ATEN **Tensor**. 27 | 28 | Python binding and ATEN lib are generated at bulid time. 29 | Since ATEN is still under actively develop, currently pytorch uses two scheme concurrently for automatic code generation, *native* and *cwrap*. 30 | 31 | The declaration of `add` in cwrap file: 32 | ``` 33 | [[ 34 | name: add 35 | variants: 36 | - method 37 | - function 38 | return: argument 0 39 | options: 40 | - cname: add_scaled 41 | arguments: 42 | - arg: THTensor* result 43 | output: True 44 | - THTensor* self 45 | - real other 46 | - arg: real alpha 47 | default: AS_REAL(1) 48 | kwarg_only: True 49 | - cname: cadd 50 | aten_sparse: True 51 | arguments: 52 | - arg: THTensor* result 53 | output: True 54 | - arg: THTensor* self 55 | broadcast: other fallback 56 | - arg: real alpha 57 | default: AS_REAL(1) 58 | kwarg_only: True 59 | - THTensor* other 60 | - sparse: True 61 | cname: spcadd 62 | aten_dense_sparse: True 63 | arguments: 64 | - arg: THTensor* result 65 | output: True 66 | - THTensor* self 67 | - arg: real alpha 68 | default: AS_REAL(1) 69 | kwarg_only: True 70 | - THSTensor* other 71 | ]] 72 | ``` 73 | 74 | As you see, the `Tensor.add` method may be dispatched into three different backend methods, based on the arguments. Specifically for `add`, `Tensor.add(5)` is going to call `add_scaled`, while `Tensor.add(Tensor)` will call `cadd`, if the Tensor is sparse, `spcadd` should be invoked. 75 | 76 | Here because `add_` is simply an inplace alias of `add` which the first argument serves as both input Tensor and output Tensor. So we can simply find the corresponding backend name for `b.add_(b)`, **cadd**. 77 | 78 | ### Find the backend code of cadd method 79 | 80 | Then you may search for the definition of cadd in TH(CPU), THC(GPU), THS(sparse). Because the actually name may be wrapped by pytorch's `THTensor_()` or `THCUDATensor_()` prefix, you have to figure it out depends on your backend. Since I'm looking for codes on CPU, so the wanted signature is `void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src) 81 | `, which locates in `aten/src/TH/generic/THTensorMath.c`. 82 | 83 | ```clike= 84 | void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src) 85 | { 86 | THTensor_(resizeAs)(r_, t); 87 | int64_t r_Size = THTensor_(nElement)(r_); 88 | int64_t srcSize = THTensor_(nElement)(src); 89 | int r_Contig = THTensor_(isContiguous)(r_); 90 | int tContig = THTensor_(isContiguous)(t); 91 | int srcContig = THTensor_(isContiguous)(src); 92 | int serial_path = 0; 93 | if (srcSize == r_Size){ 94 | if (r_Contig && tContig && srcContig) { 95 | if(r_ == t) { 96 | THBlas_(axpy)(THTensor_(nElement)(t), value, THTensor_(data)(src), 1, THTensor_(data)(r_), 1); 97 | } else { 98 | TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len);); 99 | } 100 | else // Non-contiguous case 101 | ``` 102 | 103 | The `THBlas(axpy)` calls standard blas function. Since I don't want to debug with the blas lib source code, I force the code to run line *15* by inserting a 0 into line *12*. 104 | 105 | ### TH_TENSOR_APPLY3_CONTIG 106 | The `TH_TENSOR_APPLY3_CONTIG` is a macro to apply element-wise operation. It has two versions based on whether *OPENMP* support is enabled at compile time. When the size of *Tensor* is larger than the threshold (`TH_OMP_OVERHEAD_THREASHOLD`), the *Tensor* is split into `OMP_NUM_THREADS` parts, each processed by a thread. 107 | 108 | ```clike= 109 | #ifdef _OPENMP 110 | #define TH_TENSOR_APPLY3_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ 111 | { \ 112 | int inOmp = omp_in_parallel(); \ 113 | ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR1); \ 114 | PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ 115 | { \ 116 | size_t num_threads = omp_get_num_threads(); \ 117 | size_t tid = omp_get_thread_num(); \ 118 | ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ 119 | ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ 120 | TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ 121 | ptrdiff_t TENSOR1##_len = TH_TENSOR_end - TH_TENSOR_offset; \ 122 | TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1) + TH_TENSOR_offset; \ 123 | TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2) + TH_TENSOR_offset; \ 124 | TYPE3 *TENSOR3##_data = THTensor_(data)(TENSOR3) + TH_TENSOR_offset; \ 125 | CODE \ 126 | } \ 127 | } 128 | ``` 129 | 130 | 131 | ### THVector_(cadd) 132 | 133 | Since modern CPUs benefit from various SIMD(Single Instruction Multiple Data) instructions such as *SSE*, *AVX*, *NEON*, the backend instruction set of `THVector_(cadd)` is determined at **run time** by tranversing all the supported sets. 134 | 135 | The `THVector_(cadd)` is dynamically defined via function pointer in `aten/src/TH/generic/THVectorDispatch.cpp`, with some useful macros defined in `aten/src/TH/generic/simd/simd.h`. 136 | 137 | The codes for different SIMD is defined in `aten/src/TH/vector`. Since my CPU *Xeon E5 2620v4* supports *AVX2* which computes 256bit(8 single precision floats) in one instruction, I'm going to inspect the `AVX2.cpp`. 138 | 139 | ### THFloat_cadd_AVX2 140 | 141 | Nothing special except for loop unrolling. Actually I didn't see any performance improve of such unrolling, because the major bottleneck is memory access, not the branching or prediction fails. 142 | 143 | Unlike in GPU, the loop unrolling in CPU does not always gain speedup, because CPU is very good at sequential task and branching compared with GPU. Another reason is that the prediction fails at very low ratio when the tensor size goes up. 144 | 145 | refer to [stackoverflowQA](https://stackoverflow.com/questions/24196076/is-gcc-loop-unrolling-flag-really-effective) 146 | ```clike= 147 | void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) { 148 | ptrdiff_t i; 149 | __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); 150 | __m256 YMM0, YMM1, YMM2, YMM3; 151 | for (i=0; i<=((n)-16); i+=16) { 152 | YMM0 = _mm256_loadu_ps(y+i); 153 | YMM1 = _mm256_loadu_ps(y+i+8); 154 | YMM2 = _mm256_loadu_ps(x+i); 155 | YMM3 = _mm256_loadu_ps(x+i+8); 156 | YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2); 157 | YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3); 158 | _mm256_storeu_ps(z+i, YMM2); 159 | _mm256_storeu_ps(z+i+8, YMM3); 160 | } 161 | for (; i<(n); i++) { 162 | z[i] = x[i] + y[i] * c; 163 | } 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /walkthrough.md: -------------------------------------------------------------------------------- 1 | From Python to C/C++: a Pytorch walk-through 2 | === 3 | 4 | Based on `github.com/pytorch/pytorch` version `v0.2.0` 5 | 6 | **Pytorch**: a fast prototype deep learning framework on Python 7 | - dynamic 8 | - auto gradient computing 9 | 10 | Here we take a basic insight in how pytorch implements the dynamic graph and its' auto gradient computing based on a simple but complete example. 11 | 12 | A simple snippet for Pytorch: 13 | ```python 14 | import torch 15 | from torch.autograd import Variable 16 | 17 | a = Variable(torch.FloatTensor([2]), requires_grad=True) # Parameters 18 | b = Variable(torch.FloatTensor([3])) # Data 19 | loss = torch.mm(a, b).sum() 20 | 21 | loss.backward() 22 | ``` 23 | This snippet includes the whole pipeline of Pytorch's **auto gradient** and **dynamic graph** features. 24 | 25 | I'm going to walk-through the snippet line by line. 26 | 27 | Python 28 | --- 29 | 30 | 1. `import torch` 31 | - Load extension modules: `numpy`, `NVTX`(nvidia toolbox), `many functions written in C` 32 | - import interface functions defined in Python (this may override C defined function) 33 | - Define basic utilities: `is_tensor()`, `load/store` 34 | - Define `Storage` and `Tensor` classes: `FloatTensor`, `DoubleTensor` 35 | - Import most common subpackages 36 | >related file: 37 | > - top-level module initialization: `torch/__init__.py` 38 | 39 | 2. `from torch.autograd import Variable` 40 | 41 | 3. `a = Variable(torch.FloatTensor([2]))` 42 | - Python list `[2]` 43 | - `FloatStorage` of only element {2} 44 | - `FloatTensor` of one dimension, size is `(1,)`, stride=1, offset=1 45 | - `Variable`: with its' data being the previous tensor, holds some properties about the gradient requirement (`requires_grad=True`). 46 | 47 | 48 | >How is `Variable` implemented in `Pytorch`? 49 | >It mixed in a pure python class which subclass on a C-based `Variable`. 50 | > 51 | >related files: 52 | > - python class interface: `torch/autograd/variable.py` 53 | > - C base class: `torch/csrc/autograd/variable.h(cpp)` 54 | > - detailed doc: `torch/autograd/README.md` 55 | 56 | 4. `b = Variable(torch.FloatTensor([2]))` 57 | The same as above. 58 | 59 | 5. ` loss = torch.mm(a, b).sum()` 60 | - `torch.mm(a, b)` and `a.mm(b)`: there are many operations in Pytorch which is used in two alternative forms, and the operators can be both `Variable` and `Tensor`. 61 | - `torch.mm(a, b)` will be dispatched to the corresponding static method of class `Variable`. e.g. 62 | > `torch.mm(a, b)` is equal to `Variable.mm(a, b)` or `Tensor.mm(a, b)` depending on the type of `a`. It's implemented by a **dispatcher** in the Pytorch's top-level module definition.`torch/csrc/Module.h(cpp)`, it calls the function `a._torch.mm` of instance `a`, raise errors if the function does not exist. 63 | > Variable._torch is a container class where all the static methods for Variable exist. The member methods are dynamically added to _torch at **import** time. (TODO: removed to C at newest master Jan-26-2018) 64 | ```python 65 | for method in dir(Variable): 66 | # This will also wrap some methods that normally aren't part of the 67 | # funcitonal interface, but we don't care, as they won't ever be used 68 | if method.startswith('_') or method.endswith('_'): 69 | continue 70 | if hasattr(Variable._torch, method): 71 | continue 72 | as_static = staticmethod(getattr(Variable, method)) 73 | setattr(Variable._torch, method, as_static) 74 | 75 | ``` 76 | 77 | 6. `loss.backward()` 78 | - `Variable` type has an attribute named `_grad_fn`, which is a pointer to the creator of this variable, following the creator chains, chain rules are naturally applied. Thus by implementing corresponding `backward` methods for each functions, the auto gradient can be achieved. 79 | 80 | C 81 | --- 82 | 1. `Variable` 83 | 84 | A `Variable` is a wrapper of `Tensor`, it stores extra informations. Computation performed on `Variable` is stored for future backpropagation. `torch/autograd/variable.py, ./_functions/*.py` 85 | 86 | Variable subclasses a base class `torch._C._VariableBase`, which is defined in `torch/csrc/autograd/python_variable.cpp(h)` and `./variable.cpp(h)`. 87 | There are 2 critical C types here respectively `THPVariable` and `Variable`. `THPVariable` is simply a python wrapper of C typed `Variable`, and the `THPVariable` type is encapsuled into a python class named `torch._C._VariableBase`. 88 | 89 | The reason to wrap the `cdata` is to access the variable conveniently via C codes. 90 | 91 | > THPVariable wrapper: 92 | ```clike 93 | struct THPVariable { 94 | PyObject_HEAD 95 | // Payload 96 | std::shared_ptr cdata; 97 | // Tensor this wraps (corresponds to Python attr 'data'). 98 | // It assumed that a THPVariable is *uniquely* identified by the 99 | // tensor it wraps. 100 | // Invariant: v->data == v->cdata->data 101 | PyObject* data; 102 | // Hooks to be run on backwards pass (corresponds to Python attr 103 | // '_backwards_hooks', set by 'register_hook') 104 | PyObject* backward_hooks; 105 | }; 106 | ``` 107 | > Variable extra attributes 108 | ```clike 109 | 110 | std::unique_ptr data; 111 | std::shared_ptr grad_fn; 112 | std::shared_ptr grad; 113 | std::unique_ptr version_counter; 114 | std::vector> hooks; 115 | std::weak_ptr grad_accumulator; 116 | std::mutex grad_accumulator_lock; 117 | bool requires_grad; 118 | bool is_volatile; 119 | // The "output number" of this variable; e.g., if this variable 120 | // was the second output of a function, then output_nr == 1. 121 | // We use this to make sure we can setup the backwards trace 122 | // correctly when this variable is passed to another function. 123 | int output_nr; 124 | ``` 125 | 126 | 2. `Variable.backward()` 127 | 128 | The method `backward` calls function `backward` in `torch/autograd/__init__.py`, (implicitly passing *self* as first argument). 129 | After some book-keeping operations, it comes to the C part which `Pytorch` calls `ImperativeEngine` (In `torch/csrc/autograd/python_engine.cpp(h)`. 130 | This engine transverse the whole computation graph and compute the gradients w.r.t. the original caller Variable. 131 | 132 | 133 | >The **engine** was moved from `python` to `cpp` at version *0.2.0* in favor of performance. It's a trivial task to operating `python` objects in `cpp` codes, so I prefer to read the python-engine at version *0.1.1* for simplicity. [see discussions](https://discuss.pytorch.org/t/how-to-understand-pytorchs-source-code/7600/2) 134 | 135 | ## The logic behind `autograd.Variable.backward`, `Function` and `Engine` 136 | 137 | | ![The forward and backward graph built in Pytorch](http://image.ibb.co/iXC5uR/pytorch_autograd.png) | 138 | |:--:| 139 | |*The forward & backward graph built in Pytorch* | 140 | 1. computation graph for reverse mode automatic gradient 141 | The [*reverse mode*](https://justindomke.wordpress.com/2009/03/24/a-simple-explanation-of-reverse-mode-automatic-differentiation/) automatic gradient computing is perfect for back-propagation. The process starts from the final output and computes the partial derivatives w.r.t. every input elements. Such algorithm requires the data structure to be able to transverse the whole graph starting from output variable, thus a directed edge from output to input is a natural choice (a graph with edge from input to output is stated explicitly at forward step of you `nn` structure, while the graph with reverse edge is built at forward time). 142 | 143 | 2. basic components in `pytorch` 144 | The basic components of the computation graph in `pytorch` are `Variable` (edge) and `Function` (node). The dependencies between two `Functions` (nodes) are introduced by the `Variable` (edge). The autograd related data structures of both `Variable` and `Function` will be introduced below. (version 0.1.1, slightly different on newest release especially naming conventions) 145 | > Variable: 146 | > - `creator` (changed to `grad_fn` in version 0.2.0): a reference to the producer of this `Variable`, which is also responsible for computing the gradients of output w.s.t. each input `Variable` 147 | 148 | > Function: 149 | > - needs_input_grad: a tuple indicates whether the input `requires_grad`, this properties can be utilized to further avoid useless gradient computing. 150 | > - requires_grad: if all the inputs do not need gradients. It's also of performance consideration. 151 | > - previous_functions: a list of tuple (input_var.creator, id(input_var)), the `id` field is used in `.creator` for position lookup. 152 | > - output_ids: a map from variable id to the position of output variable. 153 | 154 | | ![Function object in Pytorch](http://image.ibb.co/hav3g6/pytorch_autograd_function.png) | 155 | |:--:| 156 | | *Function object in Pytorch* | 157 | 158 | 3. Build the reverse graph at forward time 159 | ```python 160 | class Function(object): 161 | 162 | def __init__(self): 163 | self.previous_functions = None 164 | self.output_ids = None 165 | self.needs_input_grad = None 166 | self.backward_hooks = OrderedDict() 167 | 168 | def __call__(self, *input): 169 | return self._do_forward(*input) 170 | 171 | def _do_forward(self, *input): 172 | unpacked_input = tuple(arg.data for arg in input) 173 | raw_output = self.forward(*unpacked_input) 174 | if not isinstance(raw_output, tuple): 175 | raw_output = (raw_output,) 176 | self.needs_input_grad = tuple(arg.creator.requires_grad for arg in input) 177 | self.requires_grad = any(self.needs_input_grad) 178 | output = tuple(Variable(tensor, self) for tensor in raw_output) 179 | 180 | self.previous_functions = [(arg.creator, id(arg)) for arg in input] 181 | self.output_ids = {id(var): i for i, var in enumerate(output)} 182 | return output 183 | ``` 184 | - `Function` does not hold any `Variable`, instead, `Tensor` are stored for future gradient computing. 185 | - a `Variable` is unpacked to the underlying `Tensor` at `Function` level. The following operations will invoke the `Tensor` version ones. 186 | > e.g. `Variable.mm(a, b)` will invoke `Tensor.mm(a.data, b.data)` at forward time. 187 | - As `Variable` are all unpacked into `Tensor`, the history information will be lost if not explicit stored. So we store the creator of each input `Variable` into `self.previous_functions` at forward time. 188 | - For `Function` which has multiple outputs, the order information is unknown at backward time, because user may permute the output of a function just for fun. Even if the engine is able to get the creator of a `Variable`, it can't assign the gradient to the proper output of the creator. `self.output_ids` maps the id of output into the position of the variable in creator's outputs. 189 | 190 | 3. backward of `Function` and the `engine` 191 | 192 | The backward process is quite straight for `Function`, it simply computes the gradient of input `Variable` given the gradient of output `Variable` (w.s.t. starting `Variable` that calls `.backward()`). 193 | ```python 194 | class Function(object): 195 | # ... 196 | def _do_backward(self, grad_output): 197 | grad_input = self.backward(grad_output) 198 | # ... some trivial post-processing including `hook` and pack results 199 | return grad_input 200 | def backward(self, grad_output): 201 | raise NotImplementedError 202 | ``` 203 | 204 | The `backward engine` does some dirty work behind a simple `Function` API, including dispatching the gradient to proper `creator` and ensuring the gradients are completely accumulated before applying chain rule to the next `Function`. 205 | 206 | ```python 207 | class ExecutionEngine(object): 208 | 209 | def _compute_dependencies(self, function): 210 | """tranverse the computation graph to get the dependencies 211 | 212 | BFS tranverse the function starting from final output. The 213 | dependencies is a collection of counters. 214 | """ 215 | # compute the dependencies 216 | return dependencies 217 | 218 | def _free_backward_dependency(self, dependencies, prev_fn, fn, arg_id): 219 | """Update the dependencies after backward one function 220 | 221 | Return: 222 | output_idx: the position of the arg in the outputs of prev_fn 223 | """ 224 | # Update dependencies and return the position 225 | return output_idx 226 | 227 | 228 | def _is_ready_for_backward(self, dependencies, function): 229 | """Check if the node function is ready 230 | 231 | The ready status is determined by the in-degree of the node function 232 | a.k.a the dependencies[function][output] are all 0s. 233 | """ 234 | for deps in dependencies[function]: 235 | if len(deps) > 0: 236 | return False 237 | return True 238 | 239 | def run_backward(self, variable, grad): 240 | """The core part of backward engine 241 | 242 | It calls the backward method of Functions and dispatches the gradients 243 | to the correct previous functions. The method is returned until all 244 | functions are `backward`ed successfully 245 | """ 246 | 247 | # set up the starting point, the grad is 1 without explicitly 248 | # setting when calling `Variable.backward()` 249 | ready = [(variable.creator, (grad,))] # Functions ready for BP 250 | not_ready = {} # Functions whose gradients are not accumulated properly 251 | 252 | dependencies = self._compute_dependencies(variable.creator) 253 | 254 | while len(ready) > 0: 255 | fn, grad = ready.pop() 256 | # TODO: double-buffering 257 | grad_input = fn._do_backward(*grad) 258 | 259 | # Update the dependencies of all the input Variables of function fn 260 | # arg_id is used for position looking-up. 261 | for (prev_fn, arg_id), d_prev_fn in zip(fn.previous_functions, grad_input): 262 | 263 | # skipping useless gradients computing 264 | if not prev_fn.requires_grad: 265 | assert d_prev_fn is None 266 | continue 267 | output_nr = self._free_backward_dependency(dependencies, prev_fn, fn, arg_id) 268 | is_ready = self._is_ready_for_backward(dependencies, prev_fn) 269 | 270 | # the following codes are a little bit messy as 271 | # the two branches partially share a same underlying logic 272 | # , accumluating the current gradient to existing one. 273 | 274 | # if the `perv_fn` is ready for backward, then move 275 | # it from `not_ready` to `ready` 276 | if is_ready: 277 | if prev_fn in not_ready: 278 | prev_grad = not_ready[prev_fn] 279 | if not prev_grad[output_nr]: 280 | prev_grad[output_nr] = d_prev_fn 281 | else: 282 | prev_grad[output_nr].add_(d_prev_fn) 283 | del not_ready[prev_fn] 284 | else: 285 | # The `prev_fn` is ready when the first seen, 286 | # there must be only one output in `prev_fn` 287 | assert output_nr == 0 288 | prev_grad = (d_prev_fn,) 289 | ready.append((prev_fn, prev_grad)) 290 | else: 291 | if prev_fn in not_ready: 292 | prev_grad = not_ready[prev_fn] 293 | else: 294 | prev_grad = [None for _ in prev_fn.output_ids] 295 | 296 | if not prev_grad[output_nr]: 297 | prev_grad[output_nr] = d_prev_fn 298 | else: 299 | prev_grad[output_nr].add_(d_prev_fn) 300 | 301 | not_ready[prev_fn] = prev_grad 302 | ``` 303 | 304 | The auxiliary methods are listed below in detail: 305 | ```python 306 | class ExecutionEngine(object): 307 | def __init__(self): 308 | pass 309 | 310 | def _compute_dependencies(self, function): 311 | """tranverse the computation graph to get the dependencies 312 | 313 | BFS tranverse the function starting from final output. The 314 | dependencies is a collection of counters. 315 | """ 316 | dependencies = {} 317 | seen = {function} 318 | queue = [function] 319 | while len(queue) > 0: 320 | fn = queue.pop() 321 | for prev_fn, arg_id in fn.previous_functions: 322 | if prev_fn not in dependencies: 323 | dependencies[prev_fn] = [Counter() for _ in prev_fn.output_ids] 324 | output_idx = prev_fn.output_ids[arg_id] 325 | 326 | # I think there's no need to store the counter for each current function, 327 | # a simple counter for each prev_fn should be enough. 328 | dependencies[prev_fn][output_idx][fn] += 1 329 | if prev_fn not in seen: 330 | queue.append(prev_fn) 331 | seen.add(prev_fn) 332 | return dependencies 333 | 334 | def _free_backward_dependency(self, dependencies, prev_fn, fn, arg_id): 335 | """Update the dependencies after backward one function 336 | 337 | Return: 338 | output_idx: the position of the arg in the outputs of prev_fn 339 | """ 340 | deps = dependencies[prev_fn] 341 | output_idx = prev_fn.output_ids[arg_id] 342 | output_deps = deps[output_idx] 343 | output_deps[fn] -= 1 344 | if output_deps[fn] == 0: 345 | del output_deps[fn] 346 | return output_idx 347 | 348 | 349 | def _is_ready_for_backward(self, dependencies, function): 350 | """Check if the node function is ready 351 | 352 | The ready status is determined by the in-degree of the node function 353 | a.k.a the dependencies[function][output] are all 0s. 354 | """ 355 | for deps in dependencies[function]: 356 | if len(deps) > 0: 357 | return False 358 | return True 359 | ``` 360 | --------------------------------------------------------------------------------