├── docs ├── images │ ├── h36m_pred.png │ ├── mot16_pred.png │ ├── h36m_heatmap.png │ ├── h36m_pred.bak.png │ ├── mot16_heatmap.png │ └── heatmap_legend_vertical_0-100.png └── index.html ├── src ├── cuda │ ├── conv_kernel.cuh │ ├── deconv_kernel.cuh │ ├── other_nn_layers.cuh │ ├── common.cu │ ├── common.cuh │ ├── Utility.cuh │ └── conv_torch_wrapper.cpp └── deltacnn │ ├── __init__.py │ ├── filter_conversion.py │ ├── utils.py │ ├── cuda_kernels.py │ └── logging_layers.py ├── CONTRIBUTING.md ├── setup.py ├── CODE_OF_CONDUCT.md ├── example ├── mobilenetv2_webcam_example.py ├── mobilenet_original.py ├── mobilenet_deltacnn.py └── imagenet_classes.txt ├── readme.md └── LICENSE.txt /docs/images/h36m_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/h36m_pred.png -------------------------------------------------------------------------------- /docs/images/mot16_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/mot16_pred.png -------------------------------------------------------------------------------- /docs/images/h36m_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/h36m_heatmap.png -------------------------------------------------------------------------------- /docs/images/h36m_pred.bak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/h36m_pred.bak.png -------------------------------------------------------------------------------- /docs/images/mot16_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/mot16_heatmap.png -------------------------------------------------------------------------------- /docs/images/heatmap_legend_vertical_0-100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DeltaCNN/HEAD/docs/images/heatmap_legend_vertical_0-100.png -------------------------------------------------------------------------------- /src/cuda/conv_kernel.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include "common.cuh" 10 | #include 11 | 12 | 13 | template 14 | void deltacnn(scalar_t *input, scalar_t *output, scalar_t *filter, scalar_t *bias, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config); 15 | 16 | void deltacnn_hp(half *input, half *output, half *filter, half *bias, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config); 17 | 18 | void init_d_metrics_conv_kernels(); -------------------------------------------------------------------------------- /src/cuda/deconv_kernel.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include "common.cuh" 10 | #include 11 | 12 | 13 | template 14 | void delta_deconv(scalar_t *input, scalar_t *output, scalar_t *filter, scalar_t *bias, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config); 15 | 16 | void delta_deconv_hp(half *input, half *output, half *filter, half *bias, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config); 17 | 18 | void init_d_metrics_deconv_kernels(); 19 | -------------------------------------------------------------------------------- /src/deltacnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .cuda_kernels import sparse_conv, sparse_deconv, sparse_pooling 7 | from .cuda_kernels import sparse_activation, sparsify, sparse_add_tensors, sparse_add_to_dense_tensor, sparse_upsample, sparse_concatenate, sparse_mul_add 8 | from .sparse_layers import DCConv2d, DCConvTranspose2d, DCMaxPooling, DCAdaptiveAveragePooling, DCDensify, DCAdd, DCActivation, DCUpsamplingNearest2d, DCSparsify, DCThreshold, DCBackend, DCModule, DCBatchNorm2d, DCConcatenate, DCTruncation 9 | from .cuda_kernels import DCPerformanceMetricsManager, DCPerformanceMetrics 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to UNOC 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to UNOC, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from setuptools import setup, find_packages 8 | from torch.utils.cpp_extension import ( 9 | CUDAExtension, 10 | CUDA_HOME, 11 | BuildExtension, 12 | ) 13 | _DEBUG = False 14 | _DEBUG_LEVEL = 0 15 | 16 | # Common flags for both release and debug builds. 17 | # extra_compile_args = sysconfig.get_config_var('CFLAGS').split() 18 | # extra_compile_args = ["-std=c++17"] 19 | extra_compile_args = [] 20 | extra_compile_args += ["-DNDEBUG", "-O3", "-lineinfo"] 21 | extra_compile_args = { 22 | 'gcc': extra_compile_args, 23 | 'nvcc': [*extra_compile_args, "--ptxas-options=-v"] 24 | } 25 | 26 | modules = [] 27 | 28 | if CUDA_HOME: 29 | modules.append( 30 | CUDAExtension( 31 | "deltacnn.cuda", 32 | [ 33 | "src/cuda/common.cu", 34 | "src/cuda/conv_torch_wrapper.cpp", 35 | "src/cuda/conv_kernel.cu", 36 | "src/cuda/deconv_kernel.cu", 37 | "src/cuda/other_nn_layers.cu", 38 | ], 39 | extra_compile_args=extra_compile_args, 40 | language='c++17' 41 | ) 42 | ) 43 | 44 | setup( 45 | name="torchdeltacnn", 46 | packages=find_packages(where="src"), 47 | package_dir={"": "src"}, 48 | ext_modules=modules, 49 | cmdclass={"build_ext": BuildExtension}, 50 | ) -------------------------------------------------------------------------------- /src/deltacnn/filter_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | def convert_filter_out_channels_last(filter, transposed=False): 9 | if transposed: 10 | return torch.transpose(torch.transpose(filter, 2, 1), 3, 2).contiguous().clone() 11 | return torch.transpose(torch.transpose(torch.transpose(filter, 1, 0), 2, 1), 3, 2).contiguous().clone() 12 | 13 | def convert_half_filter(x, pixel_wise=False, transposed=False): 14 | def add_available(a, b): 15 | c_in = b.shape[0] 16 | c_out = b.shape[-1] 17 | a[:c_in,:,:,:c_out].add_(b) 18 | 19 | if transposed: 20 | x = torch.transpose(x, 0, 1) 21 | 22 | c_in = x.shape[1] 23 | c_out = x.shape[0] 24 | 25 | # align to 64 bit 26 | c_out_new = c_out + (c_out % 2) 27 | c_in_new = c_in + (c_in % 2) 28 | result = torch.zeros((c_in_new, x.shape[-2], x.shape[-1], c_out_new), device=x.device, dtype=torch.half) 29 | 30 | x_out_last = convert_filter_out_channels_last(x) 31 | if pixel_wise: 32 | return x_out_last 33 | 34 | add_available(result[::2,:,:,::2], x_out_last[::2,:,:,::2]) 35 | add_available(result[::2,:,:,1::2], x_out_last[1::2,:,:,1::2]) 36 | add_available(result[1::2,:,:,::2], x_out_last[1::2,:,:,::2]) 37 | add_available(result[1::2,:,:,1::2], x_out_last[::2,:,:,1::2]) 38 | 39 | return result -------------------------------------------------------------------------------- /src/cuda/other_nn_layers.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include "common.cuh" 10 | 11 | void init_d_metrics_other_nn_layers(); 12 | 13 | template 14 | void activate_truncate(scalar_t* delta, scalar_t* prev_input, scalar_t* truncated, uint32_t *mask, float threshold, Dimensions dim, int activation, int truncation_mode); 15 | 16 | template 17 | void activate_truncate_hp(scalar_t* delta, scalar_t* prev_input, scalar_t* truncated, uint32_t *mask, float threshold, Dimensions dim, int activation, int truncation_mode); 18 | 19 | template 20 | void prepare_diff_mask(scalar_t* input, scalar_t* prev_input, scalar_t* delta, uint32_t *mask, float threshold, Dimensions dim); 21 | 22 | template 23 | void prepare_diff_mask_hp(scalar_t* input, scalar_t* prev_input, scalar_t* delta, uint32_t *mask, float threshold, Dimensions dim); 24 | 25 | template 26 | void sparse_add_tensors(scalar_t* a, scalar_t* b, scalar_t* prev_out, scalar_t* out, uint32_t *mask_a, uint32_t *mask_b, uint32_t *mask_out, scalar_t weight_a, scalar_t weight_b, Dimensions dim, int activation, bool dense_out); 27 | 28 | template 29 | void sparse_add_tensors_hp(scalar_t* a, scalar_t* b, scalar_t* prev_out, scalar_t* out, uint32_t *mask_a, uint32_t *mask_b, uint32_t *mask_out, float weight_a, float weight_b, Dimensions dim, int activation, bool dense_out); 30 | 31 | template 32 | void sparse_add_to_dense_tensor_sp(scalar_t* a, scalar_t* b, uint32_t *mask_a, Dimensions dim, int activation); 33 | 34 | template 35 | void sparse_add_to_dense_tensor_hp(scalar_t* a, scalar_t* b, uint32_t *mask_a, Dimensions dim, int activation); 36 | 37 | template 38 | void sparse_upsample(scalar_t* in, scalar_t* out, uint32_t *mask_in, uint32_t *mask_out, Dimensions dim, int scale); 39 | 40 | template 41 | void sparse_upsample_hp(scalar_t* in, scalar_t* out, uint32_t *mask_in, uint32_t *mask_out, Dimensions dim, int scale); 42 | 43 | template 44 | void sparse_concatenate(scalar_t* a, scalar_t* b, scalar_t* out, uint32_t *mask_a, uint32_t *mask_b, uint32_t *mask_out, Dimensions dim); 45 | 46 | template 47 | void sparse_concatenate_hp(scalar_t* a, scalar_t* b, scalar_t* out, uint32_t *mask_a, uint32_t *mask_b, uint32_t *mask_out, Dimensions dim); 48 | 49 | template 50 | void sparse_pool(scalar_t* input, scalar_t* prev_input, scalar_t* out, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config, int pooling_mode); 51 | 52 | template 53 | void sparse_pool_hp(scalar_t* input, scalar_t* prev_input, scalar_t* out, uint32_t *mask, uint32_t *out_mask, Dimensions dim, ConvConfig config, int pooling_mode); 54 | 55 | template 56 | void sparse_mul_add(scalar_t* in, uint32_t *mask, scalar_t *out, uint32_t *mask_out, scalar_t *scale, scalar_t *bias, Dimensions dim); 57 | 58 | template 59 | void sparse_mul_add_hp(scalar_t* in, uint32_t *mask, scalar_t *out, uint32_t *mask_out, scalar_t *scale, scalar_t *bias, Dimensions dim); -------------------------------------------------------------------------------- /src/cuda/common.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "common.cuh" 7 | 8 | __device__ DCMetrics *d_metrics; 9 | 10 | __global__ void check_d_metrics_ptr() { 11 | if (threadIdx.x == 0) { 12 | printf("common.cu &d_metrics=%p\n", d_metrics); 13 | printf("common.cu &d_metrics.vals_read_dense = %p\n", &d_metrics->n_vals_read_dense); 14 | printf("common.cu &d_metrics.vals_written_dense = %p\n", &d_metrics->n_vals_written_dense); 15 | d_metrics->n_active_flops += 1; 16 | printf("common.cu d_metrics.n_active_flops=%i\n", d_metrics->n_active_flops); 17 | } 18 | } 19 | 20 | 21 | bool init_performance_metrics() { 22 | #ifdef ENABLE_METRICS 23 | HANDLE_ERROR(cudaMalloc(&d_metrics_ptr_copy, sizeof(DCMetrics))); 24 | HANDLE_ERROR(cudaMemset(d_metrics_ptr_copy, 0, sizeof(DCMetrics))); 25 | copy_performance_metrics_to_gpu(d_metrics); 26 | return true; 27 | #else 28 | return false; 29 | #endif 30 | } 31 | 32 | void copy_performance_metrics_to_gpu(DCMetrics*& d) { 33 | #ifdef ENABLE_METRICS 34 | HANDLE_ERROR(cudaMemcpyToSymbol(d, &d_metrics_ptr_copy, sizeof(DCMetrics*))); 35 | #endif 36 | } 37 | 38 | void reset_performance_metrics() { 39 | #ifdef ENABLE_METRICS 40 | HANDLE_ERROR(cudaMemset(d_metrics_ptr_copy, 0, sizeof(DCMetrics))); 41 | #endif 42 | } 43 | 44 | std::vector retrieve_metrics() { 45 | #ifdef ENABLE_METRICS 46 | DCMetrics h_d_metric; 47 | HANDLE_ERROR(cudaMemcpy(&h_d_metric, d_metrics_ptr_copy, sizeof(DCMetrics), cudaMemcpyDeviceToHost)); 48 | 49 | torch::Tensor tiles = torch::zeros({2}, torch::TensorOptions().dtype(torch::kInt64)); 50 | torch::Tensor inputs = torch::zeros({2}, torch::TensorOptions().dtype(torch::kInt64)); 51 | torch::Tensor mode = torch::zeros({2}, torch::TensorOptions().dtype(torch::kInt64)); 52 | torch::Tensor flops = torch::zeros({3}, torch::TensorOptions().dtype(torch::kInt64)); 53 | torch::Tensor memtransfer = torch::zeros({4}, torch::TensorOptions().dtype(torch::kInt64)); 54 | torch::Tensor histrogram = torch::zeros({DCMetrics::histogram_samples}, torch::TensorOptions().dtype(torch::kInt64)); 55 | 56 | tiles.data_ptr()[0] = int64_t(h_d_metric.n_active_tiles); 57 | tiles.data_ptr()[1] = int64_t(h_d_metric.n_tiles); 58 | inputs.data_ptr()[0] = int64_t(h_d_metric.n_active_inputs); 59 | inputs.data_ptr()[1] = int64_t(h_d_metric.n_inputs); 60 | mode.data_ptr()[0] = int64_t(h_d_metric.n_tiles_sparse_mode); 61 | mode.data_ptr()[1] = int64_t(h_d_metric.n_tiles_dense_mode); 62 | flops.data_ptr()[0] = int64_t(h_d_metric.n_active_flops); 63 | flops.data_ptr()[1] = int64_t(h_d_metric.n_theoretical_flops); 64 | flops.data_ptr()[2] = int64_t(h_d_metric.n_dense_flops); 65 | memtransfer.data_ptr()[0] = int64_t(h_d_metric.n_vals_read); 66 | memtransfer.data_ptr()[1] = int64_t(h_d_metric.n_vals_read_dense); 67 | memtransfer.data_ptr()[2] = int64_t(h_d_metric.n_vals_written); 68 | memtransfer.data_ptr()[3] = int64_t(h_d_metric.n_vals_written_dense); 69 | 70 | int64_t *histogram_ptr = histrogram.data_ptr(); 71 | for (int i = 0; i < DCMetrics::histogram_samples; i++) { 72 | histogram_ptr[i] = int64_t(h_d_metric.active_input_histogram[i]); 73 | } 74 | return {tiles, inputs, mode, flops, memtransfer, histrogram}; 75 | #endif 76 | return {}; 77 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /src/cuda/common.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | enum PaddingMode {zeros, repeat, mirror}; 13 | 14 | struct ConvConfig { 15 | uint16_t kernel_size[2]; 16 | uint16_t stride[2]; 17 | uint16_t dilation[2]; 18 | uint16_t padding[4]; 19 | PaddingMode padding_mode; 20 | uint16_t groups; 21 | bool sub_tile_sparsity; 22 | bool set_sparse_zero; 23 | }; 24 | 25 | struct ImageDimension{ 26 | uint16_t h; 27 | uint32_t w; 28 | uint16_t c; 29 | }; 30 | 31 | struct Dimensions { 32 | uint16_t batch_size; 33 | ImageDimension in; 34 | ImageDimension out; 35 | }; 36 | 37 | // #define ENABLE_METRICS 38 | 39 | 40 | 41 | struct DCMetrics { 42 | // these values are only tracked inside convolutional kernels 43 | uint64_t n_active_tiles; 44 | uint64_t n_tiles; 45 | uint64_t n_active_inputs; 46 | uint64_t n_inputs; 47 | uint64_t n_tiles_sparse_mode; 48 | uint64_t n_tiles_dense_mode; 49 | uint64_t n_active_flops; 50 | uint64_t n_theoretical_flops; 51 | uint64_t n_dense_flops; 52 | 53 | // these values are tracked in all layers 54 | static const bool track_filter_reads = true; 55 | uint64_t n_vals_read; 56 | uint64_t n_vals_read_dense; 57 | uint64_t n_vals_written; 58 | uint64_t n_vals_written_dense; 59 | 60 | static const uint64_t histogram_samples = 128; 61 | // histogram is only tracked inside convolutional kernels 62 | uint64_t active_input_histogram[histogram_samples]; 63 | 64 | static const uint64_t n_samples_total = histogram_samples + 13; 65 | }; 66 | 67 | #ifdef ENABLE_METRICS 68 | struct DCMetrics_ptrs { 69 | // these values are only tracked inside convolutional kernels 70 | uint64_t *n_active_tiles; 71 | uint64_t *n_tiles; 72 | uint64_t *n_active_inputs; 73 | uint64_t *n_inputs; 74 | uint64_t *n_tiles_sparse_mode; 75 | uint64_t *n_tiles_dense_mode; 76 | uint64_t *n_active_flops; 77 | uint64_t *n_theoretical_flops; 78 | uint64_t *n_dense_flops; 79 | 80 | // these values are tracked in all layers 81 | static const bool track_filter_reads = true; 82 | uint64_t *n_vals_read; 83 | uint64_t *n_vals_read_dense; 84 | uint64_t *n_vals_written; 85 | uint64_t *n_vals_written_dense; 86 | 87 | static const uint64_t histogram_samples = 128; 88 | // histogram is only tracked inside convolutional kernels 89 | uint64_t *active_input_histogram; 90 | 91 | static const uint64_t n_samples_total = histogram_samples + 13; 92 | }; 93 | 94 | static DCMetrics *d_metrics_ptr_copy; 95 | static DCMetrics h_metrics; 96 | #endif 97 | 98 | bool init_performance_metrics(); 99 | void reset_performance_metrics(); 100 | std::vector retrieve_metrics(); 101 | void copy_performance_metrics_to_gpu(DCMetrics*& d); 102 | 103 | inline static void HandleError(cudaError_t err, 104 | const char *file, 105 | int line) 106 | { 107 | if (err != cudaSuccess) 108 | { 109 | printf("%s in %s at line %d\n", cudaGetErrorString(err), 110 | file, line); 111 | throw std::exception(); 112 | } 113 | } 114 | // #ifdef _DEBUG || NDEBUG || DEBUG 115 | #define HANDLE_ERROR(err) (HandleError(err, __FILE__, __LINE__)) 116 | // #else 117 | // #define HANDLE_ERROR(err) err 118 | // #endif -------------------------------------------------------------------------------- /example/mobilenetv2_webcam_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # based on the MobileNetv2 implementation from PyTorch 7 | # source: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html 8 | # and: https://pytorch.org/hub/pytorch_vision_mobilenet_v2/ 9 | 10 | # original license and copyright: 11 | # BSD 3-Clause License 12 | 13 | # Copyright (c) Soumith Chintala 2016, 14 | # All rights reserved. 15 | 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | 42 | ######################################################################################################## 43 | # MobileNetv2 Webcam example 44 | # 45 | # This is a very simple example that shows how DeltaCNN can be used as a replacement for torch.nn layers. 46 | # Please take a look at the changes made to mobilenet for the deltacnn version. 47 | # All changes are marked with a '# added' or '# replaced by' comment. 48 | # 49 | # This example uses weights pretrained on ImageNet. 50 | # Also, the webcam is used as video input to avoid having to download videos and for being able to play 51 | # around with the camera. 52 | # Adjust the delta_threshold to see how it affects the predictions. 53 | # 54 | ######################################################################################################## 55 | 56 | 57 | 58 | import torch 59 | from torch import nn 60 | from deltacnn.sparse_layers import DCBackend, DCConv2d, DCThreshold 61 | from mobilenet_original import mobilenet_v2 62 | from mobilenet_deltacnn import DeltaCNN_mobilenet_v2 63 | 64 | def test(): 65 | from PIL import Image 66 | from torchvision import transforms 67 | import cv2 68 | 69 | device="cuda:0" 70 | 71 | original_model = mobilenet_v2(pretrained=True, progress=True) 72 | original_model.eval() 73 | original_model.to(device, memory_format=torch.channels_last) 74 | dc_model = DeltaCNN_mobilenet_v2(pretrained=True, progress=True) 75 | dc_model.eval() 76 | dc_model.to(device, memory_format=torch.channels_last) 77 | dc_model.process_filters() 78 | 79 | preprocess = transforms.Compose([ 80 | transforms.Resize(256), 81 | transforms.CenterCrop(224), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 84 | ]) 85 | 86 | # Read the categories 87 | with open("example/imagenet_classes.txt", "r") as f: 88 | categories = [s.strip() for s in f.readlines()] 89 | 90 | camera = cv2.VideoCapture(0) 91 | 92 | time_start = torch.cuda.Event(enable_timing=True) 93 | time_end = torch.cuda.Event(enable_timing=True) 94 | 95 | while True: 96 | ret, input_image = camera.read() 97 | if not ret: 98 | break 99 | 100 | # cv2.imshow("cam", input_image) 101 | # if cv2.waitKey(1) == 27: 102 | # break # esc to quit 103 | 104 | input_tensor = preprocess(Image.fromarray(input_image)) 105 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 106 | 107 | # increase batch size to increase workload 108 | input_batch = torch.repeat_interleave(input_batch, 32, dim=0) 109 | 110 | # move the input and model to GPU for speed if available 111 | input_batch = input_batch.to(device).contiguous(memory_format=torch.channels_last) 112 | 113 | torch.cuda.synchronize() 114 | with torch.no_grad(): 115 | time_start.record() 116 | original_output = original_model(input_batch) 117 | time_end.record() 118 | torch.cuda.synchronize() 119 | duration_original = time_start.elapsed_time(time_end) 120 | 121 | with torch.no_grad(): 122 | time_start.record() 123 | dc_output = dc_model(input_batch) 124 | time_end.record() 125 | torch.cuda.synchronize() 126 | duration_dc = time_start.elapsed_time(time_end) 127 | 128 | # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes 129 | # print(output[0]) 130 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 131 | probabilities = torch.nn.functional.softmax(dc_output[0], dim=0) 132 | # print(probabilities) 133 | 134 | # Show top categories per image 135 | top5_prob, top5_catid = torch.topk(probabilities, 5) 136 | print("\r", end="") 137 | for i in range(top5_prob.size(0)): 138 | print(f"{categories[top5_catid[i]]:<16} {top5_prob[i].item():.3f} ", end="") 139 | 140 | print(f"original: {duration_original:.2f}ms, dc: {duration_dc:.2f}ms out_diff_mean={(dc_output[0]-original_output[0]).abs().mean():.3f} ", end="") 141 | 142 | 143 | if __name__ == "__main__": 144 | if not torch.cuda.is_available(): 145 | print("cuda not available. example.py requires cuda") 146 | exit(-1) 147 | 148 | # using a low default threshold of 0.05. play around with this value to see how it affects performance and accuracy. 149 | DCThreshold.t_default = 0.05 150 | DCConv2d.backend = DCBackend.deltacnn 151 | test() -------------------------------------------------------------------------------- /src/deltacnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.nn.functional import conv2d 8 | 9 | 10 | def tile(a, dim, n_tile): 11 | if type(dim) != torch.Tensor: 12 | dim = torch.tensor(dim) 13 | if type(n_tile) != torch.Tensor: 14 | n_tile = torch.tensor(n_tile) 15 | return torch.repeat_interleave(a, n_tile.to(a.device), dim.to(a.device)) 16 | 17 | 18 | def scale_conv_mask(inactive_mask, sparse_dilation, kernel_size, padding, stride, bias, tile_size=-1, tile_thresh=0.0, dilation=1): 19 | if sparse_dilation is None: 20 | mask = inactive_mask 21 | padding = (kernel_size - 1 - 2 * padding) // 2 22 | if padding > 0: 23 | mask = mask[:, :, padding:-padding, padding:-padding] 24 | if type(stride) == tuple: 25 | mask = mask[:, :, ::stride[0], ::stride[1]] 26 | else: 27 | mask = mask[:, :, ::stride, ::stride] 28 | elif sparse_dilation == "natural": 29 | float_mask = torch.zeros_like(inactive_mask[0:1, 0:1], dtype=torch.float) 30 | float_mask[inactive_mask[0:1, 0:1]] += 1.0 31 | filter = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float, device=inactive_mask.device) 32 | conv_mask = tile(conv2d(float_mask, filter, bias=None, dilation=dilation, stride=stride, padding=(padding, padding)), dim=1, 33 | n_tile=inactive_mask.shape[1]) 34 | mask = conv_mask > 0.0 35 | elif sparse_dilation == "tile": 36 | mask = scale_conv_mask(inactive_mask, None, kernel_size, padding, stride, bias, tile_thresh=tile_thresh) 37 | 38 | if tile_size < 0: 39 | tile_size = mask.shape[2] // (-tile_size) 40 | 41 | shape = mask.shape 42 | if (shape[-1] % tile_size + shape[-2] % tile_size) != 0: 43 | return scale_conv_mask(inactive_mask, sparse_dilation, kernel_size, padding, stride, bias, tile_size // 2, tile_thresh=tile_thresh) 44 | 45 | if tile_thresh <= 0.0: 46 | mask = mask.reshape(*mask.shape[:3], -1, tile_size) 47 | mask = torch.max(mask, dim=-1, keepdim=False)[0] 48 | mask = torch.transpose(mask, dim0=2, dim1=-1) 49 | mask = mask.reshape(*mask.shape[:3], -1, tile_size) 50 | mask = torch.max(mask, dim=-1, keepdim=False)[0] 51 | mask = torch.transpose(mask, dim0=2, dim1=-1)[:, :, :, None, :, None] 52 | else: 53 | mask = mask.float() 54 | mask = mask.reshape(*mask.shape[:3], -1, tile_size) 55 | mask = torch.sum(mask, dim=-1, keepdim=False) 56 | mask = torch.transpose(mask, dim0=2, dim1=-1) 57 | mask = mask.reshape(*mask.shape[:3], -1, tile_size) 58 | mask = torch.sum(mask, dim=-1, keepdim=False) 59 | mask = torch.transpose(mask, dim0=2, dim1=-1)[:, :, :, None, :, None] 60 | mask = mask > (tile_thresh * tile_size * tile_size) 61 | 62 | mask = tile(mask, dim=-3, n_tile=tile_size) 63 | mask = tile(mask, dim=-1, n_tile=tile_size) 64 | mask = mask.reshape(shape) 65 | elif sparse_dilation == "natural-tile": 66 | mask = scale_conv_mask(inactive_mask, "natural", kernel_size, padding, 1, bias, tile_thresh=tile_thresh) 67 | mask = scale_conv_mask(mask, "tile", kernel_size, padding, stride, bias, tile_thresh=tile_thresh) 68 | else: 69 | raise Exception(f"dilation mode {sparse_dilation} unknown") 70 | 71 | return mask 72 | 73 | def inpaint_masked(t, mask, neighborhood=4, steps=1): 74 | assert neighborhood == 4 or neighborhood == 8 75 | in_mask_dilated = mask.clone() 76 | in_mask_dilated_float = in_mask_dilated.float() 77 | out = t.clone() 78 | out[mask] = 0.0 79 | for i in range(steps): 80 | neighbors = torch.zeros_like(out) 81 | neighbor_count = torch.zeros_like(out) 82 | neighbor_count[:, :, 1:, :] += (~in_mask_dilated_float[:, :, :-1, :]) 83 | neighbors[:, :, 1:, :] += out[:, :, :-1, :] 84 | neighbor_count[:, :, :-1, :] += (~in_mask_dilated_float[:, :, 1:, :]) 85 | neighbors[:, :, :-1, :] += out[:, :, 1:, :] 86 | neighbor_count[:, :, :, 1:] += (~in_mask_dilated_float[:, :, :, :-1]) 87 | neighbors[:, :, :, 1:] += out[:, :, :, :-1] 88 | neighbor_count[:, :, :, :-1] += (~in_mask_dilated_float[:, :, :, 1:]) 89 | neighbors[:, :, :, :-1] += out[:, :, :, 1:] 90 | 91 | if neighborhood == 8: 92 | neighbor_count[:, :, 1:, 1:] += (~in_mask_dilated_float[:, :, :-1, :-1]) 93 | neighbors[:, :, 1:, 1:] += out[:, :, :-1, :-1] 94 | neighbor_count[:, :, :-1, 1:] += (~in_mask_dilated_float[:, :, 1:, :-1]) 95 | neighbors[:, :, :-1, 1:] += out[:, :, 1:, :-1] 96 | neighbor_count[:, :, 1:, :-1] += (~in_mask_dilated_float[:, :, :-1, 1:]) 97 | neighbors[:, :, 1:, :-1] += out[:, :, -1:, 1:] 98 | neighbor_count[:, :, :-1, :-1] += (~in_mask_dilated_float[:, :, 1:, 1:]) 99 | neighbors[:, :, :-1, :-1] += out[:, :, 1:, 1:] 100 | 101 | neighbor_count[neighbor_count == 0] = 100000 102 | neighbors /= neighbor_count 103 | neighbor_count[neighbor_count == 100000] = 0 104 | out[in_mask_dilated] = neighbors[in_mask_dilated] 105 | in_mask_dilated[neighbor_count > 0] = False 106 | in_mask_dilated_float = in_mask_dilated.float() 107 | 108 | return out 109 | 110 | 111 | def inpaint_binary_mask(mask, pixels=1): 112 | fmask = mask.float() 113 | filter = torch.ones((1, 1, pixels * 2 + 1, pixels * 2 + 1), device=mask.device, dtype=torch.float) 114 | inpainted_fmask = torch.conv2d(fmask[:, :1], filter, padding=pixels) 115 | inpainted_bmask = inpainted_fmask > 0 116 | 117 | if mask.shape[1] != inpainted_bmask.shape[1]: 118 | inpainted_bmask = tile(inpainted_bmask, dim=1, n_tile=mask.shape[1]) 119 | 120 | return inpainted_bmask 121 | 122 | 123 | def count_active_neighbors(mask, kernel_size=(3, 3), dilation=(1, 1), stride=(1, 1), padding=(0, 0), exclude_self=False): 124 | float_mask = mask.float() 125 | kx, ky = kernel_size 126 | filter = torch.ones((1, 1, ky, kx), device=mask.device) 127 | if exclude_self: 128 | filter[:, :, (ky - 1) // 2, (ky - 1) // 2] = 0 129 | 130 | if len(padding) == 4: 131 | float_mask = torch.nn.functional.pad(float_mask, padding) 132 | padding = 0 133 | 134 | return torch.round(torch.conv2d(float_mask, filter, None, stride, padding, dilation)).to(torch.int32) 135 | 136 | 137 | def count_inactive_neighbors(mask, kernel_size=(3, 3), dilation=(1, 1), stride=(1, 1), padding=(0, 0), exclude_self=False): 138 | return count_active_neighbors(~mask, kernel_size, dilation, stride, padding, exclude_self) 139 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 39 | 40 | DeltaCNN 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 |
59 |

DeltaCNN

60 |

End-to-End CNN Inference of Sparse Frame Differences in Videos

61 |
62 | Mathias Parger 1 63 | Chengcheng Tang 2 64 | Christopher D. Twigg 2 65 |
66 |
67 | Cem Keskin 2 68 | Robert Wang 2 69 | Markus Steinberger 1 70 |
71 |
72 | 1Graz University of Technology 73 | 2Meta Reality Labs 74 |
75 |
76 | 83 | 84 |
85 |
86 | 87 |
88 |
89 |
90 | 91 |
Prediction
92 |
93 |
94 | 95 |
Updates
96 |
97 |
98 |
99 |
100 |
101 |
102 | 103 |
Prediction
104 |
105 |
106 | 107 |
Updates
108 |
109 |
110 |
111 | 114 |
115 |
116 | 117 |
118 |

Abstract

119 |

Convolutional neural network inference on video data requires powerful hardware for real-time processing. Given the inherent coherence across consecutive frames, large parts of a video typically change little. By skipping identical image regions and truncating insignificant pixel updates, computational redundancy can in theory be reduced significantly. However, these theoretical savings have been difficult to translate into practice, as sparse updates hamper computational consistency and memory access coherence; which are key for efficiency on real hardware. With DeltaCNN, we present a sparse convolutional neural network framework that enables sparse frame-by-frame updates to accelerate video inference in practice. We provide sparse implementations for all typical CNN layers and propagate sparse feature updates end-to-end - without accumulating errors over time. DeltaCNN is applicable to all convolutional neural networks without retraining. To the best of our knowledge, we are the first to significantly outperform the dense reference, cuDNN, in practical settings, achieving speedups of up to 7x with only marginal differences in accuracy.

120 |
121 | 122 |
123 |

Video

124 |
125 | 126 |
127 |
128 | 129 |
130 |

Cite

131 |
132 | @article{parger2022deltacnn,
133 |     title = {DeltaCNN: End-to-End CNN Inference of Sparse Frame Differences in Videos},
134 |     author = {Mathias Parger, Chengcheng Tang, Christopher D. Twigg, Cem Keskin, Robert Wang, Markus Steinberger},
135 |     journal = {CVPR 2022},
136 |     year = {2022},
137 |     month = jun
138 | }
139 | 
140 |
141 | 142 | -------------------------------------------------------------------------------- /src/cuda/Utility.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include "common.cuh" 11 | 12 | static constexpr unsigned int FULL_MASK{0xFFFFFFFF}; 13 | static constexpr unsigned int WARP_SIZE{32}; 14 | 15 | void inline start_clock(cudaEvent_t &start, cudaEvent_t &end) 16 | { 17 | HANDLE_ERROR(cudaEventCreate(&start)); 18 | HANDLE_ERROR(cudaEventCreate(&end)); 19 | HANDLE_ERROR(cudaEventRecord(start, 0)); 20 | } 21 | 22 | // ############################################################################################################################################## 23 | // 24 | float inline end_clock(cudaEvent_t &start, cudaEvent_t &end) 25 | { 26 | float time; 27 | HANDLE_ERROR(cudaEventRecord(end, 0)); 28 | HANDLE_ERROR(cudaEventSynchronize(end)); 29 | HANDLE_ERROR(cudaEventElapsedTime(&time, start, end)); 30 | HANDLE_ERROR(cudaEventDestroy(start)); 31 | HANDLE_ERROR(cudaEventDestroy(end)); 32 | 33 | // Returns ms 34 | return time; 35 | } 36 | 37 | namespace Utils 38 | { 39 | 40 | template 41 | constexpr T constexpr_min(const T a, const T b) { 42 | return a > b ? b : a; 43 | } 44 | template 45 | constexpr T constexpr_max(const T a, const T b) { 46 | return a < b ? b : a; 47 | } 48 | // ############################################################################################################################################## 49 | // 50 | __device__ __forceinline__ int lane_id() 51 | { 52 | return (threadIdx.x & 31); 53 | } 54 | 55 | // ############################################################################################################################################## 56 | // 57 | __device__ __forceinline__ int warp_id() 58 | { 59 | return (threadIdx.x >> 5); 60 | } 61 | 62 | // ############################################################################################################################################## 63 | // 64 | template 65 | __host__ __device__ __forceinline__ T divup(T a, T2 b) 66 | { 67 | return (a + b - 1) / b; 68 | } 69 | 70 | // ############################################################################################################################################## 71 | // 72 | template 73 | static __device__ __forceinline__ int getNextPow2Pow(T n) 74 | { 75 | if ((n & (n - 1)) == 0) 76 | return 32 - __clz(n) - 1; 77 | else 78 | return 32 - __clz(n); 79 | } 80 | 81 | // ############################################################################################################################################## 82 | // 83 | template 84 | static __device__ __forceinline__ T getNextPow2(T n) 85 | { 86 | return 1 << (getNextPow2Pow(n)); 87 | } 88 | 89 | // ############################################################################################################################################## 90 | // 91 | template 92 | struct static_clz 93 | { 94 | static const int value = (X & 0x80000000) ? Completed : static_clz< (X << 1), Completed + 1 >::value; 95 | }; 96 | 97 | // ############################################################################################################################################## 98 | // 99 | template 100 | struct static_clz 101 | { 102 | static const int value = 32; 103 | }; 104 | 105 | // ############################################################################################################################################## 106 | // 107 | template 108 | static constexpr int static_getNextPow2Pow() 109 | { 110 | if ((n & (n - 1)) == 0) 111 | return 32 - static_clz(n)>::value - 1; 112 | else 113 | return 32 - static_clz(n)>::value; 114 | } 115 | 116 | // ############################################################################################################################################## 117 | // 118 | template 119 | static constexpr size_t static_getNextPow2() 120 | { 121 | return 1 << (static_getNextPow2Pow()); 122 | } 123 | 124 | // ############################################################################################################################################## 125 | // 126 | template 127 | static constexpr bool isPowerOfTwo(const T n) 128 | { 129 | return (n & (n - 1)) == 0; 130 | } 131 | 132 | // ############################################################################################################################################## 133 | // 134 | template 135 | static constexpr __forceinline__ __device__ int32_t modPower2(const int32_t value) 136 | { 137 | static_assert(isPowerOfTwo(size), "ModPower2 used with non-power of 2"); 138 | return value & (size - 1); 139 | } 140 | 141 | // ############################################################################################################################################## 142 | // 143 | template 144 | __host__ __device__ __forceinline__ T divPower2(T val) 145 | { 146 | return val >> static_getNextPow2Pow(); 147 | } 148 | 149 | // ############################################################################################################################################## 150 | // 151 | template 152 | __host__ __device__ __forceinline__ T mulPower2(T val) 153 | { 154 | return val << static_getNextPow2Pow(); 155 | } 156 | 157 | // ############################################################################################################################################## 158 | // 159 | template 160 | __device__ __forceinline__ T warpReduceSum(T val, unsigned mask = FULL_MASK) { 161 | #pragma unroll 162 | for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { 163 | val += __shfl_down_sync(mask, val, offset); 164 | } 165 | return val; 166 | } 167 | 168 | // ############################################################################################################################################## 169 | // 170 | template<> __device__ __forceinline__ half2 warpReduceSum(half2 val, unsigned mask) { 171 | #pragma unroll 172 | for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { 173 | int tmp_val = __shfl_down_sync(mask, *reinterpret_cast(&val), offset); 174 | val = __hadd2(val, *reinterpret_cast(&tmp_val)); 175 | } 176 | return val; 177 | } 178 | 179 | // ############################################################################################################################################## 180 | // 181 | template<> __device__ __forceinline__ half warpReduceSum(half val, unsigned mask) { 182 | #pragma unroll 183 | for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { 184 | float tmp_val = __shfl_down_sync(mask, __half2float(val), offset); 185 | val = __hadd(val, __float2half(tmp_val)); 186 | } 187 | return val; 188 | } 189 | 190 | // ############################################################################################################################################## 191 | // 192 | template 193 | __device__ __forceinline__ T warpReduceSumGrouped(T val, int groups, unsigned mask = FULL_MASK) { 194 | for (int offset = WARP_SIZE / 2 / groups; offset > 0; offset /= 2) { 195 | val += __shfl_down_sync(mask, val, offset); 196 | } 197 | return val; 198 | } 199 | 200 | // ############################################################################################################################################## 201 | // 202 | template 203 | __device__ __forceinline__ T warpReduceMax(T val, unsigned mask = FULL_MASK) { 204 | for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { 205 | val = max(__shfl_down_sync(mask, val, offset), val); 206 | } 207 | return val; 208 | } 209 | 210 | // ############################################################################################################################################## 211 | // 212 | __device__ __forceinline__ static float atomicMax(float* address, float val) 213 | { 214 | int* address_as_i = (int*) address; 215 | int old = *address_as_i, assumed; 216 | do { 217 | assumed = old; 218 | old = ::atomicCAS(address_as_i, assumed, 219 | __float_as_int(::fmaxf(val, __int_as_float(assumed)))); 220 | } while (assumed != old); 221 | return __int_as_float(old); 222 | } 223 | } 224 | 225 | -------------------------------------------------------------------------------- /example/mobilenet_original.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # based on the MobileNetv2 implementation from PyTorch 7 | # source: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html 8 | # and: https://pytorch.org/hub/pytorch_vision_mobilenet_v2/ 9 | 10 | # original license and copyright: 11 | # BSD 3-Clause License 12 | 13 | # Copyright (c) Soumith Chintala 2016, 14 | # All rights reserved. 15 | 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | import torch 42 | from torch import nn 43 | from torch.hub import load_state_dict_from_url 44 | 45 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 46 | 47 | 48 | model_urls = { 49 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 50 | } 51 | 52 | 53 | def _make_divisible(v, divisor, min_value=None): 54 | """ 55 | This function is taken from the original tf repo. 56 | It ensures that all layers have a channel number that is divisible by 8 57 | It can be seen here: 58 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 59 | :param v: 60 | :param divisor: 61 | :param min_value: 62 | :return: 63 | """ 64 | if min_value is None: 65 | min_value = divisor 66 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 67 | # Make sure that round down does not go down by more than 10%. 68 | if new_v < 0.9 * v: 69 | new_v += divisor 70 | return new_v 71 | 72 | 73 | 74 | class ConvBNReLU(nn.Sequential): 75 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 76 | padding = (kernel_size - 1) // 2 77 | if norm_layer is None: 78 | norm_layer = nn.BatchNorm2d 79 | super(ConvBNReLU, self).__init__( 80 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 81 | norm_layer(out_planes), 82 | nn.ReLU6(inplace=True) 83 | ) 84 | 85 | 86 | class InvertedResidual(nn.Module): 87 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 88 | super(InvertedResidual, self).__init__() 89 | self.stride = stride 90 | assert stride in [1, 2] 91 | 92 | if norm_layer is None: 93 | norm_layer = nn.BatchNorm2d 94 | 95 | hidden_dim = int(round(inp * expand_ratio)) 96 | self.use_res_connect = self.stride == 1 and inp == oup 97 | 98 | layers = [] 99 | if expand_ratio != 1: 100 | # pw 101 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 102 | layers.extend([ 103 | # dw 104 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 105 | # pw-linear 106 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 107 | norm_layer(oup), 108 | ]) 109 | self.conv = nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | if self.use_res_connect: 113 | return x + self.conv(x) 114 | else: 115 | return self.conv(x) 116 | 117 | 118 | class MobileNetV2(nn.Module): 119 | def __init__(self, 120 | num_classes=1000, 121 | width_mult=1.0, 122 | inverted_residual_setting=None, 123 | round_nearest=8, 124 | block=None, 125 | norm_layer=None): 126 | """ 127 | MobileNet V2 main class 128 | 129 | Args: 130 | num_classes (int): Number of classes 131 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 132 | inverted_residual_setting: Network structure 133 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 134 | Set to 1 to turn off rounding 135 | block: Module specifying inverted residual building block for mobilenet 136 | norm_layer: Module specifying the normalization layer to use 137 | 138 | """ 139 | super(MobileNetV2, self).__init__() 140 | 141 | if block is None: 142 | block = InvertedResidual 143 | 144 | if norm_layer is None: 145 | norm_layer = nn.BatchNorm2d 146 | 147 | input_channel = 32 148 | last_channel = 1280 149 | 150 | if inverted_residual_setting is None: 151 | inverted_residual_setting = [ 152 | # t, c, n, s 153 | [1, 16, 1, 1], 154 | [6, 24, 2, 2], 155 | [6, 32, 3, 2], 156 | [6, 64, 4, 2], 157 | [6, 96, 3, 1], 158 | [6, 160, 3, 2], 159 | [6, 320, 1, 1], 160 | ] 161 | 162 | # only check the first element, assuming user knows t,c,n,s are required 163 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 164 | raise ValueError("inverted_residual_setting should be non-empty " 165 | "or a 4-element list, got {}".format(inverted_residual_setting)) 166 | 167 | # building first layer 168 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 169 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 170 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 171 | # building inverted residual blocks 172 | for t, c, n, s in inverted_residual_setting: 173 | output_channel = _make_divisible(c * width_mult, round_nearest) 174 | for i in range(n): 175 | stride = s if i == 0 else 1 176 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 177 | input_channel = output_channel 178 | # building last several layers 179 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 180 | # make it nn.Sequential 181 | self.features = nn.Sequential(*features) 182 | 183 | # building classifier 184 | self.classifier = nn.Sequential( 185 | nn.Dropout(0.2), 186 | nn.Linear(self.last_channel, num_classes), 187 | ) 188 | 189 | # weight initialization 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 193 | if m.bias is not None: 194 | nn.init.zeros_(m.bias) 195 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 196 | nn.init.ones_(m.weight) 197 | nn.init.zeros_(m.bias) 198 | elif isinstance(m, nn.Linear): 199 | nn.init.normal_(m.weight, 0, 0.01) 200 | nn.init.zeros_(m.bias) 201 | 202 | def _forward_impl(self, x): 203 | # This exists since TorchScript doesn't support inheritance, so the superclass method 204 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 205 | x = self.features(x) 206 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 207 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 208 | x = self.classifier(x) 209 | return x 210 | 211 | def forward(self, x): 212 | return self._forward_impl(x) 213 | 214 | 215 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 216 | """ 217 | Constructs a MobileNetV2 architecture from 218 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 219 | 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | progress (bool): If True, displays a progress bar of the download to stderr 223 | """ 224 | model = MobileNetV2(**kwargs) 225 | if pretrained: 226 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 227 | progress=progress) 228 | model.load_state_dict(state_dict) 229 | return model -------------------------------------------------------------------------------- /example/mobilenet_deltacnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # based on the MobileNetv2 implementation from PyTorch 7 | # source: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html 8 | # and: https://pytorch.org/hub/pytorch_vision_mobilenet_v2/ 9 | 10 | # original license and copyright: 11 | # BSD 3-Clause License 12 | 13 | # Copyright (c) Soumith Chintala 2016, 14 | # All rights reserved. 15 | 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | import torch 42 | from torch import nn 43 | from deltacnn.sparse_layers import DCDensify, DCModule, DCConv2d, DCBatchNorm2d, DCAdd, DCActivation, DCAdaptiveAveragePooling, DCSparsify 44 | from torch.hub import load_state_dict_from_url 45 | 46 | 47 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 48 | 49 | 50 | model_urls = { 51 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 52 | } 53 | 54 | 55 | def _make_divisible(v, divisor, min_value=None): 56 | """ 57 | This function is taken from the original tf repo. 58 | It ensures that all layers have a channel number that is divisible by 8 59 | It can be seen here: 60 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 61 | :param v: 62 | :param divisor: 63 | :param min_value: 64 | :return: 65 | """ 66 | if min_value is None: 67 | min_value = divisor 68 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 69 | # Make sure that round down does not go down by more than 10%. 70 | if new_v < 0.9 * v: 71 | new_v += divisor 72 | return new_v 73 | 74 | # class ConvBNReLU(nn.Sequential): 75 | class DeltaCNN_ConvBNReLU(nn.Sequential, DCModule): 76 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 77 | padding = (kernel_size - 1) // 2 78 | if norm_layer is None: 79 | norm_layer = nn.BatchNorm2d 80 | super(DeltaCNN_ConvBNReLU, self).__init__( 81 | # nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), # replaced by: 82 | DCConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 83 | norm_layer(out_planes), 84 | # nn.ReLU6(inplace=True) # replaced by: 85 | DCActivation(activation="relu6", inplace=True) 86 | ) 87 | 88 | 89 | # class InvertedResidual(nn.Module): # replaced by 90 | class DeltaCNN_InvertedResidual(DCModule): 91 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 92 | super(DeltaCNN_InvertedResidual, self).__init__() 93 | self.stride = stride 94 | assert stride in [1, 2] 95 | 96 | if norm_layer is None: 97 | # norm_layer = nn.BatchNorm2d # replaced by: 98 | norm_layer = DCBatchNorm2d 99 | 100 | hidden_dim = int(round(inp * expand_ratio)) 101 | self.use_res_connect = self.stride == 1 and inp == oup 102 | self.sparse_add = DCAdd() # added 103 | 104 | layers = [] 105 | if expand_ratio != 1: 106 | # pw 107 | layers.append(DeltaCNN_ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 108 | layers.extend([ 109 | # dw 110 | DeltaCNN_ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 111 | # pw-linear 112 | # nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), # replaced by: 113 | DCConv2d(hidden_dim, oup, 1, 1, 0, bias=False), 114 | norm_layer(oup), 115 | ]) 116 | self.conv = nn.Sequential(*layers) 117 | 118 | def forward(self, x): 119 | if self.use_res_connect: 120 | # return x + self.conv(x) # replaced by: 121 | return self.sparse_add(x, self.conv(x)) 122 | else: 123 | return self.conv(x) 124 | 125 | 126 | # class MobileNetV2(nn.Module): # replaced by: 127 | class DeltaCNN_MobileNetV2(DCModule): 128 | def __init__(self, 129 | num_classes=1000, 130 | width_mult=1.0, 131 | inverted_residual_setting=None, 132 | round_nearest=8, 133 | block=None, 134 | norm_layer=None): 135 | """ 136 | MobileNet V2 main class 137 | 138 | Args: 139 | num_classes (int): Number of classes 140 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 141 | inverted_residual_setting: Network structure 142 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 143 | Set to 1 to turn off rounding 144 | block: Module specifying inverted residual building block for mobilenet 145 | norm_layer: Module specifying the normalization layer to use 146 | 147 | """ 148 | super(DeltaCNN_MobileNetV2, self).__init__() 149 | 150 | if block is None: 151 | block = DeltaCNN_InvertedResidual 152 | 153 | if norm_layer is None: 154 | # norm_layer = nn.BatchNorm2d # replaced by: 155 | norm_layer = DCBatchNorm2d 156 | 157 | input_channel = 32 158 | last_channel = 1280 159 | 160 | if inverted_residual_setting is None: 161 | inverted_residual_setting = [ 162 | # t, c, n, s 163 | [1, 16, 1, 1], 164 | [6, 24, 2, 2], 165 | [6, 32, 3, 2], 166 | [6, 64, 4, 2], 167 | [6, 96, 3, 1], 168 | [6, 160, 3, 2], 169 | [6, 320, 1, 1], 170 | ] 171 | 172 | # only check the first element, assuming user knows t,c,n,s are required 173 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 174 | raise ValueError("inverted_residual_setting should be non-empty " 175 | "or a 4-element list, got {}".format(inverted_residual_setting)) 176 | 177 | # building first layer 178 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 179 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 180 | features = [DeltaCNN_ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 181 | # building inverted residual blocks 182 | for t, c, n, s in inverted_residual_setting: 183 | output_channel = _make_divisible(c * width_mult, round_nearest) 184 | for i in range(n): 185 | stride = s if i == 0 else 1 186 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 187 | input_channel = output_channel 188 | # building last several layers 189 | features.append(DeltaCNN_ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 190 | # make it nn.Sequential 191 | self.features = nn.Sequential(*features) 192 | 193 | self.sparsify = DCSparsify() # added 194 | self.densify = DCDensify() # added 195 | self.adaptive_avg_pooling = DCAdaptiveAveragePooling(output_size=(1,1)) # added 196 | 197 | # building classifier 198 | self.classifier = nn.Sequential( 199 | nn.Dropout(0.2), 200 | nn.Linear(self.last_channel, num_classes), 201 | ) 202 | 203 | # weight initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 207 | if m.bias is not None: 208 | nn.init.zeros_(m.bias) 209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 210 | nn.init.ones_(m.weight) 211 | nn.init.zeros_(m.bias) 212 | elif isinstance(m, nn.Linear): 213 | nn.init.normal_(m.weight, 0, 0.01) 214 | nn.init.zeros_(m.bias) 215 | 216 | def _forward_impl(self, x): 217 | x = self.sparsify(x) # added 218 | x = self.features(x) 219 | 220 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 221 | # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) # replaced by: 222 | x = self.adaptive_avg_pooling(x) 223 | x = self.densify(x) 224 | x = x.reshape(x.shape[0], -1) 225 | 226 | x = self.classifier(x) 227 | return x 228 | 229 | def forward(self, x): 230 | return self._forward_impl(x) 231 | 232 | 233 | def DeltaCNN_mobilenet_v2(pretrained=False, progress=True, **kwargs): 234 | """ 235 | Constructs a MobileNetV2 architecture from 236 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | model = DeltaCNN_MobileNetV2(**kwargs) 243 | if pretrained: 244 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 245 | progress=progress) 246 | model.load_state_dict(state_dict) 247 | return model -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DeltaCNN 2 | 3 | DeltaCNN caches intermediate feature maps from previous frames to accelerate inference of new frames by only processing updated pixels. 4 | DeltaCNN can be used as a drop-in replacement for most layers of a CNN by simply replacing the PyTorch layers with the DeltaCNN equivalent. 5 | Model weights and inference logic can be reused without the need for retraining. 6 | All layers are implemented in CUDA, other devices are currently not supported. 7 | 8 | A preprint of the paper is available on [Arxiv](https://arxiv.org/abs/2203.03996). 9 | 10 | Find more information about the project on the [Project Website](https://dabeschte.github.io/DeltaCNN) 11 | 12 | ![](docs/images/h36m_pred.png) | ![](docs/images/h36m_heatmap.png) | ![](docs/images/mot16_pred.png) | ![](docs/images/mot16_heatmap.png) 13 | :-----:|:-----:|:-----:|:-----: 14 | Prediction | Updates | Prediction | Updates 15 | 16 | ## Table of Contents 17 | 18 | * [1 Setup](#1-setup) 19 | * [2 Example Project](#2-example-project) 20 | * [3 Using DeltaCNN in your project](#3-using-deltacnn-in-your-project) 21 | * [3.1 Replacing Layers](#31-replacing-layers) 22 | * [3.2 Perform custom operations not supported by DeltaCNN](#32-perform-custom-operations-not-supported-by-deltacnn) 23 | * [3.3 Weights and features memory layout](#33-weights-and-features-memory-layout) 24 | * [3.4 Custom thresholds](#34-custom-thresholds) 25 | * [3.5 Tuning thresholds](#35-tuning-thresholds) 26 | * [Tips & Tricks](#tips--tricks) 27 | * [Cite](#cite) 28 | 29 | ## 1 Setup 30 | 31 | ### Prerequsites 32 | 33 | DeltaCNN depends on: 34 | 35 | * [Python](https://www.python.org/downloads/) / [Anaconda](https://www.anaconda.com/) 36 | * C++ compiler 37 | * Windows: Visual Studio (msvc) with "Desktop development for C++" 38 | * Linux: gcc/g++ 39 | * [CUDA Toolkit 11.3](https://developer.nvidia.com/cuda-11.3.0-download-archive) 40 | * [PyTorch v1.10 or newer](https://pytorch.org/get-started/locally/) 41 | 42 | Please install these packages before installing DeltaCNN. 43 | 44 | ### Install DeltaCNN Framework 45 | 46 | * Navigate to DeltaCNN root directory 47 | * Run `python setup.py install --user` 48 | (This can take a few minutes) 49 | 50 | ## 2 Example project 51 | 52 | [example/mobilenetv2_webcam_example.py](example/mobilenetv2_webcam_example.py) contains a simple example that showcase all steps needed for replacing PyTorch's CNN layers by DeltaCNN. 53 | In this example, all steps required to port a network are highlighted with `# added` and `# replaced by`. 54 | In the main file, we load the original CNN, and the DeltaCNN variant, and run both on webcam video input. 55 | Play around with the DCThreshold.t_default to see how the performance and accuracy change with different values. 56 | For the sake of simplicity, we avoided steps like fusing batch normalization layers together with convolutional layers or tuning thresholds for each layer individually. 57 | 58 | ## 3 Using DeltaCNN in your project 59 | 60 | Using DeltaCNN in your CNN project should in most cases be as easy as replacing all layers in the CNN with the DeltaCNN equivalent and adding a dense-to-sparse (DCSparsify()) layer at the beginning and a sparse-to-dense (DCDensify()) layer at the end. 61 | However, some things need to be considered when replacing the layers. 62 | 63 | ### 3.1 Replacing Layers 64 | 65 | Nonlinear layers need unique instances for every location they are used in the model as they cache input/output feature maps at the current stage. To be safe, create a unique instance for every use of a layer in the model. For example, this toy model can be converted as follows. 66 | 67 | ```python 68 | ####### PyTorch 69 | from torch import nn 70 | class CNN(nn.Module): 71 | def __init__(self): 72 | super(CNN, self).__init__() 73 | self.conv1 = nn.Conv2d(...) 74 | self.conv2 = nn.Conv2d(...) 75 | self.conv3 = nn.Conv2d(...) 76 | self.relu = nn.ReLU() 77 | 78 | def forward(self, x): 79 | x = self.relu(self.conv1(x)) 80 | x = self.relu(self.conv2(x)) 81 | return self.relu(self.conv3(x)) 82 | ``` 83 | 84 | ```python 85 | ####### DeltaCNN 86 | import deltacnn 87 | class CNN(deltacnn.DCModule): 88 | def __init__(self): 89 | super(CNN, self).__init__() 90 | self.sparsify = deltacnn.DCSparsify() 91 | self.conv1 = deltacnn.DCConv2d(...) 92 | self.conv2 = deltacnn.DCConv2d(...) 93 | self.conv3 = deltacnn.DCConv2d(...) 94 | self.relu1 = deltacnn.DCActivation(activation="relu") 95 | self.relu2 = deltacnn.DCActivation(activation="relu") 96 | self.relu3 = deltacnn.DCActivation(activation="relu") 97 | self.densify = deltacnn.DCDensify() 98 | 99 | def forward(self, x): 100 | x = self.sparsify(x) 101 | x = self.relu1(self.conv1(x)) 102 | x = self.relu2(self.conv2(x)) 103 | return self.densify(self.relu3(self.conv3(x))) 104 | ``` 105 | 106 | or simply: 107 | 108 | ```python 109 | ####### DeltaCNN simplified 110 | import deltacnn 111 | class CNN(deltacnn.DCModule): 112 | def __init__(self): 113 | super(CNN, self).__init__() 114 | self.sparsify = deltacnn.DCSparsify() 115 | self.conv1 = deltacnn.DCConv2d(..., activation="relu") 116 | self.conv2 = deltacnn.DCConv2d(..., activation="relu") 117 | self.conv3 = deltacnn.DCConv2d(..., activation="relu", dense_out=True) 118 | 119 | def forward(self, x): 120 | x = self.sparsify(x) 121 | x = self.conv1(x) 122 | x = self.conv2(x) 123 | return self.conv3(x) 124 | ``` 125 | 126 | ### 3.2 Perform custom operations not supported by DeltaCNN 127 | 128 | If you want to add layers not included in DeltaCNN or apply operations on the feature maps directly, be aware of the feature maps used in DeltaCNN. 129 | DeltaCNN propagates only Delta updates between the layers. The output of a DeltaCNN layer consists of a Delta tensor and an update mask. Be careful when directly accessing these values, as skipped pixels are not initialized and contain random values. 130 | 131 | If you apply custom operations onto the feature maps, the safest way is to add a DCDensify() layer, apply your operation and then convert the features back to Delta features using DCSparsify(). For example: 132 | 133 | ```python 134 | ####### PyTorch 135 | from torch import nn 136 | class Normalize(nn.Module): 137 | def forward(self, x): 138 | return x / x.max() 139 | ``` 140 | 141 | ```python 142 | ####### DeltaCNN 143 | from deltacnn import DCDensify, DCSparsify, DCModule 144 | class Normalize(DCModule): 145 | def __init__(self): 146 | super(Normalize, self).__init__() 147 | self.densify = DCDensify() 148 | self.sparsify = DCSparsify() 149 | 150 | def forward(self, x): 151 | x = self.densify(x) 152 | x = x / x.max() 153 | return self.sparsify(x) 154 | ``` 155 | 156 | ### 3.3 Weights and features memory layout 157 | 158 | DeltaCNN kernels only support `torch.channels_last` memory format. Furthermore, it expects a specific memory layout for the weights used in convolutional layers. Thus, after loading the weights from disk, process the filters before the first call. And be sure to convert the network input to channels last memory format. 159 | 160 | ```python 161 | class MyDCModel(DCModule): 162 | ... 163 | 164 | device = "cuda:0" 165 | model = MyDCModel(...) 166 | load_weights(model, weights_path) # weights are stored in PyTorch standard format 167 | model.to(device, memory_format=torch.channels_last) # set the network in channels last mode 168 | model.process_filters() # convert filters into DeltaCNN format 169 | 170 | for frame in video: 171 | frame = frame.to(device).contiguous(memory_format=torch.channels_last) 172 | out = model(frame) 173 | ``` 174 | 175 | ### 3.4 Custom thresholds 176 | 177 | The easiest way to try DeltaCNN is to use a set a global threshold the `DCThreshold.t_default` variable before instantiating the model. Good starting points are thresholds in the range between 0.05 to 0.3, but this can vary strongly depending on the network and the noise of the video. If video noise is an issue, specify a larger threshold for the DCSparsify layer using the delta_threshold parameter and compensate if using update mask dilation. 178 | For example: `DCSparsify(delta_threshold=0.3, dilation=15)`. 179 | Thresholds can also be loaded from json files containing the threshold index as key. 180 | Set the path to the thresholds using `DCThreshold.path = ` and load the thresholds after predicting the first frame. 181 | For example: 182 | 183 | ```python 184 | for frame_idx, frame in enumerate(video): 185 | frame = frame.to(device).contiguous(memory_format=torch.channels_last) 186 | out = self.model(frame) 187 | 188 | if frame_idx == 0: 189 | DCThreshold.path = threshold_path 190 | DCThreshold.load_thresholds() 191 | ``` 192 | 193 | ### 3.5 Tuning thresholds 194 | 195 | On first call, all buffers are allocated in the size of the current input, the layers and logging layers are initialized and all truncation layers register their thresholds in the DCThreshold class. Optimizing the thresholds in a front-to-back manner can be done by iterating over all items stored in the ordered dictionary `DCThreshold.t`. 196 | For example: 197 | 198 | ```python 199 | sequence = load_video() 200 | DCThreshold.t_default = 0.0 201 | model = init_model() 202 | ref_loss = calc_loss(model, sequence) 203 | max_per_layer_loss_increase = 1.001 # some random number 204 | step_size = 2 # some random number 205 | 206 | for key in DCThreshold.t.keys(): 207 | start_loss = calc_loss(model, sequence) 208 | DCThreshold.t[key] = 0.001 # some random number 209 | 210 | while calc_loss(model, sequence) < start_loss * max_per_layer_loss_increase: 211 | DCThreshold.t[key] *= step_size 212 | DCThreshold.t[key] /= step_size # since loss with prev threshold was already too large, go back a step 213 | 214 | ``` 215 | 216 | For better ways to tune the thresholds, please read the respective section in the DeltaCNN paper. 217 | 218 | ## Supported operations 219 | 220 | DeltaCNN focuses on end-to-end sparse inference and therefore comes with common CNN layers besides convolutions. 221 | Yet, being a small research project, it does not provide all layers you might need or even support the provided layers in all possible configurations. 222 | If you want to use a layer that is not included in DeltaCNN, please open an issue. 223 | If you have some experience with CUDA, you can add new layers to DeltaCNN yourself - please consider creating a pull request to make DeltaCNN even better. 224 | 225 | As a rough overview, DeltaCNN features the following layers: 226 | 227 | * `DCSparsify` / `DCDensify`: convert dense features to sparse delta features + update mask and back 228 | * `DCConv2d`: Kernel sizes of 1x1, 3x3 and 5x5. All convolutions support striding of 1x1 and 2x2 as well as dilation of any factor and depthwise convolutions. Additionally, kernel size of 7x7 with a stride of 2x2 is also implemented for ResNet. All kernels can be used in float16 and float32 mode. However, as DeltaCNN does not support Tensor Cores (which cuDNN automatically uses in 16 bit mode), performance comparisons against cuDNN should be done in 32 bit mode for apples to apples comparisons. 229 | * `DCActivation`: `ReLU`, `ReLU6`, `LeakyReLU`, `Sigmoid` and `Swish`. 230 | * `DCMaxPooling`, `DCAdaptiveAveragePooling`: Average and maximum are supported for different kernel sizes. 231 | * `DCUpsamplingNearest2d`: By factors of 2, 4, 8 or 16. 232 | * `DCBatchNorm2d`: BatchNorm parameters are converted into scale and offset on initialization. 233 | * `DCAdd`: Adding two tensors (e.g. skip connection) 234 | * `DCConcatenate`: Concatenating two tensors along channel dimension (e.g. skip connection) 235 | 236 | ## Tips & Tricks 237 | 238 | * As a starting point, we would suggest to use a small global threshold, or even 0 and to iteratively increase the threshold on the input until the accuracy decreases. Try to use a update mask dilation on the first layer together with high thresholds to compensate noise. Afterwards, try increasing the global threshold to the maximum that does not significantly reduce accuracy. Use this threshold as baseline when fine tuning individual truncation thresholds. 239 | 240 | * Fusing batch normalization layers together with convolutional layers can have a large impact on performance. 241 | 242 | * Switch between DeltaCNN and cuDNN inference mode without changing the layers by setting `DCConv2d.backend` to `DCBackend.deltacnn` or `DCBackend.cudnn`. 243 | 244 | ## Cite 245 | 246 | ``` 247 | @article{parger2022deltacnn, 248 | title = {DeltaCNN: End-to-End CNN Inference of Sparse Frame Differences in Videos}, 249 | author = {Mathias Parger, Chengcheng Tang, Christopher D. Twigg, Cem Keskin, Robert Wang, Markus Steinberger}, 250 | journal = {CVPR 2022}, 251 | year = {2022}, 252 | month = jun 253 | } 254 | ``` 255 | 256 | ## License 257 | 258 | `DeltaCNN` is released under the CC BY-NC 4.0 license. See [LICENSE](LICENSE.txt) for additional details about it. 259 | See also our [Terms of Use](https://opensource.facebook.com/legal/terms) and [Privacy Policy](https://opensource.facebook.com/legal/privacy). -------------------------------------------------------------------------------- /src/deltacnn/cuda_kernels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from deltacnn.cuda import sparse_conv_bias_wrapper_masked, sparse_deconv_bias_wrapper_masked 8 | from deltacnn.cuda import deltacnn_activate_truncate, deltacnn_prepare_diff_mask_wrapper 9 | from deltacnn.cuda import sparse_add_tensors_wrapper, sparse_add_to_dense_tensor_wrapper, sparse_upsample_wrapper, sparse_concatenate_wrapper, sparse_mul_add_wrapper 10 | from deltacnn.cuda import sparse_pooling_wrapper_masked 11 | from deltacnn.cuda import deltacnn_init_performance_metrics, deltacnn_reset_performance_metrics, deltacnn_retrieve_metrics 12 | 13 | def sparse_conv(x, filter, mask=None, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, c_out:int=None, create_out_mask=False, sub_tile_sparsity=True, out_mask=None, out_shape=None) -> torch.Tensor: 14 | # DeltaCNN currently only support zero padding 15 | pad_mode_int = 0 16 | 17 | out_b = x.shape[0] 18 | if out_shape is not None and len(out_shape) == 2: 19 | out_h, out_w = out_shape[0], out_shape[1] 20 | else: 21 | if len(padding) == 2: 22 | out_h = int((x.shape[2] + 2 * padding[0] - dilation[0] * (filter.shape[1]-1) - 1) // stride[0] + 1) 23 | out_w = int((x.shape[3] + 2 * padding[1] - dilation[1] * (filter.shape[2]-1) - 1) // stride[1] + 1) 24 | elif len(padding) == 4: 25 | out_h = int((x.shape[2] + padding[0] + padding[1] - dilation[0] * (filter.shape[1]-1) - 1) // stride[0] + 1) 26 | out_w = int((x.shape[3] + padding[2] + padding[3] - dilation[1] * (filter.shape[2]-1) - 1) // stride[1] + 1) 27 | else: 28 | raise "Padding must be iterable of size 2 or 4" 29 | if type(out_shape) == list: 30 | out_shape.extend([out_h, out_w]) 31 | 32 | out_c = filter.shape[3] if c_out is None else c_out 33 | 34 | out = torch.empty((out_b, out_c, out_h, out_w), dtype=x.dtype, device=x.device, memory_format=torch.channels_last) 35 | if mask is not None and out_mask is None: 36 | if create_out_mask: 37 | out_mask = torch.empty((out_b, 1, out_h, out_w), dtype=torch.int32, device=x.device, memory_format=torch.channels_last) 38 | else: 39 | out_mask = None 40 | sparse_conv_bias_wrapper_masked(x, filter, bias, out, mask, out_mask, stride, padding, dilation, groups, pad_mode_int, sub_tile_sparsity) 41 | 42 | if create_out_mask and mask is not None: 43 | return out, out_mask 44 | else: 45 | return out 46 | 47 | def sparse_deconv(x, filter, mask=None, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, c_out:int=None, create_out_mask=False, sub_tile_sparsity=True, out_mask=None, out_shape=None) -> torch.Tensor: 48 | # DeltaCNN currently only support zero padding 49 | pad_mode_int = 0 50 | 51 | out_b = x.shape[0] 52 | if out_shape is not None and len(out_shape) == 2: 53 | out_h, out_w = out_shape[0], out_shape[1] 54 | else: 55 | if len(padding) == 2: 56 | out_h = int((x.shape[2] - 1) * stride[0] - 2 * padding[0] + dilation[0] * (filter.shape[1]-1) + 1) 57 | out_w = int((x.shape[3] - 1) * stride[1] - 2 * padding[1] + dilation[1] * (filter.shape[2]-1) + 1) 58 | else: 59 | raise "Padding must be iterable of size 2" 60 | if type(out_shape) == list: 61 | out_shape.extend([out_h, out_w]) 62 | 63 | out_c = filter.shape[3] if c_out is None else c_out 64 | 65 | # TODO remove the contiguous call here and try to work around it 66 | out = torch.empty((out_b, out_c, out_h, out_w), dtype=x.dtype, device=x.device, memory_format=torch.channels_last) 67 | if mask is not None and out_mask is None: 68 | if create_out_mask: 69 | out_mask = torch.empty((out_b, 1, out_h, out_w), dtype=torch.int32, device=x.device, memory_format=torch.channels_last) 70 | else: 71 | out_mask = None 72 | 73 | sparse_deconv_bias_wrapper_masked(x, filter, bias, out, mask, out_mask, stride, padding, dilation, groups, pad_mode_int, sub_tile_sparsity) 74 | 75 | if create_out_mask and mask is not None: 76 | return out, out_mask 77 | else: 78 | return out 79 | 80 | def sparse_pooling(x, prev_x, kernel_size, mask=None, stride=(1,1), padding=(0,0), dilation=(1,1), create_out_mask=True, sub_tile_sparsity=True, out_mask=None, pooling_mode_int=0, out_shape=None) -> torch.Tensor: 81 | # DeltaCNN currently only support zero padding 82 | pad_mode_int = 0 83 | out_b = x.shape[0] 84 | c = x.shape[1] 85 | if out_shape is not None and len(out_shape) == 2: 86 | out_h, out_w = out_shape[0], out_shape[1] 87 | else: 88 | if len(padding) == 2: 89 | out_h = int((x.shape[2] + 2 * padding[0] - dilation[0] * (kernel_size[0]-1) - 1) // stride[0] + 1) 90 | out_w = int((x.shape[3] + 2 * padding[1] - dilation[1] * (kernel_size[1]-1) - 1) // stride[1] + 1) 91 | elif len(padding) == 4: 92 | out_h = int((x.shape[2] + padding[0] + padding[1] - dilation[0] * (kernel_size[0]-1) - 1) // stride[0] + 1) 93 | out_w = int((x.shape[3] + padding[2] + padding[3] - dilation[1] * (kernel_size[1]-1) - 1) // stride[1] + 1) 94 | else: 95 | raise "Padding must be an iterable of size 2 or 4" 96 | if type(out_shape) == list: 97 | out_shape.extend([out_h, out_w]) 98 | 99 | dtype_original = x.dtype 100 | 101 | if out_h > 1 or out_w > 1: 102 | out = torch.empty((out_b, c, out_h, out_w), dtype=x.dtype, device=x.device, memory_format=torch.channels_last) 103 | else: 104 | dtype = torch.float32 if x.dtype == torch.float16 else x.dtype 105 | if pooling_mode_int == 0: 106 | out = torch.empty((out_b, c, out_h, out_w), dtype=dtype, device=x.device, memory_format=torch.channels_last) 107 | out[:] = -1e-10 108 | else: 109 | out = torch.zeros((out_b, c, out_h, out_w), dtype=dtype, device=x.device).contiguous(memory_format=torch.channels_last) 110 | 111 | if mask is not None and out_mask is None: 112 | if create_out_mask: 113 | if out_h > 1 or out_w > 1: 114 | out_mask = torch.empty((out_b, 1, out_h, out_w), dtype=torch.int32, device=x.device, memory_format=torch.channels_last) 115 | else: 116 | # this operation uses a special kernel which sets the update mask to one atomically. thus, it has to be set zero first 117 | out_mask = torch.zeros((out_b, 1, out_h, out_w), dtype=torch.int32, device=x.device).contiguous(memory_format=torch.channels_last) 118 | else: 119 | out_mask = None 120 | 121 | sparse_pooling_wrapper_masked(x, prev_x, out, mask, out_mask, kernel_size, stride, padding, dilation, pad_mode_int, pooling_mode_int, sub_tile_sparsity) 122 | 123 | if out.dtype != dtype_original: 124 | out = out.to(dtype=dtype_original) 125 | 126 | if create_out_mask and mask is not None: 127 | return out, out_mask 128 | else: 129 | return out 130 | 131 | 132 | def sparse_activation(x, prev_x, truncated, mask, threshold, activation, truncation_mode): 133 | deltacnn_activate_truncate(x, prev_x, truncated, mask, threshold, activation, truncation_mode) 134 | 135 | 136 | def sparsify(input, prev_in, delta, mask, threshold): 137 | deltacnn_prepare_diff_mask_wrapper(input, prev_in, delta, mask, threshold) 138 | 139 | 140 | def sparse_add_tensors(a, b, prev_out, out, mask_a, mask_b, mask_out, weight_a, weight_b, activation, dense_out): 141 | sparse_add_tensors_wrapper(a, b, prev_out, out, mask_a, mask_b, mask_out, weight_a, weight_b, activation, dense_out) 142 | 143 | 144 | def sparse_add_to_dense_tensor(a, b, mask_a, activation=0): 145 | sparse_add_to_dense_tensor_wrapper(a, b, mask_a, activation) 146 | 147 | 148 | def sparse_mul_add(x, x_mask, out, out_mask, scale, bias): 149 | sparse_mul_add_wrapper(x, out, x_mask, out_mask, scale, bias) 150 | 151 | 152 | def sparse_upsample(input, mask_in, scale, out=None): 153 | if out is None: 154 | is_channels_last = input.is_contiguous(memory_format=torch.channels_last) 155 | if is_channels_last: 156 | out = torch.empty((input.shape[0], input.shape[1], input.shape[2] * scale, input.shape[3] * scale), device=input.device, dtype=input.dtype, memory_format=torch.channels_last) 157 | else: 158 | out = torch.empty((input.shape[0], input.shape[1], input.shape[2] * scale, input.shape[3] * scale), device=input.device, dtype=input.dtype) 159 | 160 | mask_out = torch.empty((input.shape[0], 1, input.shape[2] * scale, input.shape[3] * scale), device=input.device, dtype=torch.int) 161 | else: 162 | out, mask_out = out 163 | 164 | sparse_upsample_wrapper(input, out, mask_in, mask_out, scale) 165 | 166 | return out, mask_out 167 | 168 | 169 | def sparse_concatenate(a, b, out=None): 170 | a, mask_a = a 171 | b, mask_b = b 172 | if out is None: 173 | is_channels_last = a.is_contiguous(memory_format=torch.channels_last) 174 | if is_channels_last: 175 | out = torch.empty((a.shape[0], a.shape[1] + b.shape[1], a.shape[2], a.shape[3]), device=a.device, dtype=a.dtype, memory_format=torch.channels_last) 176 | else: 177 | out = torch.empty((a.shape[0], a.shape[1] + b.shape[1], a.shape[2], a.shape[3]), device=a.device, dtype=a.dtype) 178 | 179 | mask_out = torch.empty((a.shape[0], 1, a.shape[2], a.shape[3]), device=a.device, dtype=torch.int) 180 | else: 181 | out, mask_out = out 182 | 183 | sparse_concatenate_wrapper(a, b, out, mask_a, mask_b, mask_out) 184 | 185 | return out, mask_out 186 | 187 | 188 | class DCPerformanceMetrics: 189 | def __init__(self, tiles, inputs, mode, flops, memtransfer, histogram, n_frames = 1): 190 | n_frames = max(n_frames, 1) 191 | self.tiles_active = tiles[0].cpu().item() // n_frames 192 | self.tiles_total = tiles[1].cpu().item() // n_frames 193 | self.tiles_ratio = self.tiles_active / max(self.tiles_total, 1) 194 | self.inputs_active = inputs[0].cpu().item() // n_frames 195 | self.inputs_total = inputs[1].cpu().item() // n_frames 196 | self.inputs_ratio = self.inputs_active / max(self.inputs_total, 1) 197 | self.mode_sparse = mode[0].cpu().item() // n_frames 198 | self.mode_dense = mode[1].cpu().item() // n_frames 199 | self.mode_ratio = self.mode_sparse / max((self.mode_sparse + self.mode_dense), 1) 200 | self.flops_actual = flops[0].cpu().item() // n_frames 201 | self.flops_theoretical = flops[1].cpu().item() // n_frames 202 | self.flops_dense = flops[2].cpu().item() // n_frames 203 | self.flops_ratio_actual = self.flops_actual / max(self.flops_dense, 1) 204 | self.flops_ratio_theoretical = self.flops_theoretical / max(self.flops_dense, 1) 205 | self.mem_read_actual = memtransfer[0].cpu().item() // n_frames 206 | self.mem_read_dense = memtransfer[1].cpu().item() // n_frames 207 | self.mem_read_ratio = self.mem_read_actual / max(self.mem_read_dense, 1) 208 | self.mem_write_actual = memtransfer[2].cpu().item() // n_frames 209 | self.mem_write_dense = memtransfer[3].cpu().item() // n_frames 210 | self.mem_write_ratio = self.mem_write_actual / max(self.mem_write_dense, 1) 211 | self.histogram = histogram.cpu().numpy() // n_frames 212 | self.histogram_ratio = (histogram / histogram[1:].sum()).cpu().numpy() 213 | 214 | def to_dict(self): 215 | metrics_dict = { 216 | "tiles_active": self.tiles_active, 217 | "tiles_total": self.tiles_total, 218 | "tiles_ratio": self.tiles_ratio, 219 | "inputs_active": self.inputs_active, 220 | "inputs_total": self.inputs_total, 221 | "inputs_ratio": self.inputs_ratio, 222 | "mode_sparse": self.mode_sparse, 223 | "mode_dense": self.mode_dense, 224 | "mode_ratio": self.mode_ratio, 225 | "flops_actual": self.flops_actual, 226 | "flops_theoretical": self.flops_theoretical, 227 | "flops_dense": self.flops_dense, 228 | "flops_ratio_actual": self.flops_ratio_actual, 229 | "flops_ratio_theoretical": self.flops_ratio_theoretical, 230 | "mem_read_actual": self.mem_read_actual, 231 | "mem_read_dense": self.mem_read_dense, 232 | "mem_read_ratio": self.mem_read_ratio, 233 | "mem_write_actual": self.mem_write_actual, 234 | "mem_write_dense": self.mem_write_dense, 235 | "mem_write_ratio": self.mem_write_ratio, 236 | "histogram": [int(x) for x in self.histogram], 237 | "histogram_ratio": [float(x) for x in self.histogram_ratio], 238 | } 239 | return metrics_dict 240 | 241 | def save_json(self, path): 242 | import json 243 | metrics_dict = self.to_dict() 244 | with open(path, "w+") as f: 245 | json.dump(metrics_dict, f) 246 | 247 | class DCPerformanceMetricsManager: 248 | _initialized = False 249 | 250 | @classmethod 251 | def init(cls): 252 | if cls._initialized: 253 | return True 254 | cls._initialized = deltacnn_init_performance_metrics() 255 | return cls._initialized 256 | 257 | @classmethod 258 | def reset(cls): 259 | if not cls.init(): 260 | return 261 | deltacnn_reset_performance_metrics() 262 | 263 | @classmethod 264 | def get_metrics(cls, n_frames=1): 265 | if not cls.init(): 266 | return None 267 | 268 | tiles, inputs, mode, flops, memtransfer, histogram = deltacnn_retrieve_metrics() 269 | metrics = DCPerformanceMetrics(tiles, inputs, mode, flops, memtransfer, histogram, n_frames=n_frames) 270 | return metrics 271 | -------------------------------------------------------------------------------- /example/imagenet_classes.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | European fire salamander 27 | common newt 28 | eft 29 | spotted salamander 30 | axolotl 31 | bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead 35 | leatherback turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | common iguana 41 | American chameleon 42 | whiptail 43 | agama 44 | frilled lizard 45 | alligator lizard 46 | Gila monster 47 | green lizard 48 | African chameleon 49 | Komodo dragon 50 | African crocodile 51 | American alligator 52 | triceratops 53 | thunder snake 54 | ringneck snake 55 | hognose snake 56 | green snake 57 | king snake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | horned viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black and gold garden spider 74 | barn spider 75 | garden spider 76 | black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie chicken 85 | peacock 86 | quail 87 | partridge 88 | African grey 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | American egret 134 | bittern 135 | crane 136 | limpkin 137 | European gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | red-backed sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound 168 | English foxhound 169 | redbone 170 | borzoi 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound 175 | Norwegian elkhound 176 | otterhound 177 | Saluki 178 | Scottish deerhound 179 | Weimaraner 180 | Staffordshire bullterrier 181 | American Staffordshire terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier 192 | Airedale 193 | cairn 194 | Australian terrier 195 | Dandie Dinmont 196 | Boston bull 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier 201 | Tibetan terrier 202 | silky terrier 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla 213 | English setter 214 | Irish setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber 218 | English springer 219 | Welsh springer spaniel 220 | cocker spaniel 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog 231 | Shetland sheepdog 232 | collie 233 | Border collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German shepherd 237 | Doberman 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard 249 | Eskimo dog 250 | malamute 251 | Siberian husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke 265 | Cardigan 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf 271 | white wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African hunting dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian cat 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | ice bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | long-horned beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage butterfly 326 | sulphur butterfly 327 | lycaenid 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | wood rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | sorrel 341 | zebra 342 | hog 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis monkey 378 | marmoset 379 | capuchin 380 | howler monkey 381 | titi 382 | spider monkey 383 | squirrel monkey 384 | Madagascar cat 385 | indri 386 | Indian elephant 387 | African elephant 388 | lesser panda 389 | giant panda 390 | barracouta 391 | eel 392 | coho 393 | rock beauty 394 | anemone fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog clock 411 | apiary 412 | apron 413 | ashcan 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint 420 | Band Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap 435 | bath towel 436 | bathtub 437 | beach wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer bottle 442 | beer glass 443 | bell cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | bullet train 468 | butcher shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | carpenter's kit 479 | carton 480 | car wheel 481 | cash machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | cellular telephone 489 | chain 490 | chainlink fence 491 | chain mail 492 | chain saw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clog 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishrag 535 | dishwasher 536 | disk brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa 554 | file 555 | fireboat 556 | fire engine 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gasmask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower 591 | hand-held computer 592 | handkerchief 593 | hard disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal bar 604 | horse cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter 672 | mountain bike 673 | mountain tent 674 | mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil box 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket fence 718 | pickup 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | pop bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | rubber eraser 769 | rugby ball 770 | rule 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe shop 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | swab 842 | sweatshirt 843 | swimming trunks 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy 852 | television 853 | tennis ball 854 | thatch 855 | theater curtain 856 | thimble 857 | thresher 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toyshop 867 | tractor 868 | trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | worm fence 914 | wreck 915 | yawl 916 | yurt 917 | web site 918 | comic book 919 | crossword puzzle 920 | street sign 921 | traffic light 922 | book jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice lolly 931 | French loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce 962 | dough 963 | meat loaf 964 | pizza 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet tissue -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /src/cuda/conv_torch_wrapper.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include "conv_kernel.cuh" 8 | #include "deconv_kernel.cuh" 9 | #include "other_nn_layers.cuh" 10 | #include 11 | #include 12 | #include 13 | #include "common.cuh" 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor") 16 | // TODO add an actual check if the layer is channels last 17 | // #define CHECK_CONTIGUOUS(x) TORCH_CHECK(!(x.is_contiguous()), #x " must be channels last") 18 | #define CHECK_CONTIGUOUS(x) {} 19 | #define CHECK_INPUT(x) \ 20 | CHECK_CUDA(x); \ 21 | CHECK_CONTIGUOUS(x) 22 | #define CHECK_DEVICE(x, y) AT_ASSERTM((x).device().index() == (y).device().index(), #x " and " #y " must be in same CUDA device") 23 | 24 | bool deltacnn_init_performance_metrics() { 25 | static bool initialized = false; 26 | if (initialized) { 27 | return true; 28 | } 29 | bool success = init_performance_metrics(); 30 | if (!success) 31 | return success; 32 | initialized = true; 33 | init_d_metrics_conv_kernels(); 34 | init_d_metrics_deconv_kernels(); 35 | init_d_metrics_other_nn_layers(); 36 | return success; 37 | } 38 | 39 | ConvConfig params_to_conv_config(std::vector stride, std::vector padding, std::vector dilation, int groups, int padding_mode, bool sub_tile_sparsity=true) { 40 | ConvConfig config; 41 | config.stride[0] = (uint16_t) stride[0]; 42 | config.stride[1] = (uint16_t) stride[1]; 43 | if (padding.size() == 4) { 44 | config.padding[0] = (uint16_t) padding[0]; 45 | config.padding[1] = (uint16_t) padding[2]; 46 | config.padding[2] = (uint16_t) padding[1]; 47 | config.padding[3] = (uint16_t) padding[3]; 48 | } else { 49 | config.padding[0] = (uint16_t) padding[0]; 50 | config.padding[1] = (uint16_t) padding[1]; 51 | config.padding[2] = (uint16_t) padding[0]; 52 | config.padding[3] = (uint16_t) padding[1]; 53 | } 54 | config.dilation[0] = (uint16_t) dilation[0]; 55 | config.dilation[1] = (uint16_t) dilation[1]; 56 | config.groups = groups; 57 | config.padding_mode = (PaddingMode) padding_mode; 58 | config.sub_tile_sparsity = sub_tile_sparsity; 59 | return config; 60 | } 61 | 62 | void sparse_conv_bias_wrapper_masked( 63 | torch::Tensor input, 64 | torch::Tensor filter, 65 | at::optional bias, 66 | torch::Tensor out, 67 | at::optional mask, 68 | at::optional out_mask, 69 | std::vector stride, 70 | std::vector padding, 71 | std::vector dilation, 72 | int groups, 73 | int padding_mode, 74 | bool sub_tile_sparsity=true 75 | ) 76 | { 77 | #ifdef ENABLE_METRICS 78 | deltacnn_init_performance_metrics(); 79 | #endif 80 | 81 | CHECK_INPUT(input); 82 | CHECK_INPUT(filter); 83 | CHECK_INPUT(out); 84 | CHECK_DEVICE(input, filter); 85 | CHECK_DEVICE(input, out); 86 | 87 | uint32_t *out_mask_ptr = nullptr; 88 | uint32_t *mask_ptr = nullptr; 89 | void *bias_ptr = nullptr; 90 | 91 | if (bias) { 92 | CHECK_INPUT((*bias)); 93 | CHECK_DEVICE(input, (*bias)); 94 | if (input.dtype() == torch::kFloat32) { 95 | bias_ptr = (void*) (*bias).data_ptr(); 96 | } else { 97 | bias_ptr = (void*) (*bias).data_ptr(); 98 | } 99 | } 100 | if (mask) { 101 | CHECK_INPUT((*mask)); 102 | CHECK_DEVICE(input, (*mask)); 103 | mask_ptr = (uint32_t*) (*mask).data_ptr(); 104 | } 105 | if (out_mask) { 106 | CHECK_INPUT((*out_mask)); 107 | CHECK_DEVICE(input, (*out_mask)); 108 | out_mask_ptr = (uint32_t*) (*out_mask).data_ptr(); 109 | } 110 | 111 | Dimensions dim; 112 | dim.batch_size = input.size(0); 113 | dim.in.c = input.size(1); 114 | dim.in.h = input.size(2); 115 | dim.in.w = input.size(3); 116 | dim.out.c = out.size(1); 117 | dim.out.h = out.size(2); 118 | dim.out.w = out.size(3); 119 | 120 | ConvConfig config = params_to_conv_config(stride, padding, dilation, groups, padding_mode, sub_tile_sparsity); 121 | config.kernel_size[0] = (uint8_t) filter.size(1); 122 | config.kernel_size[1] = (uint8_t) filter.size(2); 123 | config.set_sparse_zero = out_mask_ptr == nullptr; 124 | 125 | 126 | if (input.dtype() == torch::kFloat32) { 127 | deltacnn(input.data_ptr(), out.data_ptr(), filter.data_ptr(), (float*) bias_ptr, (uint32_t*) mask_ptr, out_mask_ptr, dim, config); 128 | } else if (input.dtype() == torch::kFloat16) { 129 | deltacnn_hp((half*)input.data_ptr(), (half*)out.data_ptr(), (half*)filter.data_ptr(), (half*) bias_ptr, mask_ptr, out_mask_ptr, dim, config); 130 | } else { 131 | printf("unsupported datatype\n"); 132 | return; 133 | } 134 | } 135 | 136 | void sparse_deconv_bias_wrapper_masked( 137 | torch::Tensor input, 138 | torch::Tensor filter, 139 | at::optional bias, 140 | torch::Tensor out, 141 | at::optional mask, 142 | at::optional out_mask, 143 | std::vector stride, 144 | std::vector padding, 145 | std::vector dilation, 146 | int groups, 147 | int padding_mode, 148 | bool sub_tile_sparsity=true 149 | ) 150 | { 151 | #ifdef ENABLE_METRICS 152 | deltacnn_init_performance_metrics(); 153 | #endif 154 | 155 | CHECK_INPUT(input); 156 | CHECK_INPUT(filter); 157 | CHECK_INPUT(out); 158 | CHECK_DEVICE(input, filter); 159 | CHECK_DEVICE(input, out); 160 | 161 | uint32_t *out_mask_ptr = nullptr; 162 | uint32_t *mask_ptr = nullptr; 163 | void *bias_ptr = nullptr; 164 | 165 | if (bias) { 166 | CHECK_INPUT((*bias)); 167 | CHECK_DEVICE(input, (*bias)); 168 | if (input.dtype() == torch::kFloat32) { 169 | bias_ptr = (void*) (*bias).data_ptr(); 170 | } else { 171 | bias_ptr = (void*) (*bias).data_ptr(); 172 | } 173 | } 174 | if (mask) { 175 | CHECK_INPUT((*mask)); 176 | CHECK_DEVICE(input, (*mask)); 177 | mask_ptr = (uint32_t*) (*mask).data_ptr(); 178 | } 179 | if (out_mask) { 180 | CHECK_INPUT((*out_mask)); 181 | CHECK_DEVICE(input, (*out_mask)); 182 | out_mask_ptr = (uint32_t*) (*out_mask).data_ptr(); 183 | } 184 | 185 | Dimensions dim; 186 | dim.batch_size = input.size(0); 187 | dim.in.c = input.size(1); 188 | dim.in.h = input.size(2); 189 | dim.in.w = input.size(3); 190 | dim.out.c = out.size(1); 191 | dim.out.h = out.size(2); 192 | dim.out.w = out.size(3); 193 | 194 | ConvConfig config = params_to_conv_config(stride, padding, dilation, groups, padding_mode, sub_tile_sparsity); 195 | config.kernel_size[0] = (uint8_t) filter.size(1); 196 | config.kernel_size[1] = (uint8_t) filter.size(2); 197 | config.set_sparse_zero = out_mask_ptr == nullptr; 198 | 199 | 200 | if (input.dtype() == torch::kFloat32) { 201 | delta_deconv(input.data_ptr(), out.data_ptr(), filter.data_ptr(), (float*) bias_ptr, (uint32_t*) mask_ptr, out_mask_ptr, dim, config); 202 | } 203 | else if (input.dtype() == torch::kFloat16) { 204 | delta_deconv_hp((half*)input.data_ptr(), (half*)out.data_ptr(), (half*)filter.data_ptr(), (half*) bias_ptr, mask_ptr, out_mask_ptr, dim, config); 205 | } 206 | else { 207 | printf("unsupported datatype\n"); 208 | return; 209 | } 210 | } 211 | 212 | 213 | void sparse_pooling_wrapper_masked( 214 | torch::Tensor input, 215 | at::optional prev_in, 216 | torch::Tensor out, 217 | at::optional mask, 218 | at::optional out_mask, 219 | std::vector kernel_size, 220 | std::vector stride, 221 | std::vector padding, 222 | std::vector dilation, 223 | int padding_mode, 224 | int pooling_mode, 225 | bool sub_tile_sparsity=true 226 | ) 227 | { 228 | #ifdef ENABLE_METRICS 229 | deltacnn_init_performance_metrics(); 230 | #endif 231 | 232 | CHECK_INPUT(input); 233 | CHECK_INPUT(out); 234 | CHECK_DEVICE(input, out); 235 | 236 | void *prev_in_ptr = nullptr; 237 | uint32_t *out_mask_ptr = nullptr; 238 | uint32_t *mask_ptr = nullptr; 239 | 240 | if (prev_in) { 241 | CHECK_INPUT((*prev_in)); 242 | CHECK_DEVICE(input, (*prev_in)); 243 | } 244 | if (mask) { 245 | CHECK_INPUT((*mask)); 246 | CHECK_DEVICE(input, (*mask)); 247 | mask_ptr = (uint32_t*) (*mask).data_ptr(); 248 | } 249 | if (out_mask) { 250 | CHECK_INPUT((*out_mask)); 251 | CHECK_DEVICE(input, (*out_mask)); 252 | out_mask_ptr = (uint32_t*) (*out_mask).data_ptr(); 253 | } 254 | 255 | void *out_data_ptr_hp = nullptr; 256 | if (out.dtype() == torch::kFloat32) { 257 | out_data_ptr_hp = (void*) out.data_ptr(); 258 | if (prev_in) { 259 | prev_in_ptr = (void*) (*prev_in).data_ptr(); 260 | } 261 | } else if (out.dtype() == torch::kFloat16) { 262 | out_data_ptr_hp = (void*) out.data_ptr(); 263 | if (prev_in) { 264 | prev_in_ptr = (void*) (*prev_in).data_ptr(); 265 | } 266 | } 267 | 268 | ConvConfig config = params_to_conv_config(stride, padding, dilation, 1, padding_mode, sub_tile_sparsity); 269 | config.kernel_size[0] = (uint16_t) kernel_size[0]; 270 | config.kernel_size[1] = (uint16_t) kernel_size[1]; 271 | config.set_sparse_zero = out_mask_ptr == nullptr; 272 | 273 | Dimensions dim; 274 | dim.batch_size = input.size(0); 275 | dim.in.c = input.size(1); 276 | dim.in.h = input.size(2); 277 | dim.in.w = input.size(3); 278 | dim.out.c = out.size(1); 279 | dim.out.h = out.size(2); 280 | dim.out.w = out.size(3); 281 | 282 | 283 | // printf("out ptr address=%p\n", (void*) out_data_ptr_hp); 284 | if (input.dtype() == torch::kFloat32) { 285 | sparse_pool(input.data_ptr(), (float*) prev_in_ptr, out.data_ptr(), (uint32_t*) mask_ptr, out_mask_ptr, dim, config, pooling_mode); 286 | } else if (input.dtype() == torch::kFloat16) { 287 | sparse_pool_hp((half*) input.data_ptr(), (half*) prev_in_ptr, (half*) out_data_ptr_hp, (uint32_t*) mask_ptr, out_mask_ptr, dim, config, pooling_mode); 288 | } else { 289 | printf("unsupported datatype\n"); 290 | return; 291 | } 292 | // printf("out ptr address=%p\n", (void*) out_data_ptr_hp); 293 | } 294 | 295 | void deltacnn_activate_truncate( 296 | torch::Tensor input, 297 | torch::Tensor prev_input, 298 | at::optional truncated, 299 | torch::Tensor mask, 300 | float threshold, 301 | int activation, 302 | int truncation_mode 303 | ) 304 | { 305 | #ifdef ENABLE_METRICS 306 | deltacnn_init_performance_metrics(); 307 | #endif 308 | 309 | CHECK_INPUT(input); 310 | CHECK_INPUT(prev_input); 311 | float *truncated_ptr = nullptr; 312 | half *truncated_ptr_hp = nullptr; 313 | if (truncated) { 314 | CHECK_INPUT(*truncated); 315 | CHECK_DEVICE(input, *truncated); 316 | if (input.dtype() == torch::kFloat32) 317 | truncated_ptr = (float*) (*truncated).data_ptr(); 318 | if (input.dtype() == torch::kFloat16) 319 | truncated_ptr_hp = (half*) (*truncated).data_ptr(); 320 | } 321 | CHECK_DEVICE(input, prev_input); 322 | CHECK_INPUT(mask); 323 | CHECK_DEVICE(input, mask); 324 | Dimensions dim; 325 | dim.batch_size = input.size(0); 326 | dim.in.h = input.size(2); 327 | dim.in.w = input.size(3); 328 | dim.in.c = input.size(1); 329 | dim.out.h = input.size(2); 330 | dim.out.w = input.size(3); 331 | dim.out.c = input.size(1); 332 | 333 | 334 | 335 | if (input.dtype() == torch::kFloat32) { 336 | activate_truncate(input.data_ptr(), prev_input.data_ptr(), truncated_ptr, (uint32_t*) mask.data_ptr(), threshold, dim, activation, truncation_mode); 337 | } else if (input.dtype() == torch::kFloat16) { 338 | activate_truncate_hp((half*) input.data_ptr(), (half*) prev_input.data_ptr(), truncated_ptr_hp, (uint32_t*) mask.data_ptr(), threshold, dim, activation, truncation_mode); 339 | } 340 | else { 341 | printf("unsupported datatype\n"); 342 | return; 343 | } 344 | } 345 | 346 | 347 | void deltacnn_prepare_diff_mask_wrapper( 348 | torch::Tensor input, 349 | torch::Tensor prev_input, 350 | torch::Tensor delta, 351 | torch::Tensor mask, 352 | float threshold 353 | ) 354 | { 355 | #ifdef ENABLE_METRICS 356 | deltacnn_init_performance_metrics(); 357 | #endif 358 | 359 | CHECK_INPUT(input); 360 | CHECK_INPUT(prev_input); 361 | CHECK_INPUT(delta); 362 | CHECK_INPUT(mask); 363 | CHECK_DEVICE(input, prev_input); 364 | CHECK_DEVICE(input, delta); 365 | CHECK_DEVICE(input, mask); 366 | 367 | Dimensions dim; 368 | dim.batch_size = input.size(0); 369 | dim.in.c = input.size(1); 370 | dim.in.h = input.size(2); 371 | dim.in.w = input.size(3); 372 | dim.out.c = input.size(1); 373 | dim.out.h = input.size(2); 374 | dim.out.w = input.size(3); 375 | 376 | if (input.dtype() == torch::kFloat32) { 377 | prepare_diff_mask(input.data_ptr(), prev_input.data_ptr(), delta.data_ptr(), (uint32_t*) mask.data_ptr(), threshold, dim); 378 | } else if (input.dtype() == torch::kFloat16) { 379 | prepare_diff_mask_hp((half*) input.data_ptr(), (half*) prev_input.data_ptr(), (half*) delta.data_ptr(), (uint32_t*) mask.data_ptr(), threshold, dim); 380 | } else { 381 | printf("unsupported datatype\n"); 382 | return; 383 | } 384 | } 385 | 386 | 387 | void sparse_add_tensors_wrapper( 388 | torch::Tensor a, 389 | torch::Tensor b, 390 | at::optional prev_out, 391 | torch::Tensor out, 392 | torch::Tensor mask_a, 393 | torch::Tensor mask_b, 394 | torch::Tensor mask_out, 395 | float weight_a, 396 | float weight_b, 397 | int activation, 398 | bool dense_out 399 | ) 400 | { 401 | #ifdef ENABLE_METRICS 402 | deltacnn_init_performance_metrics(); 403 | #endif 404 | 405 | CHECK_INPUT(a); 406 | CHECK_INPUT(b); 407 | CHECK_INPUT(out); 408 | CHECK_INPUT(mask_a); 409 | CHECK_INPUT(mask_b); 410 | CHECK_INPUT(mask_out); 411 | CHECK_DEVICE(a, b); 412 | CHECK_DEVICE(a, out); 413 | CHECK_DEVICE(a, mask_a); 414 | CHECK_DEVICE(a, mask_b); 415 | CHECK_DEVICE(a, mask_out); 416 | 417 | Dimensions dim; 418 | dim.batch_size = a.size(0); 419 | dim.in.c = a.size(1); 420 | dim.in.h = a.size(2); 421 | dim.in.w = a.size(3); 422 | dim.out.c = a.size(1); 423 | dim.out.h = a.size(2); 424 | dim.out.w = a.size(3); 425 | 426 | void *prev_out_ptr = nullptr; 427 | if (prev_out) { 428 | CHECK_INPUT((*prev_out)); 429 | CHECK_DEVICE(a, (*prev_out)); 430 | 431 | if ((*prev_out).dtype() == torch::kFloat32) { 432 | prev_out_ptr = (void*) (*prev_out).data_ptr(); 433 | } 434 | } 435 | 436 | if (a.dtype() == torch::kFloat32) { 437 | sparse_add_tensors(a.data_ptr(), b.data_ptr(), (float*) prev_out_ptr, out.data_ptr(), (uint32_t*) mask_a.data_ptr(), (uint32_t*) mask_b.data_ptr(), (uint32_t*) mask_out.data_ptr(), weight_a, weight_b, dim, activation, dense_out); 438 | } else if (a.dtype() == torch::kFloat16) { 439 | sparse_add_tensors_hp((half*) a.data_ptr(), (half*) b.data_ptr(), (half*) prev_out_ptr, (half*) out.data_ptr(), (uint32_t*) mask_a.data_ptr(), (uint32_t*) mask_b.data_ptr(), (uint32_t*) mask_out.data_ptr(), weight_a, weight_b, dim, activation, dense_out); 440 | } else { 441 | printf("unsupported datatype\n"); 442 | return; 443 | } 444 | } 445 | 446 | 447 | void sparse_add_to_dense_tensor_wrapper( 448 | torch::Tensor a, 449 | torch::Tensor b, 450 | torch::Tensor mask_a, 451 | int activation 452 | ) 453 | { 454 | #ifdef ENABLE_METRICS 455 | deltacnn_init_performance_metrics(); 456 | #endif 457 | 458 | CHECK_INPUT(a); 459 | CHECK_INPUT(b); 460 | CHECK_INPUT(mask_a); 461 | CHECK_DEVICE(a, b); 462 | CHECK_DEVICE(a, mask_a); 463 | 464 | Dimensions dim; 465 | dim.batch_size = a.size(0); 466 | dim.in.c = a.size(1); 467 | dim.in.h = a.size(2); 468 | dim.in.w = a.size(3); 469 | dim.out.c = a.size(1); 470 | dim.out.h = a.size(2); 471 | dim.out.w = a.size(3); 472 | 473 | if (a.dtype() == torch::kFloat32) { 474 | sparse_add_to_dense_tensor_sp(a.data_ptr(), b.data_ptr(), (uint32_t*) mask_a.data_ptr(), dim, activation); 475 | } else if (a.dtype() == torch::kFloat16) { 476 | sparse_add_to_dense_tensor_hp((half*) a.data_ptr(), (half*) b.data_ptr(),(uint32_t*) mask_a.data_ptr(), dim, activation); 477 | } else { 478 | printf("unsupported datatype\n"); 479 | return; 480 | } 481 | } 482 | 483 | 484 | void sparse_upsample_wrapper( 485 | torch::Tensor input, 486 | torch::Tensor out, 487 | torch::Tensor mask_in, 488 | torch::Tensor mask_out, 489 | int scale 490 | ) 491 | { 492 | #ifdef ENABLE_METRICS 493 | deltacnn_init_performance_metrics(); 494 | #endif 495 | 496 | CHECK_INPUT(input); 497 | CHECK_INPUT(out); 498 | CHECK_INPUT(mask_in); 499 | CHECK_INPUT(mask_out); 500 | CHECK_DEVICE(input, out); 501 | CHECK_DEVICE(input, mask_in); 502 | CHECK_DEVICE(input, mask_out); 503 | 504 | Dimensions dim; 505 | dim.batch_size = input.size(0); 506 | dim.in.c = input.size(1); 507 | dim.in.h = input.size(2); 508 | dim.in.w = input.size(3); 509 | dim.out.c = out.size(1); 510 | dim.out.h = out.size(2); 511 | dim.out.w = out.size(3); 512 | 513 | if (input.dtype() == torch::kFloat32) { 514 | sparse_upsample(input.data_ptr(), out.data_ptr(), (uint32_t*) mask_in.data_ptr(), (uint32_t*) mask_out.data_ptr(), dim, scale); 515 | } else if (input.dtype() == torch::kFloat16) { 516 | sparse_upsample_hp((half*) input.data_ptr(), (half*) out.data_ptr(), (uint32_t*) mask_in.data_ptr(), (uint32_t*) mask_out.data_ptr(), dim, scale); 517 | } else { 518 | printf("unsupported datatype\n"); 519 | return; 520 | } 521 | } 522 | 523 | 524 | void sparse_concatenate_wrapper( 525 | torch::Tensor a, 526 | torch::Tensor b, 527 | torch::Tensor out, 528 | torch::Tensor mask_a, 529 | torch::Tensor mask_b, 530 | torch::Tensor mask_out 531 | ) 532 | { 533 | #ifdef ENABLE_METRICS 534 | deltacnn_init_performance_metrics(); 535 | #endif 536 | 537 | CHECK_INPUT(a); 538 | CHECK_INPUT(b); 539 | CHECK_INPUT(out); 540 | CHECK_INPUT(mask_a); 541 | CHECK_INPUT(mask_b); 542 | CHECK_INPUT(mask_out); 543 | CHECK_DEVICE(a, b); 544 | CHECK_DEVICE(a, out); 545 | CHECK_DEVICE(a, mask_a); 546 | CHECK_DEVICE(a, mask_b); 547 | CHECK_DEVICE(a, mask_out); 548 | 549 | Dimensions dim; 550 | dim.batch_size = a.size(0); 551 | dim.in.c = a.size(1); 552 | dim.in.h = a.size(2); 553 | dim.in.w = a.size(3); 554 | dim.out.c = out.size(1); 555 | dim.out.h = out.size(2); 556 | dim.out.w = out.size(3); 557 | 558 | if (a.dtype() == torch::kFloat32) { 559 | sparse_concatenate(a.data_ptr(), b.data_ptr(), out.data_ptr(), (uint32_t*) mask_a.data_ptr(), (uint32_t*) mask_b.data_ptr(), (uint32_t*) mask_out.data_ptr(), dim); 560 | } else if (a.dtype() == torch::kFloat16) { 561 | sparse_concatenate_hp((half*) a.data_ptr(), (half*) b.data_ptr(), (half*) out.data_ptr(), (uint32_t*) mask_a.data_ptr(), (uint32_t*) mask_b.data_ptr(), (uint32_t*) mask_out.data_ptr(), dim); 562 | } else { 563 | printf("unsupported datatype\n"); 564 | return; 565 | } 566 | } 567 | 568 | 569 | void sparse_mul_add_wrapper( 570 | torch::Tensor in, 571 | torch::Tensor out, 572 | at::optional mask_in, 573 | at::optional mask_out, 574 | torch::Tensor scale, 575 | at::optional bias 576 | ) 577 | { 578 | #ifdef ENABLE_METRICS 579 | deltacnn_init_performance_metrics(); 580 | #endif 581 | 582 | CHECK_INPUT(in); 583 | CHECK_INPUT(out); 584 | CHECK_INPUT(scale); 585 | CHECK_DEVICE(in, out); 586 | CHECK_DEVICE(in, scale); 587 | 588 | void *bias_ptr = nullptr; 589 | if (bias) { 590 | CHECK_INPUT(*bias); 591 | CHECK_DEVICE(in, *bias); 592 | if (bias->dtype() == torch::kFloat32) { 593 | bias_ptr = (void*) bias->data_ptr(); 594 | } else if (bias->dtype() == torch::kFloat16) { 595 | bias_ptr = (void*) bias->data_ptr(); 596 | } 597 | } 598 | 599 | uint32_t *mask_in_ptr = nullptr; 600 | if (mask_in) { 601 | CHECK_INPUT(*mask_in); 602 | CHECK_DEVICE(in, *mask_in); 603 | mask_in_ptr = (uint32_t*) mask_in->data_ptr(); 604 | } 605 | 606 | uint32_t *mask_out_ptr = nullptr; 607 | if (mask_out) { 608 | CHECK_INPUT(*mask_out); 609 | CHECK_DEVICE(in, *mask_out); 610 | mask_out_ptr = (uint32_t*) mask_out->data_ptr(); 611 | } 612 | 613 | Dimensions dim; 614 | dim.batch_size = in.size(0); 615 | dim.in.c = in.size(1); 616 | dim.in.h = in.size(2); 617 | dim.in.w = in.size(3); 618 | dim.out.c = in.size(1); 619 | dim.out.h = in.size(2); 620 | dim.out.w = in.size(3); 621 | 622 | if (in.dtype() == torch::kFloat32) { 623 | sparse_mul_add(in.data_ptr(), mask_in_ptr, out.data_ptr(), mask_out_ptr, scale.data_ptr(), (float*) bias_ptr, dim); 624 | } else if (in.dtype() == torch::kFloat16) { 625 | sparse_mul_add_hp((half*) in.data_ptr(), mask_in_ptr, (half*) out.data_ptr(), mask_out_ptr, (half*) scale.data_ptr(), (half*) bias_ptr, dim); 626 | } else { 627 | printf("unsupported datatype\n"); 628 | return; 629 | } 630 | } 631 | 632 | void deltacnn_reset_performance_metrics() { 633 | reset_performance_metrics(); 634 | } 635 | 636 | std::vector deltacnn_retrieve_metrics() { 637 | return retrieve_metrics(); 638 | } 639 | 640 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 641 | { 642 | m.def("sparse_conv_bias_wrapper_masked", &sparse_conv_bias_wrapper_masked, "Sparse Convolution with Bias PyTorch wrapper with per pixel mask (CUDA)"); 643 | m.def("sparse_deconv_bias_wrapper_masked", &sparse_deconv_bias_wrapper_masked, "Sparse Transposed Convolution with Bias PyTorch wrapper with per pixel mask (CUDA)"); 644 | m.def("sparse_pooling_wrapper_masked", &sparse_pooling_wrapper_masked, "Sparse Pooling (CUDA)"); 645 | m.def("deltacnn_activate_truncate", &deltacnn_activate_truncate, "Aggregate previous and current inputs (CUDA)"); 646 | m.def("deltacnn_prepare_diff_mask_wrapper", &deltacnn_prepare_diff_mask_wrapper, "Calculate diff mask and update previous input (CUDA)"); 647 | m.def("sparse_add_tensors_wrapper", &sparse_add_tensors_wrapper, "Add 2 sparse tensors and create a union of their mask (CUDA)"); 648 | m.def("sparse_add_to_dense_tensor_wrapper", &sparse_add_to_dense_tensor_wrapper, "Add sparse tensor updates to dense tensor (CUDA)"); 649 | m.def("sparse_mul_add_wrapper", &sparse_mul_add_wrapper, "Apply scale and bias to sparse tensor (CUDA)"); 650 | m.def("sparse_upsample_wrapper", &sparse_upsample_wrapper, "Upsample an image with integer scales (CUDA)"); 651 | m.def("sparse_concatenate_wrapper", &sparse_concatenate_wrapper, "Concatenated Tensors A and B along the channels (dim=1) (CUDA)"); 652 | m.def("deltacnn_init_performance_metrics", &deltacnn_init_performance_metrics, "Init DeltaCNN performance metrics. Returns false if metrics are disabled by compile flags"); 653 | m.def("deltacnn_reset_performance_metrics", &deltacnn_reset_performance_metrics, "Reset DeltaCNN performance metrics"); 654 | m.def("deltacnn_retrieve_metrics", &deltacnn_retrieve_metrics, "Get performance metrics. In this order: tiles, inputs, mode, flops, memtransfer, histogram"); 655 | } -------------------------------------------------------------------------------- /src/deltacnn/logging_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.tensorboard import SummaryWriter 9 | from .utils import count_active_neighbors 10 | from torch.nn.functional import interpolate 11 | 12 | 13 | class LoggingController: 14 | img_epochs = 10 15 | 16 | def __init__(self, model, writer: SummaryWriter): 17 | self.model = model 18 | self.writer = writer 19 | self._all_logging_layers = None 20 | 21 | def log_densities(self, epoch, reset=False, disable=False, prefix=""): 22 | for layer in self.get_all_logging_layers().get(DensityLogger, []): 23 | self.writer.add_scalar(f"{prefix}densities/{layer.name}", layer.get(), epoch) 24 | if reset: 25 | layer.reset() 26 | if disable: 27 | layer.disable_logging() 28 | 29 | def _log_pixel_masks_impl(self, channels, name, epoch, individual_channels=False, prefix=""): 30 | def to_img(x): 31 | return (x * 255).to(torch.uint8) 32 | 33 | self.writer.add_image(f"{prefix}pixel-mask-acc/{name}/", to_img(torch.mean(channels, dim=0, keepdim=True)), epoch) 34 | 35 | if individual_channels: 36 | for ch_idx in range(len(channels)): 37 | self.writer.add_image(f"{prefix}pixel-mask-channels/{name}/{ch_idx}/", to_img(channels[None, ch_idx]), epoch) 38 | 39 | def log_pixel_masks(self, epoch, reset=False, disable=False): 40 | for layer in self.get_all_logging_layers().get(PixelMaskLogger, []): 41 | self._log_pixel_masks_impl(layer.get(), layer.name, epoch) 42 | 43 | if reset: 44 | layer.reset() 45 | if disable: 46 | layer.disable_logging() 47 | 48 | def log_diff_in_histogram(self, epoch, reset=False, disable=False): 49 | total_histogram = [] 50 | for layer in self.get_all_logging_layers().get(DifferentialPixelMaskLogger, []): 51 | histrogram = layer.get_histogram(tile_size=(6, 6)) 52 | 53 | bins = range(0, 65) 54 | 55 | if histrogram is not None and len(histrogram) > 0: 56 | self.writer.add_histogram(f"tile-density/{layer.name}", histrogram, epoch, bins=bins) 57 | histogram_non_empty = histrogram[histrogram > 0] 58 | if len(histogram_non_empty > 0): 59 | self.writer.add_histogram(f"tile-density-non-empty/{layer.name}", histogram_non_empty, epoch, bins=bins) 60 | total_histogram.extend(list(histrogram)) 61 | 62 | if reset: 63 | layer.reset() 64 | if disable: 65 | layer.disable_logging() 66 | 67 | if len(total_histogram) > 0: 68 | t_total_histogram = torch.tensor(total_histogram) 69 | self.writer.add_histogram(f"tile-density-total/all", t_total_histogram, epoch, bins=bins) 70 | t_total_histogram_non_empty = t_total_histogram[t_total_histogram > 0] 71 | if len(t_total_histogram_non_empty) > 0: 72 | self.writer.add_histogram(f"tile-density-total/non-empty", t_total_histogram_non_empty, epoch, bins=bins) 73 | 74 | def image_logger(self, epoch, reset, disable, path, cls): 75 | def to_img(x): 76 | return (x * 255).to(torch.uint8) 77 | 78 | for layer in self.get_all_logging_layers().get(cls, []): 79 | if epoch % LoggingController.img_epochs == (LoggingController.img_epochs - 1): 80 | val = layer.get() 81 | if val is not None: 82 | self.writer.add_image(f"{path}/{layer.name}/", to_img(torch.mean(val[None], dim=0, keepdim=True)), epoch) 83 | if reset: 84 | layer.reset() 85 | if disable: 86 | layer.disable_logging() 87 | 88 | def log_prev_input(self, epoch, reset=False, disable=False): 89 | self.image_logger(epoch, reset, disable, "prev_input", PrevInputLogger) 90 | 91 | def log_input(self, epoch, reset=False, disable=False): 92 | self.image_logger(epoch, reset, disable, "input", InputLogger) 93 | 94 | def log_output(self, epoch, reset=False, disable=False): 95 | self.image_logger(epoch, reset, disable, "output", OutputLogger) 96 | 97 | def log_computations(self, epoch, reset=False, disable=False): 98 | def to_img(x): 99 | return (x * 255).to(torch.uint8) 100 | 101 | all_masks = [] 102 | max_res = [0,0] 103 | 104 | logged_total_computations = False 105 | for layer in self.get_all_logging_layers().get(ComputationsLogger, []): 106 | if not logged_total_computations: 107 | self.writer.add_scalar(f"computations", layer.get(), epoch) 108 | logged_total_computations = True 109 | 110 | mask = layer.get_mask() 111 | if mask is not None and epoch % LoggingController.img_epochs == (LoggingController.img_epochs - 1): 112 | self.writer.add_image(f"update-mask/{layer.name}/", to_img(torch.mean(mask[None], dim=0, keepdim=True)), epoch) 113 | all_masks.append(mask) 114 | if mask.shape[0] > max_res[0] or mask.shape[1] > max_res[1]: 115 | max_res = mask.shape 116 | 117 | if reset: 118 | layer.reset() 119 | if disable: 120 | layer.disable_logging() 121 | 122 | if len(all_masks) > 0: 123 | for i in range(len(all_masks)): 124 | all_masks[i] = interpolate(all_masks[i][None, None], size=(max_res[-2], max_res[-1]), mode="nearest") 125 | mean_mask = to_img(torch.mean(torch.cat(all_masks, dim=1),dim=1)) 126 | self.writer.add_image("update-density", mean_mask, epoch) 127 | 128 | def log_multiplications(self, epoch, reset=False, disable=False): 129 | logged_global = False 130 | for layer in self.get_all_logging_layers().get(MultiplicationsLogger, []): 131 | if not logged_global: 132 | self.writer.add_scalar(f"multiplications/rel", layer.get_global_rel(), epoch) 133 | self.writer.add_scalar(f"multiplications/abs", layer.get_global_abs(), epoch) 134 | logged_global = True 135 | 136 | self.writer.add_scalar(f"mul-per-layer-rel/{layer.name}", layer.get_self_rel(), epoch) 137 | self.writer.add_scalar(f"mul-per-layer-abs/{layer.name}", layer.get_self_abs(), epoch) 138 | 139 | if reset: 140 | layer.reset() 141 | if disable: 142 | layer.disable_logging() 143 | 144 | def log_prev_input_bandwidth(self, epoch, reset=False, disable=False): 145 | for layer in self.get_all_logging_layers().get(PrevInBandwidthLogger, []): 146 | self.writer.add_scalar(f"prev_in_bandwidth", layer.get(), epoch) 147 | 148 | if reset: 149 | layer.reset() 150 | if disable: 151 | layer.disable_logging() 152 | break 153 | 154 | def log_prev_in_sizes(self, epoch, reset=False, disable=False): 155 | logged_global = False 156 | for layer in self.get_all_logging_layers().get(PrevInputSizeLogger, []): 157 | if not logged_global: 158 | self.writer.add_scalar(f"prev-in-size-global/", layer.get_global(), epoch) 159 | logged_global = True 160 | 161 | self.writer.add_scalar(f"prev-in-size/{layer.name}/", layer.get(), epoch) 162 | 163 | if reset: 164 | layer.reset() 165 | if disable: 166 | layer.disable_logging() 167 | 168 | def log_gradient(self, epoch, reset=False, disable=False): 169 | for layer in self.get_all_logging_layers().get(GradientLogger, []): 170 | result = layer.get() 171 | if result is None: 172 | continue 173 | self.writer.add_image(f"gradient", result, epoch) 174 | 175 | if reset: 176 | layer.reset() 177 | if disable: 178 | layer.disable_logging() 179 | 180 | def log_combined_loggers(self, epoch, reset=False, disable=False): 181 | logged_total_diff_density = False 182 | 183 | for layer in self.get_all_logging_layers().get(Loggers, []): 184 | outputs = layer.get() 185 | 186 | for log_name, log_out in outputs.items(): 187 | if "DiffMask" in log_name and epoch % LoggingController.img_epochs == (LoggingController.img_epochs - 1): 188 | self._log_pixel_masks_impl(log_out, layer.name, epoch, prefix="diff-") 189 | elif "Mask" in log_name and epoch % LoggingController.img_epochs == (LoggingController.img_epochs - 1): 190 | self._log_pixel_masks_impl(log_out, layer.name, epoch) 191 | elif "DiffDensity" in log_name: 192 | self.writer.add_scalar(f"diff-densities/{layer.name}", log_out, epoch) 193 | if not logged_total_diff_density: 194 | self.writer.add_scalar(f"diff_densities_sum", layer.logs[log_name].get_sum(), epoch) 195 | logged_total_diff_density = True 196 | elif "Density" in log_name: 197 | self.writer.add_scalar(f"densities/{layer.name}", log_out, epoch) 198 | 199 | if reset: 200 | layer.reset() 201 | if disable: 202 | layer.disable_logging() 203 | 204 | def get_all_logging_layers(self): 205 | # if list was already computed, return a copy of it to save time 206 | if self._all_logging_layers is not None: 207 | return {**self._all_logging_layers} 208 | 209 | layers = {} 210 | 211 | def recursive_get_layers(module): 212 | for layer in module.modules(): 213 | cls = type(layer) 214 | 215 | is_logging_layer = issubclass(cls, LoggingLayer) 216 | if is_logging_layer and cls not in layers: 217 | layers[cls] = [] 218 | if is_logging_layer and layer in layers[cls]: 219 | continue 220 | elif is_logging_layer: 221 | layers[cls].append(layer) 222 | elif layer != module: 223 | recursive_get_layers(layer) 224 | 225 | recursive_get_layers(self.model) 226 | self._all_logging_layers = {**layers} 227 | return layers 228 | 229 | def reset_loggers(self, enable=False, disable=False): 230 | for layers in self.get_all_logging_layers().values(): 231 | for layer in layers: 232 | layer.reset() 233 | if enable: 234 | layer.enable_logging() 235 | if disable: 236 | layer.disable_logging() 237 | 238 | def reset_loggers_history(self): 239 | for layers in self.get_all_logging_layers().values(): 240 | for layer in layers: 241 | layer.reset_history() 242 | 243 | def write_logs(self, epoch, reset=False, disable=False): 244 | self.log_densities(epoch, reset, disable) 245 | self.log_pixel_masks(epoch, reset, disable) 246 | self.log_computations(epoch, reset, disable) 247 | self.log_multiplications(epoch, reset, disable) 248 | self.log_prev_input_bandwidth(epoch, reset, disable) 249 | self.log_input(epoch, reset, disable) 250 | self.log_output(epoch, reset, disable) 251 | self.log_prev_input(epoch, reset, disable) 252 | self.log_prev_in_sizes(epoch, reset, disable) 253 | self.log_diff_in_histogram(epoch, reset, disable) 254 | # self.log_combined_loggers(epoch, reset, disable) 255 | 256 | 257 | class LoggingLayer(nn.Module): 258 | id = 1 259 | 260 | def __init__(self, name="", enabled=False): 261 | super(LoggingLayer, self).__init__() 262 | self.name = name 263 | self.enabled = enabled 264 | self.added_id = False 265 | 266 | def enable_logging(self): 267 | self.enabled = True 268 | 269 | def disable_logging(self): 270 | self.enabled = False 271 | 272 | def get(self): 273 | pass 274 | 275 | def reset(self): 276 | pass 277 | 278 | def reset_history(self): 279 | pass 280 | 281 | 282 | class DensityLogger(LoggingLayer): 283 | def __init__(self, threshold=0.0, **kwargs): 284 | super(DensityLogger, self).__init__(**kwargs) 285 | self.active_pixels = 0 286 | self.total_pixels = 0 287 | self.threshold = threshold 288 | 289 | def forward(self, x): 290 | if not self.added_id: 291 | self.name = f"{LoggingLayer.id} {self.name}" 292 | LoggingLayer.id += 1 293 | self.added_id = True 294 | 295 | if self.enabled: 296 | self.active_pixels += torch.sum(torch.abs(x) > self.threshold) 297 | self.total_pixels += x.numel() 298 | return x 299 | 300 | def get(self): 301 | return float(self.active_pixels) / float(self.total_pixels) 302 | 303 | def reset(self): 304 | self.active_pixels = 0 305 | self.total_pixels = 0 306 | 307 | 308 | class DifferentialDensityLogger(LoggingLayer): 309 | sum_active_pixels = 0 310 | sum_total_pixels = 0 311 | 312 | def __init__(self, threshold=0.05, **kwargs): 313 | super(DifferentialDensityLogger, self).__init__(**kwargs) 314 | self.active_pixels = 0 315 | self.total_pixels = 0 316 | self.prev_frame = None 317 | self.threshold = threshold 318 | 319 | def forward(self, x): 320 | if not self.added_id: 321 | self.name = f"{LoggingLayer.id} {self.name}" 322 | LoggingLayer.id += 1 323 | self.added_id = True 324 | 325 | if self.enabled: 326 | if self.prev_frame is None: 327 | diff = torch.ones_like(x) 328 | else: 329 | diff = torch.abs(self.prev_frame - x) > self.threshold 330 | self.active_pixels += torch.sum(diff != 0.0) 331 | self.total_pixels += diff.numel() 332 | self.prev_frame = x.detach() 333 | DifferentialDensityLogger.sum_active_pixels += torch.sum(diff != 0.0) 334 | DifferentialDensityLogger.sum_total_pixels += diff.numel() 335 | return x 336 | 337 | def get(self): 338 | return float(self.active_pixels) / float(self.total_pixels) 339 | 340 | def get_sum(self): 341 | if DifferentialDensityLogger.sum_total_pixels == 0: 342 | return 0 343 | return float(DifferentialDensityLogger.sum_active_pixels) / float(DifferentialDensityLogger.sum_total_pixels) 344 | 345 | def reset(self): 346 | self.active_pixels = 0 347 | self.total_pixels = 0 348 | DifferentialDensityLogger.sum_active_pixels = 0 349 | DifferentialDensityLogger.sum_total_pixels = 0 350 | 351 | def reset_history(self): 352 | self.prev_frame = None 353 | 354 | 355 | class PixelMaskLogger(LoggingLayer): 356 | def __init__(self, threshold=0.0, **kwargs): 357 | super(PixelMaskLogger, self).__init__(**kwargs) 358 | self.channels = None 359 | self.n_samples = 0 360 | self.threshold = threshold 361 | 362 | def reset(self): 363 | self.channels = None 364 | self.n_samples = 0 365 | 366 | def forward(self, x): 367 | if not self.added_id: 368 | self.name = f"{LoggingLayer.id} {self.name}" 369 | LoggingLayer.id += 1 370 | self.added_id = True 371 | 372 | if self.enabled: 373 | if self.channels is None: 374 | self.channels = torch.zeros_like(x[0]) 375 | self.channels += torch.sum(torch.abs(x) > self.threshold, dim=0) 376 | self.n_samples += len(x) 377 | 378 | return x 379 | 380 | def get(self): 381 | return self.channels / self.n_samples 382 | 383 | 384 | class DifferentialPixelMaskLogger(LoggingLayer): 385 | def __init__(self, threshold=0.05, **kwargs): 386 | super(DifferentialPixelMaskLogger, self).__init__(**kwargs) 387 | self.channels = None 388 | self.n_samples = 0 389 | self.prev_frame = None 390 | self.threshold = threshold 391 | 392 | def reset(self): 393 | self.channels = None 394 | self.n_samples = 0 395 | 396 | def reset_history(self): 397 | self.prev_frame = None 398 | 399 | def forward(self, x, use_direct=True): 400 | if not self.added_id: 401 | self.name = f"{LoggingLayer.id} {self.name}" 402 | LoggingLayer.id += 1 403 | self.added_id = True 404 | 405 | if self.enabled: 406 | if self.channels is None: 407 | self.channels = torch.zeros_like(x[0]).int() 408 | 409 | if use_direct: 410 | diff = x 411 | else: 412 | if self.prev_frame is None: 413 | diff = torch.ones_like(x) 414 | else: 415 | diff = torch.abs(self.prev_frame - x) > self.threshold 416 | self.channels += torch.sum(diff, dim=0) 417 | self.n_samples += len(x) 418 | if not use_direct: 419 | self.prev_frame = x.detach() 420 | 421 | return x 422 | 423 | def get(self): 424 | return self.channels / self.n_samples 425 | 426 | def get_histogram(self, tile_size=(6, 6)): 427 | # densities = torch.zeros((tile_size[-2] + 2) * (tile_size[-1] + 2) + 1) 428 | # tiles = ((self.channels.shape[-2] + tile_size[-2] - 1) // tile_size[-2]) * ((self.channels.shape[-1] + tile_size[-1] - 1) // tile_size[-1]) 429 | densities = [] 430 | 431 | if self.channels is None: 432 | return densities 433 | 434 | for y_off in range(0, self.channels.shape[-2], tile_size[-2]): 435 | for x_off in range(0, self.channels.shape[-1], tile_size[-1]): 436 | y_start = max(y_off - 1, 0) 437 | x_start = max(x_off - 1, 0) 438 | y_end = min(y_off + 1 + tile_size[0], self.channels.shape[-2]) 439 | x_end = min(x_off + 1 + tile_size[1], self.channels.shape[-1]) 440 | in_tile = self.channels[0, y_start:y_end, x_start:x_end] 441 | density = torch.sum(in_tile > 0) 442 | densities.append(density) 443 | 444 | return torch.tensor(densities) 445 | 446 | def get_mask(self): 447 | return self.channels 448 | 449 | 450 | class ComputationsLogger(LoggingLayer): 451 | n_computed = 0 452 | n_samples = 0 453 | 454 | def __init__(self, **kwargs): 455 | super(ComputationsLogger, self).__init__(**kwargs) 456 | self.computed_mask = None 457 | self.n_samples = 0 458 | 459 | def reset(self): 460 | ComputationsLogger.n_computed = 0 461 | ComputationsLogger.n_samples = 0 462 | self.computed_mask = None 463 | self.n_samples = 0 464 | 465 | def forward(self, x): 466 | if not self.enabled: 467 | return x 468 | 469 | if not self.added_id: 470 | self.name = f"{LoggingLayer.id} {self.name}" 471 | LoggingLayer.id += 1 472 | self.added_id = True 473 | 474 | ComputationsLogger.n_computed += torch.sum(x) 475 | ComputationsLogger.n_samples += x.numel() 476 | 477 | self.n_samples += x.shape[0] * x.shape[1] 478 | comp_mask = torch.sum(torch.sum(x.clone().float(), dim=0), dim=0) 479 | if self.computed_mask is None or self.computed_mask.shape != x.shape[2:]: 480 | self.computed_mask = comp_mask 481 | else: 482 | self.computed_mask += comp_mask 483 | 484 | return x 485 | 486 | def get(self): 487 | if self.n_samples == 0: 488 | return 0 489 | return ComputationsLogger.n_computed / ComputationsLogger.n_samples 490 | 491 | def get_mask(self): 492 | if self.computed_mask is None: 493 | return None 494 | return self.computed_mask / self.n_samples 495 | 496 | 497 | class MultiplicationsLogger(LoggingLayer): 498 | n_computed = 0 499 | n_samples = 0 500 | 501 | def __init__(self, **kwargs): 502 | super(MultiplicationsLogger, self).__init__(**kwargs) 503 | self.n_computed = 0 504 | self.n_samples = 0 505 | 506 | def reset(self): 507 | self.n_computed = 0 508 | self.n_samples = 0 509 | MultiplicationsLogger.n_computed = 0 510 | MultiplicationsLogger.n_samples = 0 511 | 512 | def forward(self, x, filter, in_mask=None, kernel_size=(3, 3), dilation=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 513 | if not self.added_id: 514 | self.name = f"{LoggingLayer.id} {self.name}" 515 | LoggingLayer.id += 1 516 | self.added_id = True 517 | 518 | if self.enabled: 519 | if in_mask is None: 520 | computed = torch.sum(x[:, 0]) * filter.numel() 521 | self.n_computed += computed 522 | MultiplicationsLogger.n_computed += computed 523 | else: 524 | n_filter = (filter[:, 0, 0, :] if filter.shape[1] == filter.shape[2] else filter[:, :, 0, 0]).numel() 525 | if filter.shape[2] == 1: 526 | computed = torch.sum(x[:, 0]) * n_filter 527 | self.n_computed += computed 528 | MultiplicationsLogger.n_computed += computed 529 | else: 530 | active_neighbors = count_active_neighbors(in_mask[:, :1], kernel_size=kernel_size, dilation=dilation, stride=stride, padding=padding) 531 | active_neighbors = active_neighbors[:, 0][x[:, 0]] 532 | 533 | # computed = torch.sum(active_neighbors) * filter[::groups, :, 0, 0].numel() 534 | computed = torch.sum(active_neighbors) * n_filter 535 | self.n_computed += computed 536 | MultiplicationsLogger.n_computed += computed 537 | samples = x[:, 0].numel() * filter.numel() 538 | self.n_samples += samples 539 | MultiplicationsLogger.n_samples += samples 540 | 541 | return x 542 | 543 | def get_self_rel(self): 544 | if self.n_samples > 0: 545 | return self.n_computed / self.n_samples 546 | return 1 547 | 548 | def get_self_abs(self): 549 | return self.n_computed 550 | 551 | def get_global_rel(self): 552 | if MultiplicationsLogger.n_samples > 0: 553 | return MultiplicationsLogger.n_computed / MultiplicationsLogger.n_samples 554 | else: 555 | return 1 556 | 557 | def get_global_abs(self): 558 | return MultiplicationsLogger.n_computed 559 | 560 | 561 | class PrevInBandwidthLogger(LoggingLayer): 562 | n_read = 0 563 | 564 | def __init__(self, **kwargs): 565 | super(PrevInBandwidthLogger, self).__init__(**kwargs) 566 | 567 | def reset(self): 568 | PrevInBandwidthLogger.n_read = 0 569 | 570 | def forward(self, mask: torch.Tensor, dilation: str, tile_size: int = 0): 571 | if not self.added_id: 572 | self.name = f"{LoggingLayer.id} {self.name}" 573 | LoggingLayer.id += 1 574 | self.added_id = True 575 | 576 | if self.enabled: 577 | if dilation == "tile": 578 | tile_size = tile_size if tile_size > 0 else mask.shape[2] // (-tile_size) 579 | mask = mask[:, :, ::tile_size, ::tile_size] 580 | inv_long_mask = (~mask).to(torch.long) 581 | inactive_neighbors = torch.zeros_like(inv_long_mask) 582 | inactive_neighbors[:, :, :-1] += inv_long_mask[:, :, 1:] 583 | inactive_neighbors[:, :, 1:] += inv_long_mask[:, :, :-1] 584 | inactive_neighbors[:, :, :, :-1] += inv_long_mask[:, :, :, 1:] 585 | inactive_neighbors[:, :, :, 1:] += inv_long_mask[:, :, :, :-1] 586 | inactive_neighbors[~mask] = 0 587 | n_read = torch.sum(inactive_neighbors).item() * tile_size 588 | else: 589 | inv_long_mask = (~mask).to(torch.long) 590 | inactive_neighbors = torch.zeros_like(inv_long_mask) 591 | inactive_neighbors[:, :, :-1] += inv_long_mask[:, :, 1:] 592 | inactive_neighbors[:, :, 1:] += inv_long_mask[:, :, :-1] 593 | inactive_neighbors[:, :, :, :-1] += inv_long_mask[:, :, :, 1:] 594 | inactive_neighbors[:, :, :, 1:] += inv_long_mask[:, :, :, :-1] 595 | inactive_neighbors[:, :, :-1, :-1] += inv_long_mask[:, :, 1:, 1:] 596 | inactive_neighbors[:, :, 1:, :-1] += inv_long_mask[:, :, :-1, 1:] 597 | inactive_neighbors[:, :, :-1, 1:] += inv_long_mask[:, :, 1:, :-1] 598 | inactive_neighbors[:, :, 1:, 1:] += inv_long_mask[:, :, :-1, :-1] 599 | inactive_neighbors[~mask] = 0 600 | n_read = torch.sum(inactive_neighbors).item() 601 | 602 | PrevInBandwidthLogger.n_read += n_read 603 | 604 | return mask 605 | 606 | def get(self): 607 | return PrevInBandwidthLogger.n_read 608 | 609 | 610 | class PrevInputSizeLogger(LoggingLayer): 611 | n_vals = 0 612 | 613 | def __init__(self, **kwargs): 614 | super(PrevInputSizeLogger, self).__init__(**kwargs) 615 | self.n_vals = 0 616 | 617 | def reset(self): 618 | PrevInputSizeLogger.n_vals = 0 619 | self.n_vals = 0 620 | 621 | def forward(self, x): 622 | if not self.added_id: 623 | self.name = f"{LoggingLayer.id} {self.name}" 624 | LoggingLayer.id += 1 625 | self.added_id = True 626 | 627 | if x is None: 628 | return x 629 | 630 | PrevInputSizeLogger.n_vals += x[0].numel() 631 | self.n_vals += x[0].numel() 632 | 633 | return x 634 | 635 | def get_global(self): 636 | return PrevInputSizeLogger.n_vals 637 | 638 | def get(self): 639 | return self.n_vals 640 | 641 | 642 | class PrevInputLogger(LoggingLayer): 643 | def __init__(self, **kwargs): 644 | super(PrevInputLogger, self).__init__(**kwargs) 645 | self.prev_input = None 646 | 647 | def reset(self): 648 | self.prev_input = None 649 | 650 | def forward(self, x): 651 | if not self.enabled: 652 | return x 653 | 654 | if not self.added_id: 655 | self.name = f"{LoggingLayer.id} {self.name}" 656 | LoggingLayer.id += 1 657 | self.added_id = True 658 | 659 | if x is None: 660 | return x 661 | 662 | prev_input = torch.sum(torch.sum(x.clone().float(), dim=0), dim=0) 663 | if self.prev_input is None or self.prev_input.shape != prev_input.shape: 664 | self.prev_input = prev_input 665 | else: 666 | self.prev_input += prev_input 667 | 668 | return x 669 | 670 | def get(self): 671 | if self.prev_input is None: 672 | return None 673 | return self.prev_input / torch.max(self.prev_input) 674 | 675 | 676 | class InputLogger(PrevInputLogger): 677 | def __init__(self, **kwargs): 678 | super(InputLogger, self).__init__(**kwargs) 679 | 680 | 681 | class OutputLogger(PrevInputLogger): 682 | def __init__(self, **kwargs): 683 | super(OutputLogger, self).__init__(**kwargs) 684 | 685 | 686 | class RemoveDifferentialThreshold(LoggingLayer): 687 | def __init__(self, threshold=0.05, name="", enable_training=False, enabled=True, **kwargs): 688 | super(RemoveDifferentialThreshold, self).__init__(name=name, enabled=enabled) 689 | self.threshold = threshold 690 | self.prev_frame = None 691 | self.enable_training = enable_training 692 | 693 | def forward(self, x): 694 | if self.enabled and (torch.is_grad_enabled() and self.enable_training or not torch.is_grad_enabled()): 695 | if self.prev_frame is None: 696 | self.prev_frame = x.detach() 697 | return x 698 | mask = torch.abs(x - self.prev_frame) < self.threshold 699 | x[mask] = self.prev_frame[mask] 700 | self.prev_frame[~mask] = x[~mask].detach() 701 | 702 | return x 703 | 704 | def disable_logging(self): 705 | pass 706 | 707 | def enable_logging(self): 708 | pass 709 | 710 | def reset_history(self): 711 | self.prev_frame = None 712 | 713 | 714 | class RemoveConsistentPixels(LoggingLayer): 715 | def __init__(self, threshold=0.05, name="", enable_training=False, enabled=True, **kwargs): 716 | super(RemoveConsistentPixels, self).__init__(name=name, enabled=enabled) 717 | self.threshold = threshold 718 | self.prev_frame = None 719 | self.enable_training = enable_training 720 | 721 | def forward(self, x): 722 | if self.enabled and (torch.is_grad_enabled() and self.enable_training or not torch.is_grad_enabled()): 723 | if self.prev_frame is None: 724 | self.prev_frame = x.detach() 725 | return x 726 | mask = torch.abs(x - self.prev_frame) < self.threshold 727 | x[mask] = 0.0 728 | self.prev_frame[~mask] = x[~mask].detach() 729 | 730 | return x 731 | 732 | def disable_logging(self): 733 | pass 734 | 735 | def enable_logging(self): 736 | pass 737 | 738 | def reset_history(self): 739 | self.prev_frame = None 740 | 741 | 742 | class GradientLogFunction(torch.autograd.Function): 743 | @staticmethod 744 | def forward(ctx, X, log_out): 745 | ctx.log_out = log_out 746 | return X 747 | 748 | @staticmethod 749 | def backward(ctx, grad): 750 | ctx.log_out['loss'] = grad 751 | return grad, None 752 | 753 | 754 | class GradientLogger(LoggingLayer): 755 | def __init__(self, **kwargs): 756 | super(GradientLogger, self).__init__(**kwargs) 757 | self.fn = GradientLogFunction() 758 | self.loss = {'loss': None} 759 | 760 | def reset(self): 761 | self.loss['loss'] = None 762 | 763 | def forward(self, x): 764 | self.fn(x, self.loss) 765 | return x 766 | 767 | def get(self): 768 | return self.loss['loss'] 769 | 770 | 771 | class Loggers(LoggingLayer): 772 | threshold = 0.0 773 | 774 | def __init__(self, name="", modes=None, enabled=False, threshold=None, **kwargs): 775 | super(Loggers, self).__init__(name, enabled, **kwargs) 776 | self.threshold = threshold if threshold is not None else Loggers.threshold 777 | 778 | if modes is None: 779 | modes = ["DiffMask", "DiffDensity", "Density", "Mask"] 780 | self.logs = {} 781 | if "RemoveDiff" in modes: 782 | self.logs["RemoveDiff"] = RemoveDifferentialThreshold(enable_training=False, enabled=True, threshold=self.threshold) 783 | if "RemoveConsistent" in modes: 784 | self.logs["RemoveConsistent"] = RemoveConsistentPixels(enable_training=False, enabled=True, threshold=self.threshold) 785 | if "Density" in modes: 786 | self.logs["Density"] = DensityLogger(name=name, enabled=enabled, threshold=self.threshold, **kwargs) 787 | if "Mask" in modes: 788 | self.logs["Mask"] = PixelMaskLogger(name=name, enabled=enabled, threshold=self.threshold, **kwargs) 789 | if "DiffDensity" in modes: 790 | self.logs["DiffDensity"] = DifferentialDensityLogger(name=name, enabled=enabled, threshold=self.threshold, **kwargs) 791 | if "DiffMask" in modes: 792 | self.logs["DiffMask"] = DifferentialPixelMaskLogger(name=name, enabled=enabled, threshold=self.threshold, **kwargs) 793 | 794 | self.active_layers = ["ReLU", "Input"] 795 | self.always_use_removediff = False 796 | 797 | def reset(self): 798 | for name, log in self.logs.items(): 799 | if issubclass(type(log), LoggingLayer): 800 | log.reset() 801 | 802 | def reset_history(self): 803 | for name, log in self.logs.items(): 804 | if issubclass(type(log), LoggingLayer): 805 | log.reset_history() 806 | 807 | def _is_active(self): 808 | if self.active_layers is None: 809 | return True 810 | 811 | for layer in self.active_layers: 812 | if layer in self.name: 813 | return True 814 | 815 | return False 816 | 817 | def forward(self, x): 818 | if not self.added_id: 819 | self.name = f"{LoggingLayer.id} {self.name}" 820 | for name, log in list(self.logs.items()): 821 | self.logs[f"{LoggingLayer.id} {name}"] = log 822 | self.logs.pop(name) 823 | self.added_id = True 824 | 825 | for name, log in self.logs.items(): 826 | log.name = name 827 | log.added_id = True 828 | 829 | LoggingLayer.id += 1 830 | 831 | if not self._is_active(): 832 | if self.always_use_removediff: 833 | for name, log in self.logs.items(): 834 | if "RemoveDiff" in name: 835 | x = log(x) 836 | return x 837 | 838 | for name, log in self.logs.items(): 839 | if self.enabled: 840 | log.enable_logging() 841 | else: 842 | log.disable_logging() 843 | x = log(x) 844 | 845 | return x 846 | 847 | def get(self): 848 | res = {} 849 | 850 | if not self._is_active(): 851 | return res 852 | 853 | for name, log in self.logs.items(): 854 | if issubclass(type(log), LoggingLayer): 855 | res[name] = log.get() 856 | 857 | return res --------------------------------------------------------------------------------