├── README.md ├── code ├── branch_attentions │ ├── dynamic_conv.py │ ├── highway.py │ ├── resnest_module.py │ └── sk_module.py ├── channel_attentions │ ├── dia_module.py │ ├── eca_module.py │ ├── enc_module.py │ ├── fcanet.py │ ├── gct_module.py │ ├── se_module.py │ └── soca_module.py ├── channel_spatial_attentions │ ├── bam.py │ ├── cbam.py │ ├── coordatt_module.py │ ├── danet.py │ ├── lka.py │ ├── scnet.py │ ├── simam_module.py │ ├── strip_pooling_module.py │ └── triplet_attention.py ├── spatial_attentions │ ├── attention_augmented_module.py │ ├── doub_attention.py │ ├── external_attention.py │ ├── gc_module.py │ ├── hamnet.py │ ├── mhsa.py │ ├── ocr.py │ ├── offset_module.py │ ├── segformer_module.py │ ├── self_attention.py │ └── stn.py ├── spatial_temporal_attentions │ └── dstt_module.py └── temporal_attentions │ └── gltr.py └── imgs ├── attention_category.png ├── fuse.png ├── fuse_fig.png └── timeline.png /README.md: -------------------------------------------------------------------------------- 1 | # This repo is built for paper: Attention Mechanisms in Computer Vision: A Survey [paper](https://arxiv.org/abs/2111.07624) 2 | 3 | ## 介绍该论文的中文版博客 [链接](https://mp.weixin.qq.com/s/0iOZ45NTK9qSWJQlcI3_kQ ) 4 | 5 | 6 | 7 | ## Citation 8 | 9 | If it is helpful for your work, please cite this paper: 10 | 11 | ``` 12 | @article{guo2022attention, 13 | title={Attention mechanisms in computer vision: A survey}, 14 | author={Guo, Meng-Hao and Xu, Tian-Xing and Liu, Jiang-Jiang and Liu, Zheng-Ning and Jiang, Peng-Tao and Mu, Tai-Jiang and Zhang, Song-Hai and Martin, Ralph R and Cheng, Ming-Ming and Hu, Shi-Min}, 15 | journal={Computational Visual Media}, 16 | pages={1--38}, 17 | year={2022}, 18 | publisher={Springer} 19 | } 20 | ``` 21 | 22 | 23 | ![image](https://github.com/MenghaoGuo/Awesome-Vision-Attentions/blob/main/imgs/fuse.png) 24 | 25 | 26 | 27 | 28 | 29 | 30 | - [Vision-Attention-Papers](#vision-attention-papers) 31 | * [Channel attention](#channel-attention) 32 | * [Spatial attention](#spatial-attention) 33 | * [Temporal attention](#temporal-attention) 34 | * [Branch attention](#branch-attention) 35 | * [Channel \& Spatial attention](#channelspatial-attention) 36 | * [Spatial \& Temporal attention](#spatialtemporal-attention) 37 | 38 | 39 | 40 | * Codes about different attention mechanisms based on [Jittor](https://github.com/Jittor/jittor) are released now 41 | * TODO : collect more related papers. Contributions are welcome. 42 | 43 | 🔥 (citations > 200) 44 | 45 | 46 | ## Channel attention 47 | 48 | * Squeeze-and-Excitation Networks (CVPR 2018) [pdf](https://arxiv.org/pdf/1709.01507), (PAMI2019 version) [pdf](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8701503) 🔥 49 | * Image superresolution using very deep residual channel attention networks (ECCV 2018) [pdf](https://arxiv.org/pdf/1807.02758) 🔥 50 | * Context encoding for semantic segmentation (CVPR 2018) [pdf](https://arxiv.org/pdf/1803.08904) 🔥 51 | * Spatio-temporal channel correlation networks for action classification (ECCV 2018) [pdf](https://arxiv.org/pdf/1806.07754) 52 | * Global second-order pooling convolutional networks (CVPR 2019) [pdf](https://arxiv.org/pdf/1811.12006) 53 | * Srm : A style-based recalibration module for convolutional neural networks (ICCV 2019) [pdf](https://arxiv.org/pdf/1903.10829) 54 | * You look twice: Gaternet for dynamic filter selection in cnns (CVPR 2019) [pdf](https://arxiv.org/pdf/1811.11205) 55 | * Second-order attention network for single image super-resolution (CVPR 2019) [pdf](https://openaccess.thecvf.com/content_CVPR_2019/papers/Dai_Second-Order_Attention_Network_for_Single_Image_Super-Resolution_CVPR_2019_paper.pdf) 🔥 56 | * DIANet: Dense-and-Implicit Attention Network (AAAI 2020)[pdf](https://arxiv.org/pdf/1905.10671.pdf) 57 | * Spsequencenet: Semantic segmentation network on 4d point clouds (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/html/Shi_SpSequenceNet_Semantic_Segmentation_Network_on_4D_Point_Clouds_CVPR_2020_paper.html) 58 | * Ecanet: Efficient channel attention for deep convolutional neural networks (CVPR 2020) [pdf](https://arxiv.org/pdf/1910.03151) 🔥 59 | * Gated channel transformation for visual recognition (CVPR2020) [pdf](https://arxiv.org/pdf/1909.11519) 60 | * Fcanet: Frequency channel attention networks (ICCV 2021) [pdf](https://arxiv.org/pdf/2012.11879) 61 | 62 | ## Spatial attention 63 | 64 | - Recurrent models of visual attention (NeurIPS 2014), [pdf](https://arxiv.org/pdf/1406.6247) 🔥 65 | - Show, attend and tell: Neural image caption generation with visual attention (PMLR 2015) [pdf](https://arxiv.org/pdf/1502.03044) 🔥 66 | - Draw: A recurrent neural network for image generation (ICML 2015) [pdf](https://arxiv.org/pdf/1502.04623) 🔥 67 | - Spatial transformer networks (NeurIPS 2015) [pdf](https://arxiv.org/pdf/1506.02025) 🔥 68 | - Multiple object recognition with visual attention (ICLR 2015) [pdf](https://arxiv.org/pdf/1412.7755) 🔥 69 | - Action recognition using visual attention (arXiv 2015) [pdf](https://arxiv.org/pdf/1511.04119) 🔥 70 | - Videolstm convolves, attends and flows for action recognition (arXiv 2016) [pdf](https://arxiv.org/pdf/1607.01794) 🔥 71 | - Look closer to see better: Recurrent attention convolutional neural network for fine-grained image recognition (CVPR 2017) [pdf](https://openaccess.thecvf.com/content_cvpr_2017/papers/Fu_Look_Closer_to_CVPR_2017_paper.pdf) 🔥 72 | - Learning multi-attention convolutional neural network for fine-grained image recognition (ICCV 2017) [pdf](http://openaccess.thecvf.com/content_ICCV_2017/papers/Zheng_Learning_Multi-Attention_Convolutional_ICCV_2017_paper.pdf) 🔥 73 | - Diversified visual attention networks for fine-grained object classification (TMM 2017) [pdf](https://arxiv.org/pdf/1606.08572) 🔥 74 | - High-Order Attention Models for Visual Question Answering (NeurIPS 2017) [pdf](https://arxiv.org/pdf/1711.04323) 75 | - Attentional pooling for action recognition (NeurIPS 2017) [pdf](https://arxiv.org/pdf/1711.01467) 🔥 76 | - Non-local neural networks (CVPR 2018) [pdf](https://arxiv.org/pdf/1711.07971) 🔥 77 | - Attentional shapecontextnet for point cloud recognition (CVPR 2018) [pdf](https://openaccess.thecvf.com/content_cvpr_2018/papers/Xie_Attentional_ShapeContextNet_for_CVPR_2018_paper.pdf) 78 | - Relation networks for object detection (CVPR 2018) [pdf](https://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Relation_Networks_for_CVPR_2018_paper.pdf) 🔥 79 | - a2-nets: Double attention networks (NeurIPS 2018) [pdf](https://arxiv.org/pdf/1810.11579) 🔥 80 | - Attention-aware compositional network for person re-identification (CVPR 2018) [pdf](https://arxiv.org/pdf/1805.03344) 🔥 81 | - Tell me where to look: Guided attention inference network (CVPR 2018) [pdf](https://arxiv.org/pdf/1802.10171) 🔥 82 | - Pedestrian alignment network for large-scale person re-identification (TCSVT 2018) [pdf](https://arxiv.org/pdf/1707.00408) 🔥 83 | - Learn to pay attention (ICLR 2018) [pdf](https://arxiv.org/pdf/1804.02391.pdf) 🔥 84 | - Attention U-Net: Learning Where to Look for the Pancreas (MIDL 2018) [pdf](https://arxiv.org/pdf/1804.03999.pdf) 🔥 85 | - Psanet: Point-wise spatial attention network for scene parsing (ECCV 2018) [pdf](https://openaccess.thecvf.com/content_ECCV_2018/html/Hengshuang_Zhao_PSANet_Point-wise_Spatial_ECCV_2018_paper.html) 🔥 86 | - Self attention generative adversarial networks (ICML 2019) [pdf](https://arxiv.org/pdf/1805.08318) 🔥 87 | - Attentional pointnet for 3d-object detection in point clouds (CVPRW 2019) [pdf](https://openaccess.thecvf.com/content_CVPRW_2019/papers/WAD/Paigwar_Attentional_PointNet_for_3D-Object_Detection_in_Point_Clouds_CVPRW_2019_paper.pdf) 88 | - Co-occurrent features in semantic segmentation (CVPR 2019) [pdf](http://openaccess.thecvf.com/content_CVPR_2019/papers/Zhang_Co-Occurrent_Features_in_Semantic_Segmentation_CVPR_2019_paper.pdf) 89 | - Factor Graph Attention (CVPR 2019) [pdf](https://arxiv.org/pdf/1904.05880) 90 | - Attention augmented convolutional networks (ICCV 2019) [pdf](https://arxiv.org/pdf/1904.09925) 🔥 91 | - Local relation networks for image recognition (ICCV 2019) [pdf](https://arxiv.org/pdf/1904.11491) 92 | - Latentgnn: Learning efficient nonlocal relations for visual recognition(ICML 2019) [pdf](https://arxiv.org/pdf/1905.11634) 93 | - Graph-based global reasoning networks (CVPR 2019) [pdf](https://arxiv.org/pdf/1811.12814) 🔥 94 | - Gcnet: Non-local networks meet squeeze-excitation networks and beyond (ICCVW 2019) [pdf](https://arxiv.org/pdf/1904.11492) 🔥 95 | - Asymmetric non-local neural networks for semantic segmentation (ICCV 2019) [pdf](https://arxiv.org/pdf/1908.07678) 🔥 96 | - Looking for the devil in the details: Learning trilinear attention sampling network for fine-grained image recognition (CVPR 2019) [pdf](https://arxiv.org/pdf/1903.06150) 97 | - Second-order non-local attention networks for person re-identification (ICCV 2019) [pdf](https://arxiv.org/pdf/1909.00295) 🔥 98 | - End-to-end comparative attention networks for person re-identification (ICCV 2019) [pdf](https://arxiv.org/pdf/1606.04404) 🔥 99 | - Modeling point clouds with self-attention and gumbel subset sampling (CVPR 2019) [pdf](https://arxiv.org/pdf/1904.03375) 100 | - Diagnose like a radiologist: Attention guided convolutional neural network for thorax disease classification (arXiv 2019) [pdf](https://arxiv.org/pdf/1801.09927) 101 | - L2g autoencoder: Understanding point clouds by local-to-global reconstruction with hierarchical self-attention (arXiv 2019) [pdf](https://arxiv.org/pdf/1908.00720) 102 | - Generative pretraining from pixels (PMLR 2020) [pdf](https://cdn.openai.com/papers/Generative_Pretraining_from_Pixels_V2.pdf) 103 | - Exploring self-attention for image recognition (CVPR 2020) [pdf](https://arxiv.org/pdf/2004.13621) 104 | - Cf-sis: Semantic-instance segmentation of 3d point clouds by context fusion with self attention (ACM MM 20) [pdf](https://dl.acm.org/doi/pdf/10.1145/3394171.3413829) 105 | - Disentangled non-local neural networks (ECCV 2020) [pdf](https://arxiv.org/pdf/2006.06668) 106 | - Relation-aware global attention for person re-identification (CVPR 2020) [pdf](https://arxiv.org/pdf/1904.02998) 107 | - Segmentation transformer: Object-contextual representations for semantic segmentation (ECCV 2020) [pdf](https://arxiv.org/pdf/1909.11065) 🔥 108 | - Spatial pyramid based graph reasoning for semantic segmentation (CVPR 2020) [pdf](https://arxiv.org/pdf/2003.10211) 109 | - Self-supervised Equivariant Attention Mechanism for Weakly Supervised Semantic Segmentation (CVPR 2020) [pdf](https://arxiv.org/pdf/2004.04581.pdf) 110 | - End-to-end object detection with transformers (ECCV 2020) [pdf](https://arxiv.org/pdf/2005.12872) 🔥 111 | - Pointasnl: Robust point clouds processing using nonlocal neural networks with adaptive sampling (CVPR 2020) [pdf](https://arxiv.org/pdf/2003.00492) 112 | - Rethinking semantic segmentation from a sequence-to-sequence perspective with transformers (CVPR 2021) [pdf](https://arxiv.org/pdf/2012.15840) 113 | - An image is worth 16x16 words: Transformers for image recognition at scale (ICLR 2021) [pdf](https://arxiv.org/pdf/2010.11929) 🔥 114 | - Is Attention Better Than Matrix Decomposition? (ICLR 2021) [pdf](https://arxiv.org/abs/2109.04553) 115 | - An empirical study of training selfsupervised vision transformers (CVPR 2021) [pdf](https://arxiv.org/pdf/2104.02057) 116 | - Ocnet: Object context network for scene parsing (IJCV 2021) [pdf](https://arxiv.org/pdf/1809.00916) 🔥 117 | - Point transformer (ICCV 2021) [pdf](https://arxiv.org/pdf/2012.09164) 118 | - PCT: Point Cloud Transformer (CVMJ 2021) [pdf](https://arxiv.org/pdf/2012.09688.pdf) 119 | - Pre-trained image processing transformer (CVPR 2021) [pdf](https://arxiv.org/pdf/2012.00364) 120 | - An empirical study of training self-supervised vision transformers (ICCV 2021) [pdf](https://arxiv.org/pdf/2104.02057) 121 | - Segformer: Simple and efficient design for semantic segmentation with transformers (arxiv 2021) [pdf](https://arxiv.org/pdf/2105.15203) 122 | - Beit: Bert pre-training of image transformers (arxiv 2021) [pdf](https://arxiv.org/pdf/2106.08254) 123 | - Beyond Self-attention: External attention using two linear layers for visual tasks (arxiv 2021) [pdf](https://arxiv.org/pdf/2105.02358) 124 | - Query2label: A simple transformer way to multi-label classification (arxiv 2021) [pdf](https://arxiv.org/pdf/2107.10834) 125 | - Transformer in transformer (arxiv 2021) [pdf](https://arxiv.org/pdf/2103.00112) 126 | 127 | ## Temporal attention 128 | 129 | - Jointly attentive spatial-temporal pooling networks for video-based person re-identification (ICCV 2017) [pdf](https://arxiv.org/pdf/1708.02286.pdf) 🔥 130 | - Video person reidentification with competitive snippet-similarity aggregation and co-attentive snippet embedding (CVPR 2018) [pdf](https://openaccess.thecvf.com/content_cvpr_2018/CameraReady/1036.pdf) 131 | - Scan: Self-and-collaborative attention network for video person re-identification (TIP 2019) [pdf](https://arxiv.org/pdf/1807.05688.pdf) 132 | 133 | ## Branch attention 134 | 135 | - Training very deep networks (NeurIPS 2015) [pdf](https://arxiv.org/pdf/1507.06228.pdf) 🔥 136 | - Selective kernel networks (CVPR 2019) [pdf](https://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Selective_Kernel_Networks_CVPR_2019_paper.pdf) 🔥 137 | - CondConv: Conditionally Parameterized Convolutions for Efficient Inference (NeurIPS 2019) [pdf](https://arxiv.org/pdf/1904.04971.pdf) 138 | - Dynamic convolution: Attention over convolution kernels (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Chen_Dynamic_Convolution_Attention_Over_Convolution_Kernels_CVPR_2020_paper.pdf) 139 | - ResNest: Split-attention networks (arXiv 2020) [pdf](https://arxiv.org/pdf/2004.08955.pdf) 🔥 140 | 141 | ## Channel+Spatial attention 142 | 143 | - Residual attention network for image classification (CVPR 2017) [pdf](https://openaccess.thecvf.com/content_cvpr_2017/papers/Wang_Residual_Attention_Network_CVPR_2017_paper.pdf) 🔥 144 | - SCA-CNN: spatial and channel-wise attention in convolutional networks for image captioning (CVPR 2017) [pdf](https://openaccess.thecvf.com/content_cvpr_2017/papers/Chen_SCA-CNN_Spatial_and_CVPR_2017_paper.pdf) 🔥 145 | - CBAM: convolutional block attention module (ECCV 2018) [pdf](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf) 🔥 146 | - Harmonious attention network for person re-identification (CVPR 2018) [pdf](https://arxiv.org/pdf/1802.08122.pdf) 🔥 147 | - Recalibrating fully convolutional networks with spatial and channel “squeeze and excitation” blocks (TMI 2018) [pdf](https://arxiv.org/pdf/1808.08127.pdf) 148 | - Mancs: A multi-task attentional network with curriculum sampling for person re-identification (ECCV 2018) [pdf](https://www.ecva.net/papers/eccv_2018/papers_ECCV/papers/Cheng_Wang_Mancs_A_Multi-task_ECCV_2018_paper.pdf) 🔥 149 | - Bam: Bottleneck attention module(BMVC 2018) [pdf](http://bmvc2018.org/contents/papers/0092.pdf) 🔥 150 | - Pvnet: A joint convolutional network of point cloud and multi-view for 3d shape recognition (ACM MM 2018) [pdf](https://arxiv.org/pdf/1808.07659.pdf) 151 | - Learning what and where to attend (ICLR 2019) [pdf](https://openreview.net/pdf?id=BJgLg3R9KQ) 152 | - Dual attention network for scene segmentation (CVPR 2019) [pdf](https://openaccess.thecvf.com/content_CVPR_2019/papers/Fu_Dual_Attention_Network_for_Scene_Segmentation_CVPR_2019_paper.pdf) 🔥 153 | - Abd-net: Attentive but diverse person re-identification (ICCV 2019) [pdf](https://openaccess.thecvf.com/content_ICCV_2019/papers/Chen_ABD-Net_Attentive_but_Diverse_Person_Re-Identification_ICCV_2019_paper.pdf) 154 | - Mixed high-order attention network for person re-identification (ICCV 2019) [pdf](https://arxiv.org/pdf/1908.05819.pdf) 155 | - Mlcvnet: Multi-level context votenet for 3d object detection (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_MLCVNet_Multi-Level_Context_VoteNet_for_3D_Object_Detection_CVPR_2020_paper.pdf) 156 | - Improving convolutional networks with self-calibrated convolutions (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_Improving_Convolutional_Networks_With_Self-Calibrated_Convolutions_CVPR_2020_paper.pdf) 157 | - Relation-aware global attention for person re-identification (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhang_Relation-Aware_Global_Attention_for_Person_Re-Identification_CVPR_2020_paper.pdf) 158 | - Strip Pooling: Rethinking spatial pooling for scene parsing (CVPR 2020) [pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Hou_Strip_Pooling_Rethinking_Spatial_Pooling_for_Scene_Parsing_CVPR_2020_paper.pdf) 159 | - Rotate to attend: Convolutional triplet attention module, (WACV 2021) [pdf](https://arxiv.org/pdf/2010.03045.pdf) 160 | - Coordinate attention for efficient mobile network design (CVPR 2021) [pdf](https://openaccess.thecvf.com/content/CVPR2021/papers/Hou_Coordinate_Attention_for_Efficient_Mobile_Network_Design_CVPR_2021_paper.pdf) 161 | - Simam: A simple, parameter-free attention module for convolutional neural networks (ICML 2021) [pdf](http://proceedings.mlr.press/v139/yang21o/yang21o.pdf) 162 | 163 | ## Spatial+Temporal attention 164 | 165 | - An end-to-end spatio-temporal attention model for human action recognition from skeleton data (AAAI 2017) [pdf](https://arxiv.org/pdf/1611.06067.pdf) 🔥 166 | - Diversity regularized spatiotemporal attention for video-based person re-identification (arXiv 2018) 🔥 167 | - Interpretable spatio-temporal attention for video action recognition (ICCVW 2019) [pdf](https://openaccess.thecvf.com/content_ICCVW_2019/papers/HVU/Meng_Interpretable_Spatio-Temporal_Attention_for_Video_Action_Recognition_ICCVW_2019_paper.pdf) 168 | - A Simple Baseline for Audio-Visual Scene-Aware Dialog (CVPR 2019) [pdf](https://arxiv.org/pdf/1904.05876v1.pdf) 169 | - Hierarchical lstms with adaptive attention for visual captioning (TPAMI 2020) [pdf](https://arxiv.org/pdf/1812.11004.pdf) 170 | - Stat: Spatial-temporal attention mechanism for video captioning, (TMM 2020) [pdf](https://ieeexplore.ieee.org/abstract/document/8744407) 171 | - Gta: Global temporal attention for video action understanding (arXiv 2020) [pdf](https://arxiv.org/pdf/2012.08510.pdf) 172 | - Multi-granularity reference-aided attentive feature aggregation for video-based person re-identification (CVPR 2020) [pdf](https://arxiv.org/pdf/2003.12224.pdf) 173 | - Read: Reciprocal attention discriminator for image-to-video re-identification (ECCV 2020) [pdf](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123590324.pdf) 174 | - Decoupled spatial-temporal transformer for video inpainting (arXiv 2021) [pdf](https://arxiv.org/pdf/2104.06637.pdf) 175 | - Towards Coherent Visual Storytelling with Ordered Image Attention (arXiv 2021) [pdf](https://arxiv.org/pdf/2108.02180) 176 | -------------------------------------------------------------------------------- /code/branch_attentions/dynamic_conv.py: -------------------------------------------------------------------------------- 1 | # Dynamic convolution: Attention over convolution kernels (CVPR 2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class attention2d(nn.Module): 7 | def __init__(self, in_planes, ratios, K, temperature): 8 | super(attention2d, self).__init__() 9 | # for reducing τ temperature from 30 to 1 linearly in the first 10 epochs. 10 | assert temperature % 3 == 1 11 | self.avgpool = nn.AdaptiveAvgPool2d(1) 12 | 13 | if in_planes != 3: 14 | hidden_planes = int(in_planes * ratios) + 1 15 | else: 16 | hidden_planes = K 17 | 18 | self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False) 19 | # self.relu = nn.ReLU() 20 | self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True) 21 | self.temperature = temperature 22 | 23 | def update__temperature(self): 24 | if self.temperature != 1: 25 | self.temperature -= 3 26 | 27 | def execute(self, z): 28 | z = self.avgpool(z) 29 | z = self.fc1(z) 30 | # z = self.relu(z) 31 | z = nn.relu(z) 32 | z = self.fc2(z) 33 | z = z.view(z.size(0), -1) 34 | # z = self.fc2(z).view(z.size(0), -1) 35 | 36 | return nn.softmax(z/self.temperature, 1) 37 | 38 | 39 | class Dynamic_conv2d(nn.Module): 40 | def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34): 41 | super(Dynamic_conv2d, self).__init__() 42 | 43 | if in_planes % groups != 0: 44 | raise ValueError('Error : in_planes%groups != 0') 45 | self.in_planes = in_planes 46 | self.out_planes = out_planes 47 | self.kernel_size = kernel_size 48 | self.stride = stride 49 | self.padding = padding 50 | self.dilation = dilation 51 | self.groups = groups 52 | self.bias = bias 53 | self.K = K 54 | self.attention = attention2d(in_planes, ratio, K, temperature) 55 | self.weight = jt.random(( 56 | K, out_planes, in_planes//groups, kernel_size, kernel_size)) 57 | 58 | if bias: 59 | self.bias = jt.random((K, out_planes)) 60 | else: 61 | self.bias = None 62 | 63 | def update_temperature(self): 64 | self.attention.update__temperature() 65 | 66 | def execute(self, z): 67 | 68 | # Regard batch as a dimensional variable, perform group convolution, 69 | # because the weight of group convolution is different, 70 | # and the weight of dynamic convolution is also different 71 | softmax_attention = self.attention(z) 72 | batch_size, in_planes, height, width = z.size() 73 | # changing into dimension for group convolution 74 | z = z.view(1, -1, height, width) 75 | weight = self.weight.view(self.K, -1) 76 | 77 | # The generation of the weight of dynamic convolution, 78 | # which generates batch_size convolution parameters 79 | # (each parameter is different) 80 | aggregate_weight = jt.matmul(softmax_attention, weight).view(-1, self.in_planes, 81 | self.kernel_size, self.kernel_size) # expects two matrices (2D tensors) 82 | if self.bias is not None: 83 | aggregate_bias = jt.matmul(softmax_attention, self.bias).view(-1) 84 | output = nn.conv2d(z, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, 85 | dilation=self.dilation, groups=self.groups * batch_size) 86 | else: 87 | output = nn.conv2d(z, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, 88 | dilation=self.dilation, groups=self.groups * batch_size) 89 | output = output.view(batch_size, self.out_planes, 90 | output.size(-2), output.size(-1)) 91 | # print('2d-att-for') 92 | return output 93 | 94 | 95 | def main(): 96 | attention_block = Dynamic_conv2d(64, 64, 3, padding=1) 97 | input = jt.ones([4, 64, 32, 32]) 98 | output = attention_block(input) 99 | print(input.size(), output.size()) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /code/branch_attentions/highway.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class Highway(nn.Module): 6 | def __init__(self, dim, num_layers=2): 7 | 8 | super(Highway, self).__init__() 9 | 10 | self.num_layers = num_layers 11 | 12 | self.nonlinear = nn.ModuleList( 13 | [nn.Linear(dim, dim) for _ in range(num_layers)]) 14 | self.linear = nn.ModuleList([nn.Linear(dim, dim) 15 | for _ in range(num_layers)]) 16 | self.gate = nn.ModuleList([nn.Linear(dim, dim) 17 | for _ in range(num_layers)]) 18 | 19 | self.f = nn.ReLU() 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def execute(self, x): 23 | """ 24 | :param x: tensor with shape of [batch_size, size] 25 | :return: tensor with shape of [batch_size, size] 26 | applies σ(x) ⨀ (f(G(x))) + (1 - σ(x)) ⨀ (Q(x)) transformation | G and Q is affine transformation, 27 | f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition 28 | and ⨀ is element-wise multiplication 29 | """ 30 | 31 | for layer in range(self.num_layers): 32 | gate = self.sigmoid(self.gate[layer](x)) 33 | nonlinear = self.f(self.nonlinear[layer](x)) 34 | linear = self.linear[layer](x) 35 | x = gate * nonlinear + (1 - gate) * linear 36 | print(x.size()) 37 | return x 38 | 39 | 40 | def main(): 41 | attention_block = Highway(32) 42 | input = jt.rand([4, 64, 32]) 43 | output = attention_block(input) 44 | print(input.size(), output.size()) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /code/branch_attentions/resnest_module.py: -------------------------------------------------------------------------------- 1 | # ResNest: Split-attention networks (arXiv 2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 7 | min_value = min_value or divisor 8 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 9 | # Make sure that round down does not go down by more than 10%. 10 | if new_v < round_limit * v: 11 | new_v += divisor 12 | return new_v 13 | 14 | 15 | class RadixSoftmax(nn.Module): 16 | def __init__(self, radix, cardinality): 17 | super(RadixSoftmax, self).__init__() 18 | self.radix = radix 19 | self.cardinality = cardinality 20 | 21 | def execute(self, x): 22 | batch = x.size(0) 23 | if self.radix > 1: 24 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 25 | x = nn.softmax(x, dim=1) 26 | x = x.reshape(batch, -1) 27 | else: 28 | x = x.sigmoid() 29 | return x 30 | 31 | 32 | class SplitAttn(nn.Module): 33 | """Split-Attention (aka Splat) 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | self.drop_block = drop_block 43 | mid_chs = out_channels * radix 44 | if rd_channels is None: 45 | attn_chs = make_divisible( 46 | in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 47 | else: 48 | attn_chs = rd_channels * radix 49 | 50 | padding = kernel_size // 2 if padding is None else padding 51 | self.conv = nn.Conv2d( 52 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 53 | groups=groups * radix, bias=bias, **kwargs) 54 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 55 | self.act0 = act_layer() 56 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 57 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 58 | self.act1 = act_layer() 59 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 60 | self.rsoftmax = RadixSoftmax(radix, groups) 61 | 62 | def execute(self, x): 63 | x = self.conv(x) 64 | x = self.bn0(x) 65 | if self.drop_block is not None: 66 | x = self.drop_block(x) 67 | x = self.act0(x) 68 | 69 | B, RC, H, W = x.shape 70 | if self.radix > 1: 71 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 72 | x_gap = x.sum(dim=1) 73 | else: 74 | x_gap = x 75 | x_gap = x_gap.mean(2, keepdims=True).mean(3, keepdims=True) 76 | x_gap = self.fc1(x_gap) 77 | x_gap = self.bn1(x_gap) 78 | x_gap = self.act1(x_gap) 79 | x_attn = self.fc2(x_gap) 80 | 81 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 82 | if self.radix > 1: 83 | out = (x * x_attn.reshape((B, self.radix, 84 | RC // self.radix, 1, 1))).sum(dim=1) 85 | else: 86 | out = x * x_attn 87 | return out 88 | 89 | 90 | def main(): 91 | attention_block = SplitAttn(64) 92 | input = jt.ones([4, 64, 32, 32]) 93 | output = attention_block(input) 94 | print(input.size(), output.size()) 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /code/branch_attentions/sk_module.py: -------------------------------------------------------------------------------- 1 | 2 | import jittor as jt 3 | import jittor.nn as nn 4 | 5 | 6 | class SKModule(nn.Module): 7 | def __init__(self, features, M=2, G=32, r=16, stride=1, L=32): 8 | """ Constructor 9 | Args: 10 | features: input channel dimensionality. 11 | M: the number of branchs. 12 | G: num of convolution groups. 13 | r: the ratio for compute d, the length of z. 14 | stride: stride, default 1. 15 | L: the minimum dim of the vector z in paper, default 32. 16 | """ 17 | super(SKModule, self).__init__() 18 | d = max(int(features/r), L) 19 | self.M = M 20 | self.features = features 21 | self.convs = nn.ModuleList([]) 22 | for i in range(M): 23 | self.convs.append(nn.Sequential( 24 | nn.Conv2d(features, features, kernel_size=3, stride=stride, 25 | padding=1+i, dilation=1+i, groups=G, bias=False), 26 | nn.BatchNorm2d(features), 27 | nn.ReLU() 28 | )) 29 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 30 | self.fc = nn.Sequential(nn.Conv2d(features, d, kernel_size=1, stride=1, bias=False), 31 | nn.BatchNorm2d(d), 32 | nn.ReLU()) 33 | self.fcs = nn.ModuleList([]) 34 | for i in range(M): 35 | self.fcs.append( 36 | nn.Conv2d(d, features, kernel_size=1, stride=1) 37 | ) 38 | self.softmax = nn.Softmax(dim=1) 39 | 40 | def execute(self, x): 41 | 42 | batch_size = x.shape[0] 43 | 44 | feats = [conv(x) for conv in self.convs] 45 | feats = jt.concat(feats, dim=1) 46 | feats = feats.view(batch_size, self.M, self.features, 47 | feats.shape[2], feats.shape[3]) 48 | 49 | feats_U = jt.sum(feats, dim=1) 50 | feats_S = self.gap(feats_U) 51 | feats_Z = self.fc(feats_S) 52 | 53 | attention_vectors = [fc(feats_Z) for fc in self.fcs] 54 | attention_vectors = jt.concat(attention_vectors, dim=1) 55 | attention_vectors = attention_vectors.view( 56 | batch_size, self.M, self.features, 1, 1) 57 | attention_vectors = self.softmax(attention_vectors) 58 | 59 | feats_V = jt.sum(feats*attention_vectors, dim=1) 60 | 61 | return feats_V 62 | 63 | 64 | def main(): 65 | attention_block = SKModule(64) 66 | input = jt.rand([4, 64, 32, 32]) 67 | output = attention_block(input) 68 | print(input.size(), output.size()) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /code/channel_attentions/dia_module.py: -------------------------------------------------------------------------------- 1 | # DIANet: Dense-and-Implicit Attention Network (AAAI 2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class small_cell(nn.Module): 7 | def __init__(self, input_size, hidden_size): 8 | """"Constructor of the class""" 9 | super(small_cell, self).__init__() 10 | self.seq = nn.Sequential(nn.Linear(input_size, input_size // 4), 11 | nn.ReLU(), 12 | nn.Linear(input_size // 4, 4 * hidden_size)) 13 | 14 | def execute(self, x): 15 | return self.seq(x) 16 | 17 | 18 | class LSTMCell(nn.Module): 19 | def __init__(self, input_size, hidden_size, nlayers, dropout=0.1): 20 | """"Constructor of the class""" 21 | super(LSTMCell, self).__init__() 22 | 23 | self.nlayers = nlayers 24 | self.dropout = nn.Dropout(p=dropout) 25 | 26 | ih, hh = [], [] 27 | for i in range(nlayers): 28 | if i == 0: 29 | # ih.append(nn.Linear(input_size, 4 * hidden_size)) 30 | ih.append(small_cell(input_size, hidden_size)) 31 | # hh.append(nn.Linear(hidden_size, 4 * hidden_size)) 32 | hh.append(small_cell(hidden_size, hidden_size)) 33 | else: 34 | ih.append(nn.Linear(hidden_size, 4 * hidden_size)) 35 | hh.append(nn.Linear(hidden_size, 4 * hidden_size)) 36 | self.w_ih = nn.ModuleList(ih) 37 | self.w_hh = nn.ModuleList(hh) 38 | 39 | def execute(self, input, hidden): 40 | """"Defines the forward computation of the LSTMCell""" 41 | hy, cy = [], [] 42 | for i in range(self.nlayers): 43 | hx, cx = hidden[0][i], hidden[1][i] 44 | gates = self.w_ih[i](input) + self.w_hh[i](hx) 45 | i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1) 46 | i_gate = i_gate.sigmoid() 47 | f_gate = f_gate.sigmoid() 48 | c_gate = jt.tanh(c_gate) 49 | o_gate = o_gate.sigmoid() 50 | ncx = (f_gate * cx) + (i_gate * c_gate) 51 | # nhx = o_gate * torch.tanh(ncx) 52 | nhx = o_gate * ncx.sigmoid() 53 | cy.append(ncx) 54 | hy.append(nhx) 55 | input = self.dropout(nhx) 56 | 57 | hy, cy = jt.stack(hy, 0), jt.stack( 58 | cy, 0) # number of layer * batch * hidden 59 | return hy, cy 60 | 61 | 62 | class Attention(nn.Module): 63 | def __init__(self, channel): 64 | super(Attention, self).__init__() 65 | self.lstm = LSTMCell(channel, channel, 1) 66 | 67 | self.GlobalAvg = nn.AdaptiveAvgPool2d((1, 1)) 68 | self.relu = nn.ReLU() 69 | 70 | def execute(self, x): 71 | org = x 72 | seq = self.GlobalAvg(x) 73 | seq = seq.view(seq.size(0), seq.size(1)) 74 | ht = jt.zeros((1, seq.size(0), seq.size( 75 | 1))) # 1 mean number of layers 76 | ct = jt.zeros((1, seq.size(0), seq.size(1))) 77 | ht, ct = self.lstm(seq, (ht, ct)) # 1 * batch size * length 78 | # ht = self.sigmoid(ht) 79 | x = x * (ht[-1].view(ht.size(1), ht.size(2), 1, 1)) 80 | x += org 81 | x = self.relu(x) 82 | 83 | return x # , list 84 | 85 | 86 | def main(): 87 | attention_block = Attention(64) 88 | input = jt.rand([4, 64, 32, 32]) 89 | output = attention_block(input) 90 | print(input.size(), output.size()) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /code/channel_attentions/eca_module.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class ECALayer(nn.Module): 6 | """ 7 | Constructs a ECA module. 8 | Args: 9 | k_size: Adaptive selection of kernel size 10 | """ 11 | 12 | def __init__(self, k_size=3): 13 | super(ECALayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, 16 | padding=(k_size - 1) // 2, bias=False) 17 | self.sigmoid = nn.Sigmoid() 18 | 19 | def execute(self, x): 20 | # feature descriptor on the global spatial information 21 | y = self.avg_pool(x) 22 | 23 | # Two different branches of ECA module 24 | y = self.conv(y.squeeze(-1).transpose(-1, -2) 25 | ).transpose(-1, -2).unsqueeze(-1) 26 | 27 | y = self.sigmoid(y) 28 | 29 | return x * y.expand_as(x) 30 | 31 | 32 | def main(): 33 | attention_block = ECALayer() 34 | input = jt.rand([4, 64, 32, 32]) 35 | output = attention_block(input) 36 | print(input.size(), output.size()) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /code/channel_attentions/enc_module.py: -------------------------------------------------------------------------------- 1 | # Context encoding for semantic segmentation (CVPR 2018) 2 | import jittor as jt 3 | from jittor import nn, init 4 | 5 | 6 | class Encoding(nn.Module): 7 | def __init__(self, channels, num_codes): 8 | super(Encoding, self).__init__() 9 | # init codewords and smoothing factor 10 | self.channels, self.num_codes = channels, num_codes 11 | std = 1. / ((num_codes * channels)**0.5) 12 | # [num_codes, channels] 13 | self.codewords = init.uniform_( 14 | jt.random((num_codes, channels)), -std, std) 15 | # [num_codes] 16 | self.scale = init.uniform_(jt.random((num_codes,)), -1, 0) 17 | 18 | @staticmethod 19 | def scaled_l2(x, codewords, scale): 20 | num_codes, channels = codewords.size() 21 | batch_size = x.size(0) 22 | reshaped_scale = scale.view((1, 1, num_codes)) 23 | expanded_x = x.unsqueeze(2).expand( 24 | (batch_size, x.size(1), num_codes, channels)) 25 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 26 | 27 | scaled_l2_norm = reshaped_scale * ( 28 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 29 | return scaled_l2_norm 30 | 31 | @staticmethod 32 | def aggregate(assignment_weights, x, codewords): 33 | num_codes, channels = codewords.size() 34 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 35 | batch_size = x.size(0) 36 | 37 | expanded_x = x.unsqueeze(2).expand( 38 | (batch_size, x.size(1), num_codes, channels)) 39 | encoded_feat = (assignment_weights.unsqueeze(3) * 40 | (expanded_x - reshaped_codewords)).sum(dim=1) 41 | return encoded_feat 42 | 43 | def execute(self, x): 44 | assert x.ndim == 4 and x.size(1) == self.channels 45 | # [batch_size, channels, height, width] 46 | batch_size = x.size(0) 47 | # [batch_size, height x width, channels] 48 | x = x.view(batch_size, self.channels, -1).transpose(0, 2, 1) 49 | # assignment_weights: [batch_size, channels, num_codes] 50 | assignment_weights = nn.softmax( 51 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 52 | # aggregate 53 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 54 | return encoded_feat 55 | 56 | 57 | class EncModule(nn.Module): 58 | def __init__(self, in_channels, num_codes): 59 | super(EncModule, self).__init__() 60 | self.encoding_project = nn.Conv2d(in_channels, in_channels, 1) 61 | self.encoding = nn.Sequential( 62 | Encoding(channels=in_channels, num_codes=num_codes), 63 | nn.BatchNorm(num_codes), 64 | nn.ReLU()) 65 | self.fc = nn.Sequential( 66 | nn.Linear(in_channels, in_channels), nn.Sigmoid()) 67 | 68 | def execute(self, x): 69 | encoding_projection = self.encoding_project(x) 70 | encoding_feat = self.encoding(encoding_projection).mean(dim=1) 71 | batch_size, channels, _, _ = x.size() 72 | gamma = self.fc(encoding_feat) 73 | return x*gamma.view(batch_size, channels, 1, 1) 74 | 75 | 76 | def main(): 77 | attention_block = EncModule(64, 32) 78 | input = jt.rand([4, 64, 32, 32]) 79 | output = attention_block(input) 80 | print(input.size(), output.size()) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /code/channel_attentions/fcanet.py: -------------------------------------------------------------------------------- 1 | # Fcanet: Frequency channel attention networks (ICCV 2021) 2 | import math 3 | import jittor as jt 4 | from jittor import nn 5 | 6 | 7 | def get_freq_indices(method): 8 | assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32', 9 | 'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32', 10 | 'low1', 'low2', 'low4', 'low8', 'low16', 'low32'] 11 | num_freq = int(method[3:]) 12 | if 'top' in method: 13 | all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 14 | 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2, 6, 1] 15 | all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 16 | 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0, 5, 3] 17 | mapper_x = all_top_indices_x[:num_freq] 18 | mapper_y = all_top_indices_y[:num_freq] 19 | elif 'low' in method: 20 | all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 21 | 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4] 22 | all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 23 | 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3] 24 | mapper_x = all_low_indices_x[:num_freq] 25 | mapper_y = all_low_indices_y[:num_freq] 26 | elif 'bot' in method: 27 | all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 28 | 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5, 3, 6] 29 | all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 30 | 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3, 3, 3] 31 | mapper_x = all_bot_indices_x[:num_freq] 32 | mapper_y = all_bot_indices_y[:num_freq] 33 | else: 34 | raise NotImplementedError 35 | return mapper_x, mapper_y 36 | 37 | 38 | class MultiSpectralAttentionLayer(nn.Module): 39 | def __init__(self, channel, dct_h, dct_w, reduction=16, freq_sel_method='top16'): 40 | super(MultiSpectralAttentionLayer, self).__init__() 41 | self.reduction = reduction 42 | self.dct_h = dct_h 43 | self.dct_w = dct_w 44 | 45 | mapper_x, mapper_y = get_freq_indices(freq_sel_method) 46 | self.num_split = len(mapper_x) 47 | mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] 48 | mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y] 49 | # make the frequencies in different sizes are identical to a 7x7 frequency space 50 | # eg, (2,2) in 14x14 is identical to (1,1) in 7x7 51 | 52 | self.dct_layer = MultiSpectralDCTLayer( 53 | dct_h, dct_w, mapper_x, mapper_y, channel) 54 | self.fc = nn.Sequential( 55 | nn.Linear(channel, channel // reduction, bias=False), 56 | nn.ReLU(), 57 | nn.Linear(channel // reduction, channel, bias=False), 58 | nn.Sigmoid() 59 | ) 60 | self.avgpool = nn.AdaptiveAvgPool2d((self.dct_h, self.dct_w)) 61 | 62 | def execute(self, x): 63 | n, c, h, w = x.shape 64 | x_pooled = x 65 | if h != self.dct_h or w != self.dct_w: 66 | x_pooled = self.avgpool(x) 67 | # If you have concerns about one-line-change, don't worry. :) 68 | # In the ImageNet models, this line will never be triggered. 69 | # This is for compatibility in instance segmentation and object detection. 70 | y = self.dct_layer(x_pooled) 71 | 72 | y = self.fc(y).view(n, c, 1, 1) 73 | return x * y.expand_as(x) 74 | 75 | 76 | class MultiSpectralDCTLayer(nn.Module): 77 | """ 78 | Generate dct filters 79 | """ 80 | 81 | def __init__(self, height, width, mapper_x, mapper_y, channel): 82 | super(MultiSpectralDCTLayer, self).__init__() 83 | 84 | assert len(mapper_x) == len(mapper_y) 85 | assert channel % len(mapper_x) == 0 86 | 87 | self.num_freq = len(mapper_x) 88 | 89 | # fixed DCT init 90 | self.weight = self.get_dct_filter( 91 | height, width, mapper_x, mapper_y, channel) 92 | 93 | def execute(self, x): 94 | assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + \ 95 | str(len(x.shape)) 96 | # n, c, h, w = x.shape 97 | 98 | x = x * self.weight 99 | result = jt.sum(jt.sum(x, dim=2), dim=2) 100 | return result 101 | 102 | def build_filter(self, pos, freq, POS): 103 | result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) 104 | if freq == 0: 105 | return result 106 | else: 107 | return result * math.sqrt(2) 108 | 109 | def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel): 110 | dct_filter = jt.zeros((channel, tile_size_x, tile_size_y)) 111 | 112 | c_part = channel // len(mapper_x) 113 | 114 | for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)): 115 | for t_x in range(tile_size_x): 116 | for t_y in range(tile_size_y): 117 | dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter( 118 | t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y) 119 | 120 | return dct_filter 121 | 122 | 123 | def main(): 124 | attention_block = MultiSpectralAttentionLayer(64, 16, 16) 125 | input = jt.ones([4, 64, 32, 32]) 126 | output = attention_block(input) 127 | print(input.size(), output.size()) 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /code/channel_attentions/gct_module.py: -------------------------------------------------------------------------------- 1 | # Gated channel transformation for visual recognition (CVPR2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class GCT(nn.Module): 7 | 8 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False): 9 | super(GCT, self).__init__() 10 | 11 | self.alpha = jt.ones((1, num_channels, 1, 1)) 12 | self.gamma = jt.zeros((1, num_channels, 1, 1)) 13 | self.beta = jt.zeros((1, num_channels, 1, 1)) 14 | self.epsilon = epsilon 15 | self.mode = mode 16 | self.after_relu = after_relu 17 | 18 | def execute(self, x): 19 | 20 | if self.mode == 'l2': 21 | embedding = (x.pow(2).sum(2, keepdims=True).sum(3, keepdims=True) + 22 | self.epsilon).pow(0.5) * self.alpha 23 | norm = self.gamma / \ 24 | (embedding.pow(2).mean(dim=1, keepdims=True) + self.epsilon).pow(0.5) 25 | 26 | elif self.mode == 'l1': 27 | if not self.after_relu: 28 | _x = jt.abs(x) 29 | else: 30 | _x = x 31 | embedding = _x.sum(2, keepdims=True).sum( 32 | 3, keepdims=True) * self.alpha 33 | norm = self.gamma / \ 34 | (jt.abs(embedding).mean(dim=1, keepdims=True) + self.epsilon) 35 | else: 36 | print('Unknown mode!') 37 | 38 | gate = 1. + jt.tanh(embedding * norm + self.beta) 39 | 40 | return x * gate 41 | 42 | 43 | def main(): 44 | attention_block = GCT(64) 45 | input = jt.rand([4, 64, 32, 32]) 46 | output = attention_block(input) 47 | print(input.size(), output.size()) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /code/channel_attentions/se_module.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, channel // reduction, bias=False), 11 | nn.ReLU(), 12 | nn.Linear(channel // reduction, channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def execute(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y.expand_as(x) 21 | 22 | 23 | def main(): 24 | attention_block = SELayer(64) 25 | input = jt.rand([4, 64, 32, 32]) 26 | output = attention_block(input) 27 | print(input.size(), output.size()) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /code/channel_attentions/soca_module.py: -------------------------------------------------------------------------------- 1 | # Second-Order Attention Network for Single Image Super-Resolution (CVPR 2019) 2 | import jittor as jt 3 | from jittor import nn, Function 4 | 5 | 6 | class Covpool(Function): 7 | def execute(self, input): 8 | x = input 9 | batchSize = x.data.shape[0] 10 | dim = x.data.shape[1] 11 | h = x.data.shape[2] 12 | w = x.data.shape[3] 13 | M = h*w 14 | x = x.reshape(batchSize, dim, M) 15 | I_hat = (-1./M/M)*jt.ones((M, M)) + (1./M)*jt.init.eye((M, M)) 16 | I_hat = I_hat.view(1, M, M).repeat(batchSize, 1, 1) 17 | y = nn.bmm(nn.bmm(x, I_hat), x.transpose(0, 2, 1)) 18 | self.save_vars = (input, I_hat) 19 | return y 20 | 21 | def grad(self, grad_output): 22 | input, I_hat = self.save_vars 23 | x = input 24 | batchSize = x.data.shape[0] 25 | dim = x.data.shape[1] 26 | h = x.data.shape[2] 27 | w = x.data.shape[3] 28 | M = h*w 29 | x = x.reshape(batchSize, dim, M) 30 | grad_input = grad_output + grad_output.transpose(0, 2, 1) 31 | grad_input = nn.bmm(nn.bmm(grad_input, x), I_hat) 32 | grad_input = grad_input.reshape(batchSize, dim, h, w) 33 | return grad_input 34 | 35 | 36 | class Sqrtm(Function): 37 | def execute(self, input, iterN): 38 | x = input 39 | batchSize = x.data.shape[0] 40 | dim = x.data.shape[1] 41 | I3 = 3.0*jt.init.eye((dim, dim)).view(1, 42 | dim, dim).repeat(batchSize, 1, 1) 43 | normA = (1.0/3.0)*x.matmul(I3).sum(dim=1).sum(dim=1) 44 | A = x.divide(normA.view(batchSize, 1, 1).expand_as(x)) 45 | Y = jt.zeros((batchSize, iterN, dim, dim)) 46 | Y.requires_grad = False 47 | Z = jt.init.eye((dim, dim)).view( 48 | 1, dim, dim).repeat(batchSize, iterN, 1, 1) 49 | if iterN < 2: 50 | ZY = 0.5*(I3 - A) 51 | Y[:, 0, :, :] = nn.bmm(A, ZY) 52 | else: 53 | ZY = 0.5*(I3 - A) 54 | Y[:, 0, :, :] = nn.bmm(A, ZY) 55 | Z[:, 0, :, :] = ZY 56 | for i in range(1, iterN-1): 57 | ZY = 0.5*nn.bmm(I3 - Z[:, i-1, :, :], Y[:, i-1, :, :]) 58 | Y[:, i, :, :] = nn.bmm(Y[:, i-1, :, :], ZY) 59 | Z[:, i, :, :] = nn.bmm(ZY, Z[:, i-1, :, :]) 60 | ZY = nn.bmm( 61 | nn.bmm(0.5*Y[:, iterN-2, :, :], I3 - Z[:, iterN-2, :, :]), Y[:, iterN-2, :, :]) 62 | y = ZY*jt.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 63 | self.save_vars = (input, A, ZY, normA, Y, Z) 64 | self.iterN = iterN 65 | return y 66 | 67 | def grad(self, grad_output): 68 | input, A, ZY, normA, Y, Z = self.save_vars 69 | iterN = self.iterN 70 | x = input 71 | batchSize = x.data.shape[0] 72 | dim = x.data.shape[1] 73 | der_postCom = grad_output * \ 74 | jt.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 75 | der_postComAux = ( 76 | grad_output*ZY).sum(dim=1).sum(dim=1).divide(2*jt.sqrt(normA)) 77 | I3 = 3.0*jt.init.eye((dim, dim)).view(1, dim, 78 | dim).repeat(batchSize, 1, 1) 79 | if iterN < 2: 80 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) 81 | else: 82 | dldY = 0.5*(nn.bmm(der_postCom, I3 - nn.bmm(Y[:, iterN-2, :, :], Z[:, iterN-2, :, :])) - nn.bmm( 83 | nn.bmm(Z[:, iterN-2, :, :], Y[:, iterN-2, :, :]), der_postCom)) 84 | dldZ = -0.5*nn.bmm( 85 | nn.bmm(Y[:, iterN-2, :, :], der_postCom), Y[:, iterN-2, :, :]) 86 | for i in range(iterN-3, -1, -1): 87 | YZ = I3 - nn.bmm(Y[:, i, :, :], Z[:, i, :, :]) 88 | ZY = nn.bmm(Z[:, i, :, :], Y[:, i, :, :]) 89 | dldY_ = 0.5*(nn.bmm(dldY, YZ) - 90 | nn.bmm(nn.bmm(Z[:, i, :, :], dldZ), Z[:, i, :, :]) - 91 | nn.bmm(ZY, dldY)) 92 | dldZ_ = 0.5*(nn.bmm(YZ, dldZ) - 93 | nn.bmm(nn.bmm(Y[:, i, :, :], dldY), Y[:, i, :, :]) - 94 | nn.bmm(dldZ, ZY)) 95 | dldY = dldY_ 96 | dldZ = dldZ_ 97 | der_NSiter = 0.5*(nn.bmm(dldY, I3 - A) - dldZ - nn.bmm(A, dldY)) 98 | grad_input = der_NSiter.divide( 99 | normA.view(batchSize, 1, 1).expand_as(x)) 100 | grad_aux = der_NSiter.matmul(x).sum(dim=1).sum(dim=1) 101 | for i in range(batchSize): 102 | grad_input[i, :, :] += (der_postComAux[i] 103 | - grad_aux[i] / (normA[i] * normA[i])) * jt.ones((dim,)).diag() 104 | return grad_input, None 105 | 106 | 107 | class SOCA(nn.Module): 108 | def __init__(self, channel, reduction=8): 109 | super().__init__() 110 | 111 | self.conv_du = nn.Sequential( 112 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 113 | nn.ReLU(), 114 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 115 | nn.Sigmoid() 116 | ) 117 | self.CovpoolLayer = Covpool() 118 | self.SqrtmLayer = Sqrtm() 119 | 120 | def execute(self, x): 121 | b, c, h, w = x.shape 122 | 123 | h1 = 1000 124 | w1 = 1000 125 | if h < h1 and w < w1: 126 | x_sub = x 127 | elif h < h1 and w > w1: 128 | W = (w - w1) // 2 129 | x_sub = x[:, :, :, W:(W + w1)] 130 | elif w < w1 and h > h1: 131 | H = (h - h1) // 2 132 | x_sub = x[:, :, H:H + h1, :] 133 | else: 134 | H = (h - h1) // 2 135 | W = (w - w1) // 2 136 | x_sub = x[:, :, H:(H + h1), W:(W + w1)] 137 | 138 | # MPN-COV 139 | cov_mat = self.CovpoolLayer(x_sub) 140 | cov_mat_sqrt = self.SqrtmLayer(cov_mat, 5) 141 | 142 | cov_mat_sum = jt.mean(cov_mat_sqrt, 1) 143 | cov_mat_sum = cov_mat_sum.view(b, c, 1, 1) 144 | 145 | y_cov = self.conv_du(cov_mat_sum) 146 | 147 | return y_cov*x 148 | 149 | 150 | def main(): 151 | attention_block = SOCA(64) 152 | input = jt.rand([4, 64, 32, 32]) 153 | output = attention_block(input) 154 | jt.grad(output, input) 155 | print(input.size(), output.size()) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/bam.py: -------------------------------------------------------------------------------- 1 | # Bam: Bottleneck attention module(BMVC 2018) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class Flatten(nn.Module): 7 | def execute(self, x): 8 | return x.view(x.size(0), -1) 9 | 10 | 11 | class ChannelGate(nn.Module): 12 | def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): 13 | super(ChannelGate, self).__init__() 14 | self.gate_c = nn.ModuleList() 15 | self.gate_c.append(Flatten()) 16 | gate_channels = [gate_channel] 17 | gate_channels += [gate_channel // reduction_ratio] * num_layers 18 | gate_channels += [gate_channel] 19 | for i in range(len(gate_channels) - 2): 20 | self.gate_c.append(nn.Linear( 21 | gate_channels[i], gate_channels[i+1])) 22 | self.gate_c.append(nn.BatchNorm1d(gate_channels[i+1])) 23 | self.gate_c.append(nn.ReLU()) 24 | self.gate_c.append(nn.Linear( 25 | gate_channels[-2], gate_channels[-1])) 26 | 27 | def execute(self, x): 28 | avg_pool = nn.avg_pool2d( 29 | x, x.size(2), stride=x.size(2)) 30 | return self.gate_c(avg_pool).unsqueeze(2).unsqueeze(3).expand_as(x) 31 | 32 | 33 | class SpatialGate(nn.Module): 34 | def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): 35 | super(SpatialGate, self).__init__() 36 | self.gate_s = nn.ModuleList() 37 | self.gate_s.append(nn.Conv2d( 38 | gate_channel, gate_channel//reduction_ratio, kernel_size=1)) 39 | self.gate_s.append(nn.BatchNorm2d( 40 | gate_channel//reduction_ratio)) 41 | self.gate_s.append(nn.ReLU()) 42 | for i in range(dilation_conv_num): 43 | self.gate_s.append(nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, 44 | padding=dilation_val, dilation=dilation_val)) 45 | self.gate_s.append(nn.BatchNorm2d( 46 | gate_channel//reduction_ratio)) 47 | self.gate_s.append(nn.ReLU()) 48 | self.gate_s.append(nn.Conv2d( 49 | gate_channel//reduction_ratio, 1, kernel_size=1)) 50 | 51 | def execute(self, x): 52 | return self.gate_s(x).expand_as(x) 53 | 54 | 55 | class BAM(nn.Module): 56 | def __init__(self, gate_channel): 57 | super(BAM, self).__init__() 58 | self.channel_att = ChannelGate(gate_channel) 59 | self.spatial_att = SpatialGate(gate_channel) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | def execute(self, x): 63 | att = 1 + self.sigmoid(self.channel_att(x) * self.spatial_att(x)) 64 | return att * x 65 | 66 | 67 | def main(): 68 | attention_block = BAM(64) 69 | input = jt.rand([4, 64, 32, 32]) 70 | output = attention_block(input) 71 | print(input.size(), output.size()) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/cbam.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 7 | super(BasicConv, self).__init__() 8 | self.out_channels = out_planes 9 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 10 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 11 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, 12 | momentum=0.01, affine=True) if bn else None 13 | self.relu = nn.ReLU() if relu else None 14 | 15 | def execute(self, x): 16 | x = self.conv(x) 17 | if self.bn is not None: 18 | x = self.bn(x) 19 | if self.relu is not None: 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | class ChannelGate(nn.Module): 25 | def __init__(self, channel, reduction=16): 26 | super(ChannelGate, self).__init__() 27 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 28 | self.fc_avg = nn.Sequential( 29 | nn.Linear(channel, channel // reduction, bias=False), 30 | nn.ReLU(), 31 | nn.Linear(channel // reduction, channel, bias=False), 32 | ) 33 | self.max_pool = nn.AdaptiveMaxPool2d(1) 34 | self.fc_max = nn.Sequential( 35 | nn.Linear(channel, channel // reduction, bias=False), 36 | nn.ReLU(), 37 | nn.Linear(channel // reduction, channel, bias=False), 38 | ) 39 | 40 | self.sigmoid = nn.Sigmoid() 41 | 42 | def execute(self, x): 43 | b, c, _, _ = x.size() 44 | y_avg = self.avg_pool(x).view(b, c) 45 | y_avg = self.fc_avg(y_avg).view(b, c, 1, 1) 46 | 47 | y_max = self.max_pool(x).view(b, c) 48 | y_max = self.fc_max(y_max).view(b, c, 1, 1) 49 | 50 | y = self.sigmoid(y_avg + y_avg) 51 | return x * y.expand_as(x) 52 | 53 | 54 | class ChannelPool(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | def execute(self, x): 59 | x_max = jt.max(x, 1).unsqueeze(1) 60 | x_avg = jt.mean(x, 1).unsqueeze(1) 61 | x = jt.concat([x_max, x_avg], dim=1) 62 | return x 63 | 64 | 65 | class SpatialGate(nn.Module): 66 | def __init__(self): 67 | super(SpatialGate, self).__init__() 68 | kernel_size = 7 69 | self.compress = ChannelPool() 70 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=( 71 | kernel_size-1) // 2, relu=False) 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def execute(self, x): 75 | x_compress = self.compress(x) 76 | x_out = self.spatial(x_compress) 77 | scale = self.sigmoid(x_out) # broadcasting 78 | return x * scale 79 | 80 | 81 | class CBAM(nn.Module): 82 | def __init__(self, gate_channels, reduction_ratio=16): 83 | super(CBAM, self).__init__() 84 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio) 85 | self.SpatialGate = SpatialGate() 86 | 87 | def execute(self, x): 88 | x_out = self.ChannelGate(x) 89 | x_out = self.SpatialGate(x_out) 90 | return x_out 91 | 92 | 93 | def main(): 94 | attention_block = CBAM(64) 95 | input = jt.rand([4, 64, 32, 32]) 96 | output = attention_block(input) 97 | print(input.size(), output.size()) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/coordatt_module.py: -------------------------------------------------------------------------------- 1 | # Coordinate attention for efficient mobile network design (CVPR 2021) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class h_sigmoid(nn.Module): 7 | def __init__(self): 8 | super(h_sigmoid, self).__init__() 9 | self.relu = nn.ReLU6() 10 | 11 | def execute(self, x): 12 | return self.relu(x + 3) / 6 13 | 14 | 15 | class h_swish(nn.Module): 16 | def __init__(self): 17 | super(h_swish, self).__init__() 18 | self.sigmoid = h_sigmoid() 19 | 20 | def execute(self, x): 21 | return x * self.sigmoid(x) 22 | 23 | 24 | class CoordAtt(nn.Module): 25 | def __init__(self, inp, oup, reduction=32): 26 | super(CoordAtt, self).__init__() 27 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 28 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 29 | 30 | mip = max(8, inp // reduction) 31 | 32 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 33 | self.bn1 = nn.BatchNorm2d(mip) 34 | self.act = h_swish() 35 | 36 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 37 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 38 | 39 | def execute(self, x): 40 | identity = x 41 | 42 | n, c, h, w = x.size() 43 | x_h = self.pool_h(x) 44 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 45 | 46 | y = jt.concat([x_h, x_w], dim=2) 47 | y = self.conv1(y) 48 | y = self.bn1(y) 49 | y = self.act(y) 50 | 51 | x_h, x_w = jt.split(y, [h, w], dim=2) 52 | x_w = x_w.permute(0, 1, 3, 2) 53 | 54 | a_h = self.conv_h(x_h).sigmoid() 55 | a_w = self.conv_w(x_w).sigmoid() 56 | 57 | out = identity * a_w * a_h 58 | 59 | return out 60 | 61 | 62 | def main(): 63 | attention_block = CoordAtt(64, 64) 64 | input = jt.rand([4, 64, 32, 32]) 65 | output = attention_block(input) 66 | print(input.size(), output.size()) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/danet.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class DANetHead(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels): 8 | super(DANetHead, self).__init__() 9 | inter_channels = in_channels // 4 10 | self.conv5a = nn.Sequential(nn.Conv(in_channels, inter_channels, 3, padding=1, bias=False), 11 | nn.BatchNorm(inter_channels), 12 | nn.ReLU()) 13 | 14 | self.conv5c = nn.Sequential(nn.Conv(in_channels, inter_channels, 3, padding=1, bias=False), 15 | nn.BatchNorm(inter_channels), 16 | nn.ReLU()) 17 | 18 | self.sa = PAM_Module(inter_channels) 19 | self.sc = CAM_Module(inter_channels) 20 | self.conv51 = nn.Sequential(nn.Conv(inter_channels, inter_channels, 3, padding=1, bias=False), 21 | nn.BatchNorm(inter_channels), 22 | nn.ReLU()) 23 | self.conv52 = nn.Sequential(nn.Conv(inter_channels, inter_channels, 3, padding=1, bias=False), 24 | nn.BatchNorm(inter_channels), 25 | nn.ReLU()) 26 | 27 | self.conv8 = nn.Sequential(nn.Dropout( 28 | 0.1, False), nn.Conv(inter_channels, out_channels, 1)) 29 | 30 | def execute(self, x): 31 | 32 | feat1 = self.conv5a(x) 33 | sa_feat = self.sa(feat1) 34 | sa_conv = self.conv51(sa_feat) 35 | 36 | feat2 = self.conv5c(x) 37 | sc_feat = self.sc(feat2) 38 | sc_conv = self.conv52(sc_feat) 39 | 40 | feat_sum = sa_conv+sc_conv 41 | 42 | sasc_output = self.conv8(feat_sum) 43 | 44 | return sasc_output 45 | 46 | 47 | class PAM_Module(nn.Module): 48 | """ Position attention module""" 49 | # Ref from SAGAN 50 | 51 | def __init__(self, in_dim): 52 | super(PAM_Module, self).__init__() 53 | self.chanel_in = in_dim 54 | 55 | self.query_conv = nn.Conv( 56 | in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 57 | self.key_conv = nn.Conv( 58 | in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 59 | self.value_conv = nn.Conv( 60 | in_channels=in_dim, out_channels=in_dim, kernel_size=1) 61 | self.gamma = jt.zeros(1) 62 | 63 | self.softmax = nn.Softmax(dim=-1) 64 | 65 | def execute(self, x): 66 | """ 67 | inputs : 68 | x : input feature maps( B X C X H X W) 69 | returns : 70 | out : attention value + input feature 71 | attention: B X (HxW) X (HxW) 72 | """ 73 | m_batchsize, C, height, width = x.size() 74 | proj_query = self.query_conv(x).reshape( 75 | m_batchsize, -1, width*height).transpose(0, 2, 1) 76 | proj_key = self.key_conv(x).reshape(m_batchsize, -1, width*height) 77 | energy = nn.bmm(proj_query, proj_key) 78 | attention = self.softmax(energy) 79 | proj_value = self.value_conv(x).reshape(m_batchsize, -1, width*height) 80 | 81 | out = nn.bmm(proj_value, attention.transpose(0, 2, 1)) 82 | out = out.reshape(m_batchsize, C, height, width) 83 | 84 | out = self.gamma*out + x 85 | return out 86 | 87 | 88 | class CAM_Module(nn.Module): 89 | """ Channel attention module""" 90 | 91 | def __init__(self, in_dim): 92 | super(CAM_Module, self).__init__() 93 | self.chanel_in = in_dim 94 | self.gamma = jt.zeros(1) 95 | self.softmax = nn.Softmax(dim=-1) 96 | 97 | def execute(self, x): 98 | """ 99 | inputs : 100 | x : input feature maps( B X C X H X W) 101 | returns : 102 | out : attention value + input feature 103 | attention: B X C X C 104 | """ 105 | m_batchsize, C, height, width = x.size() 106 | proj_query = x.reshape(m_batchsize, C, -1) 107 | proj_key = x.reshape(m_batchsize, C, -1).transpose(0, 2, 1) 108 | energy = nn.bmm(proj_query, proj_key) 109 | #energy_new = jt.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 110 | attention = self.softmax(energy) 111 | proj_value = x.reshape(m_batchsize, C, -1) 112 | 113 | out = nn.bmm(attention, proj_value) 114 | out = out.reshape(m_batchsize, C, height, width) 115 | 116 | out = self.gamma*out + x 117 | return out 118 | 119 | 120 | def main(): 121 | attention_block = DANetHead(64, 64) 122 | input = jt.rand([4, 64, 32, 32]) 123 | output = attention_block(input) 124 | print(input.size(), output.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/lka.py: -------------------------------------------------------------------------------- 1 | # Visual Attention Network 2 | import jittor as jt 3 | import jittor.nn as nn 4 | 5 | 6 | class AttentionModule(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 10 | self.conv_spatial = nn.Conv2d( 11 | dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 12 | self.conv1 = nn.Conv2d(dim, dim, 1) 13 | 14 | def execute(self, x): 15 | u = x.clone() 16 | attn = self.conv0(x) 17 | attn = self.conv_spatial(attn) 18 | attn = self.conv1(attn) 19 | 20 | return u * attn 21 | 22 | 23 | class SpatialAttention(nn.Module): 24 | def __init__(self, d_model): 25 | super().__init__() 26 | 27 | self.proj_1 = nn.Conv2d(d_model, d_model, 1) 28 | self.activation = nn.GELU() 29 | self.spatial_gating_unit = AttentionModule(d_model) 30 | self.proj_2 = nn.Conv2d(d_model, d_model, 1) 31 | 32 | def execute(self, x): 33 | shorcut = x.clone() 34 | x = self.proj_1(x) 35 | x = self.activation(x) 36 | x = self.spatial_gating_unit(x) 37 | x = self.proj_2(x) 38 | x = x + shorcut 39 | return x 40 | 41 | 42 | def main(): 43 | attention_block = SpatialAttention(64) 44 | input = jt.rand([4, 64, 32, 32]) 45 | output = attention_block(input) 46 | print(input.size(), output.size()) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/scnet.py: -------------------------------------------------------------------------------- 1 | # Improving convolutional networks with self-calibrated convolutions (CVPR 2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class SCConv(nn.Module): 7 | def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer): 8 | super(SCConv, self).__init__() 9 | self.k2 = nn.Sequential( 10 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 11 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, 12 | padding=padding, dilation=dilation, 13 | groups=groups, bias=False), 14 | norm_layer(planes), 15 | ) 16 | self.k3 = nn.Sequential( 17 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, 18 | padding=padding, dilation=dilation, 19 | groups=groups, bias=False), 20 | norm_layer(planes), 21 | ) 22 | self.k4 = nn.Sequential( 23 | nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 24 | padding=padding, dilation=dilation, 25 | groups=groups, bias=False), 26 | norm_layer(planes), 27 | ) 28 | 29 | def execute(self, x): 30 | identity = x 31 | 32 | out = jt.sigmoid(jt.add(identity, nn.interpolate( 33 | self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) 34 | out = jt.multiply(self.k3(x), out) # k3 * sigmoid(identity + k2) 35 | out = self.k4(out) # k4 36 | 37 | return out 38 | 39 | 40 | def main(): 41 | attention_block = SCConv(64, 64, stride=1, 42 | padding=2, dilation=2, groups=1, pooling_r=4, norm_layer=nn.BatchNorm2d) 43 | input = jt.rand([4, 64, 32, 32]) 44 | output = attention_block(input) 45 | print(input.size(), output.size()) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/simam_module.py: -------------------------------------------------------------------------------- 1 | # Simam: A simple, parameter-free attention module for convolutional neural networks (ICML 2021) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class simam_module(nn.Module): 7 | def __init__(self, e_lambda=1e-4): 8 | super(simam_module, self).__init__() 9 | 10 | self.activaton = nn.Sigmoid() 11 | self.e_lambda = e_lambda 12 | 13 | def execute(self, x): 14 | 15 | b, c, h, w = x.size() 16 | 17 | n = w * h - 1 18 | 19 | x_minus_mu_square = ( 20 | x - x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)).pow(2) 21 | y = x_minus_mu_square / \ 22 | (4 * (x_minus_mu_square.sum(dim=2, 23 | keepdims=True).sum(dim=3, 24 | keepdims=True) / n + self.e_lambda)) + 0.5 25 | 26 | return x * self.activaton(y) 27 | 28 | 29 | def main(): 30 | attention_block = simam_module() 31 | input = jt.ones([4, 64, 32, 32]) 32 | output = attention_block(input) 33 | print(input.size(), output.size()) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/strip_pooling_module.py: -------------------------------------------------------------------------------- 1 | # Strip Pooling: Rethinking spatial pooling for scene parsing (CVPR 2020) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class StripPooling(nn.Module): 7 | """ 8 | Reference: 9 | """ 10 | 11 | def __init__(self, in_channels, pool_size, norm_layer, up_kwargs): 12 | super(StripPooling, self).__init__() 13 | self.pool1 = nn.AdaptiveAvgPool2d(pool_size[0]) 14 | self.pool2 = nn.AdaptiveAvgPool2d(pool_size[1]) 15 | self.pool3 = nn.AdaptiveAvgPool2d((1, None)) 16 | self.pool4 = nn.AdaptiveAvgPool2d((None, 1)) 17 | 18 | inter_channels = int(in_channels/4) 19 | self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), 20 | norm_layer(inter_channels), 21 | nn.ReLU()) 22 | self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), 23 | norm_layer(inter_channels), 24 | nn.ReLU()) 25 | self.conv2_0 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 26 | norm_layer(inter_channels)) 27 | self.conv2_1 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 28 | norm_layer(inter_channels)) 29 | self.conv2_2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 30 | norm_layer(inter_channels)) 31 | self.conv2_3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False), 32 | norm_layer(inter_channels)) 33 | self.conv2_4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False), 34 | norm_layer(inter_channels)) 35 | self.conv2_5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 36 | norm_layer(inter_channels), 37 | nn.ReLU()) 38 | self.conv2_6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 39 | norm_layer(inter_channels), 40 | nn.ReLU()) 41 | self.conv3 = nn.Sequential(nn.Conv2d(inter_channels*2, in_channels, 1, bias=False), 42 | norm_layer(in_channels)) 43 | # bilinear interpolate options 44 | self._up_kwargs = up_kwargs 45 | 46 | def execute(self, x): 47 | _, _, h, w = x.size() 48 | x1 = self.conv1_1(x) 49 | x2 = self.conv1_2(x) 50 | x2_1 = self.conv2_0(x1) 51 | x2_2 = nn.interpolate(self.conv2_1(self.pool1(x1)), 52 | (h, w), **self._up_kwargs) 53 | x2_3 = nn.interpolate(self.conv2_2(self.pool2(x1)), 54 | (h, w), **self._up_kwargs) 55 | x2_4 = nn.interpolate(self.conv2_3(self.pool3(x2)), 56 | (h, w), **self._up_kwargs) 57 | x2_5 = nn.interpolate(self.conv2_4(self.pool4(x2)), 58 | (h, w), **self._up_kwargs) 59 | x1 = self.conv2_5(nn.relu(x2_1 + x2_2 + x2_3)) 60 | x2 = self.conv2_6(nn.relu(x2_5 + x2_4)) 61 | out = self.conv3(jt.concat([x1, x2], dim=1)) 62 | return nn.relu(x + out) 63 | 64 | 65 | def main(): 66 | attention_block = StripPooling( 67 | 64, (20, 12), nn.BatchNorm2d, {'mode': 'bilinear', 'align_corners': True}) 68 | input = jt.rand([4, 64, 32, 32]) 69 | output = attention_block(input) 70 | print(input.size(), output.size()) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /code/channel_spatial_attentions/triplet_attention.py: -------------------------------------------------------------------------------- 1 | # Rotate to attend: Convolutional triplet attention module (WACV 2021) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class BasicConv(nn.Module): 7 | def __init__( 8 | self, 9 | in_planes, 10 | out_planes, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | dilation=1, 15 | groups=1, 16 | relu=True, 17 | bn=True, 18 | bias=False, 19 | ): 20 | super(BasicConv, self).__init__() 21 | self.out_channels = out_planes 22 | self.conv = nn.Conv2d( 23 | in_planes, 24 | out_planes, 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=padding, 28 | dilation=dilation, 29 | groups=groups, 30 | bias=bias, 31 | ) 32 | self.bn = ( 33 | nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) 34 | if bn 35 | else None 36 | ) 37 | self.relu = nn.ReLU() if relu else None 38 | 39 | def execute(self, x): 40 | x = self.conv(x) 41 | if self.bn is not None: 42 | x = self.bn(x) 43 | if self.relu is not None: 44 | x = self.relu(x) 45 | return x 46 | 47 | 48 | class ZPool(nn.Module): 49 | def execute(self, x): 50 | return jt.concat( 51 | (x.max(1).unsqueeze(1), x.mean(1).unsqueeze(1)), dim=1 52 | ) 53 | 54 | 55 | class AttentionGate(nn.Module): 56 | def __init__(self): 57 | super(AttentionGate, self).__init__() 58 | kernel_size = 7 59 | self.compress = ZPool() 60 | self.conv = BasicConv( 61 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False 62 | ) 63 | 64 | def execute(self, x): 65 | x_compress = self.compress(x) 66 | x_out = self.conv(x_compress) 67 | scale = x_out.sigmoid() 68 | return x * scale 69 | 70 | 71 | class TripletAttention(nn.Module): 72 | def __init__(self, no_spatial=False): 73 | super(TripletAttention, self).__init__() 74 | self.cw = AttentionGate() 75 | self.hc = AttentionGate() 76 | self.no_spatial = no_spatial 77 | if not no_spatial: 78 | self.hw = AttentionGate() 79 | 80 | def execute(self, x): 81 | x_perm1 = x.permute(0, 2, 1, 3) 82 | x_out1 = self.cw(x_perm1) 83 | x_out11 = x_out1.permute(0, 2, 1, 3) 84 | x_perm2 = x.permute(0, 3, 2, 1) 85 | x_out2 = self.hc(x_perm2) 86 | x_out21 = x_out2.permute(0, 3, 2, 1) 87 | if not self.no_spatial: 88 | x_out = self.hw(x) 89 | x_out = 1 / 3 * (x_out + x_out11 + x_out21) 90 | else: 91 | x_out = 1 / 2 * (x_out11 + x_out21) 92 | return x_out 93 | 94 | 95 | def main(): 96 | attention_block = TripletAttention() 97 | input = jt.ones([4, 64, 32, 32]) 98 | output = attention_block(input) 99 | print(input.size(), output.size()) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /code/spatial_attentions/attention_augmented_module.py: -------------------------------------------------------------------------------- 1 | # Attention augmented convolutional networks (ICCV 2019) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class AugmentedConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, dk, dv, Nh, relative): 8 | super(AugmentedConv, self).__init__() 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | self.kernel_size = kernel_size 12 | self.dk = dk 13 | self.dv = dv 14 | self.Nh = Nh 15 | self.relative = relative 16 | 17 | self.conv_out = nn.Conv2d( 18 | self.in_channels, self.out_channels - self.dv, self.kernel_size, padding=1) 19 | 20 | self.qkv_conv = nn.Conv2d( 21 | self.in_channels, 2 * self.dk + self.dv, kernel_size=1) 22 | 23 | self.attn_out = nn.Conv2d(self.dv, self.dv, 1) 24 | 25 | def execute(self, x): 26 | # Input x 27 | # (batch_size, channels, height, width) 28 | batch, _, height, width = x.size() 29 | 30 | # conv_out 31 | # (batch_size, out_channels, height, width) 32 | conv_out = self.conv_out(x) 33 | 34 | # flat_q, flat_k, flat_v 35 | # (batch_size, Nh, height * width, dvh or dkh) 36 | # dvh = dv / Nh, dkh = dk / Nh 37 | # q, k, v 38 | # (batch_size, Nh, height, width, dv or dk) 39 | flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv( 40 | x, self.dk, self.dv, self.Nh) 41 | logits = jt.matmul(flat_q.transpose(2, 3), flat_k) 42 | 43 | if self.relative: 44 | h_rel_logits, w_rel_logits = self.relative_logits(q) 45 | logits += h_rel_logits 46 | logits += w_rel_logits 47 | weights = nn.softmax(logits, dim=-1) 48 | 49 | # attn_out 50 | # (batch, Nh, height * width, dvh) 51 | attn_out = jt.matmul(weights, flat_v.transpose(2, 3)) 52 | attn_out = jt.reshape( 53 | attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) 54 | # combine_heads_2d 55 | # (batch, out_channels, height, width) 56 | attn_out = self.combine_heads_2d(attn_out) 57 | attn_out = self.attn_out(attn_out) 58 | return jt.concat((conv_out, attn_out), dim=1) 59 | 60 | def compute_flat_qkv(self, x, dk, dv, Nh): 61 | N, _, H, W = x.size() 62 | qkv = self.qkv_conv(x) 63 | q, k, v = jt.split(qkv, [dk, dk, dv], dim=1) 64 | q = self.split_heads_2d(q, Nh) 65 | k = self.split_heads_2d(k, Nh) 66 | v = self.split_heads_2d(v, Nh) 67 | 68 | dkh = dk // Nh 69 | q *= dkh ** -0.5 70 | flat_q = jt.reshape(q, (N, Nh, dk // Nh, H * W)) 71 | flat_k = jt.reshape(k, (N, Nh, dk // Nh, H * W)) 72 | flat_v = jt.reshape(v, (N, Nh, dv // Nh, H * W)) 73 | return flat_q, flat_k, flat_v, q, k, v 74 | 75 | def split_heads_2d(self, x, Nh): 76 | batch, channels, height, width = x.size() 77 | ret_shape = (batch, Nh, channels // Nh, height, width) 78 | split = jt.reshape(x, ret_shape) 79 | return split 80 | 81 | def combine_heads_2d(self, x): 82 | batch, Nh, dv, H, W = x.size() 83 | ret_shape = (batch, Nh * dv, H, W) 84 | return jt.reshape(x, ret_shape) 85 | 86 | def relative_logits(self, q): 87 | B, Nh, dk, H, W = q.size() 88 | q = jt.transpose(q, 2, 4).transpose(2, 3) 89 | 90 | key_rel_w = jt.randn((2 * W - 1, dk)) 91 | rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w") 92 | 93 | key_rel_h = jt.randn((2 * H - 1, dk)) 94 | rel_logits_h = self.relative_logits_1d( 95 | jt.transpose(q, 2, 3), key_rel_h, W, H, Nh, "h") 96 | 97 | return rel_logits_h, rel_logits_w 98 | 99 | def relative_logits_1d(self, q, rel_k, H, W, Nh, case): 100 | rel_logits = jt.matmul(q, rel_k.transpose(0, 1)) 101 | rel_logits = jt.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1)) 102 | rel_logits = self.rel_to_abs(rel_logits) 103 | 104 | rel_logits = jt.reshape(rel_logits, (-1, Nh, H, W, W)) 105 | rel_logits = jt.unsqueeze(rel_logits, dim=3) 106 | rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)) 107 | 108 | if case == "w": 109 | rel_logits = jt.transpose(rel_logits, 3, 4) 110 | elif case == "h": 111 | rel_logits = jt.transpose( 112 | rel_logits, 2, 4).transpose(4, 5).transpose(3, 5) 113 | rel_logits = jt.reshape(rel_logits, (-1, Nh, H * W, H * W)) 114 | return rel_logits 115 | 116 | def rel_to_abs(self, x): 117 | B, Nh, L, _ = x.size() 118 | 119 | col_pad = jt.zeros((B, Nh, L, 1)) 120 | x = jt.concat((x, col_pad), dim=3) 121 | 122 | flat_x = jt.reshape(x, (B, Nh, L * 2 * L)) 123 | flat_pad = jt.zeros((B, Nh, L - 1)) 124 | flat_x_padded = jt.concat((flat_x, flat_pad), dim=2) 125 | 126 | final_x = jt.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1)) 127 | final_x = final_x[:, :, :L, L - 1:] 128 | return final_x 129 | 130 | 131 | def main(): 132 | attention_block = AugmentedConv(64, 64, 3, 40, 4, 4, True) 133 | input = jt.rand([4, 64, 32, 32]) 134 | output = attention_block(input) 135 | print(input.size(), output.size()) 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /code/spatial_attentions/doub_attention.py: -------------------------------------------------------------------------------- 1 | # A2-Nets: Double Attention Networks (NIPS 2018) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class DoubleAtten(nn.Module): 7 | def __init__(self, in_c): 8 | super(DoubleAtten, self).__init__() 9 | self.in_c = in_c 10 | self.convA = nn.Conv2d(in_c, in_c, kernel_size=1) 11 | self.convB = nn.Conv2d(in_c, in_c, kernel_size=1) 12 | self.convV = nn.Conv2d(in_c, in_c, kernel_size=1) 13 | 14 | def execute(self, input): 15 | 16 | feature_maps = self.convA(input) 17 | atten_map = self.convB(input) 18 | b, _, h, w = feature_maps.shape 19 | 20 | feature_maps = feature_maps.view(b, 1, self.in_c, h*w) 21 | atten_map = atten_map.view(b, self.in_c, 1, h*w) 22 | global_descriptors = jt.mean( 23 | (feature_maps * nn.softmax(atten_map, dim=-1)), dim=-1) 24 | 25 | v = self.convV(input) 26 | atten_vectors = nn.softmax( 27 | v.view(b, self.in_c, h*w), dim=-1) 28 | out = nn.bmm(atten_vectors.permute(0, 2, 1), 29 | global_descriptors).permute(0, 2, 1) 30 | 31 | return out.view(b, _, h, w) 32 | 33 | 34 | def main(): 35 | attention_block = DoubleAtten(64) 36 | input = jt.rand([4, 64, 32, 32]) 37 | output = attention_block(input) 38 | jt.grad(output, input) 39 | print(input.size(), output.size()) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /code/spatial_attentions/external_attention.py: -------------------------------------------------------------------------------- 1 | # Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks (CVMJ2021) 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class External_attention(nn.Module): 7 | ''' 8 | Arguments: 9 | c (int): The input and output channel number. 10 | ''' 11 | 12 | def __init__(self, c): 13 | super(External_attention, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(c, c, 1) 16 | 17 | self.k = 64 18 | self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False) 19 | 20 | self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False) 21 | self.linear_1.weight = self.linear_0.weight.permute(1, 0, 2) 22 | 23 | self.conv2 = nn.Sequential( 24 | nn.Conv2d(c, c, 1, bias=False), 25 | nn.BatchNorm(c)) 26 | 27 | self.relu = nn.ReLU() 28 | 29 | def execute(self, x): 30 | idn = x 31 | x = self.conv1(x) 32 | 33 | b, c, h, w = x.size() 34 | n = h*w 35 | x = x.view(b, c, h*w) # b * c * n 36 | 37 | attn = self.linear_0(x) # b, k, n 38 | attn = nn.softmax(attn, dim=-1) # b, k, n 39 | 40 | attn = attn / (1e-9 + attn.sum(dim=1, keepdims=True)) # b, k, n 41 | x = self.linear_1(attn) # b, c, n 42 | 43 | x = x.view(b, c, h, w) 44 | x = self.conv2(x) 45 | x = x + idn 46 | x = self.relu(x) 47 | return x 48 | 49 | 50 | def main(): 51 | attention_block = External_attention(64) 52 | input = jt.rand([4, 64, 32, 32]) 53 | output = attention_block(input) 54 | print(input.size(), output.size()) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /code/spatial_attentions/gc_module.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | 4 | 5 | class GlobalContextBlock(nn.Module): 6 | def __init__(self, 7 | inplanes, 8 | ratio): 9 | super(GlobalContextBlock, self).__init__() 10 | self.inplanes = inplanes 11 | self.ratio = ratio 12 | self.planes = int(inplanes * ratio) 13 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | self.channel_add_conv = nn.Sequential( 17 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 18 | nn.LayerNorm([self.planes, 1, 1]), 19 | nn.ReLU(), # yapf: disable 20 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 21 | 22 | def spatial_pool(self, x): 23 | batch, channel, height, width = x.size() 24 | 25 | input_x = x 26 | # [N, C, H * W] 27 | input_x = input_x.view(batch, channel, height * width) 28 | # [N, 1, C, H * W] 29 | input_x = input_x.unsqueeze(1) 30 | # [N, 1, H, W] 31 | context_mask = self.conv_mask(x) 32 | # [N, 1, H * W] 33 | context_mask = context_mask.view(batch, 1, height * width) 34 | # [N, 1, H * W] 35 | context_mask = self.softmax(context_mask) 36 | # [N, 1, H * W, 1] 37 | context_mask = context_mask.unsqueeze(-1) 38 | # [N, 1, C, 1] 39 | context = jt.matmul(input_x, context_mask) 40 | # [N, C, 1, 1] 41 | context = context.view(batch, channel, 1, 1) 42 | 43 | return context 44 | 45 | def execute(self, x): 46 | # [N, C, 1, 1] 47 | context = self.spatial_pool(x) 48 | 49 | out = x 50 | # [N, C, 1, 1] 51 | channel_add_term = self.channel_add_conv(context) 52 | out = out + channel_add_term 53 | 54 | return out 55 | 56 | 57 | def main(): 58 | attention_block = GlobalContextBlock(64, 1/4) 59 | input = jt.rand([4, 64, 32, 32]) 60 | output = attention_block(input) 61 | print(input.size(), output.size()) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /code/spatial_attentions/hamnet.py: -------------------------------------------------------------------------------- 1 | # Is Attention Better Than Matrix Decomposition? (ICLR 2021) 2 | import jittor as jt 3 | from jittor import nn 4 | from contextlib import contextmanager 5 | 6 | 7 | @contextmanager 8 | def null_context(): 9 | yield 10 | 11 | 12 | class NMF(nn.Module): 13 | def __init__( 14 | self, 15 | dim, 16 | n, 17 | ratio=8, 18 | K=6, 19 | eps=2e-8 20 | ): 21 | super().__init__() 22 | r = dim // ratio 23 | 24 | self.D = jt.zeros((dim, r)).uniform_(0, 1) 25 | self.C = jt.zeros((r, n)).uniform_(0, 1) 26 | 27 | self.K = K 28 | 29 | self.eps = eps 30 | 31 | def execute(self, x): 32 | b, D, C, eps = x.shape[0], self.D, self.C, self.eps 33 | 34 | # x is made non-negative with relu as proposed in paper 35 | x = nn.relu(x) 36 | D = D.unsqueeze(0).repeat(b, 1, 1) 37 | C = C.unsqueeze(0).repeat(b, 1, 1) 38 | 39 | # transpose 40 | def t(tensor): return tensor.transpose(1, 2) 41 | 42 | for k in reversed(range(self.K)): 43 | # only calculate gradients on the last step, per propose 'One-step Gradient' 44 | context = null_context if k == 0 else jt.no_grad 45 | with context(): 46 | C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps)) 47 | D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps)) 48 | C, D = C_new, D_new 49 | 50 | return D @ C 51 | 52 | 53 | class Hamburger(nn.Module): 54 | def __init__( 55 | self, 56 | dim, 57 | n, 58 | inner_dim, 59 | ratio=8, 60 | K=6 61 | ): 62 | super().__init__() 63 | 64 | self.lower_bread = nn.Conv1d(dim, inner_dim, 1, bias=False) 65 | self.ham = NMF(inner_dim, n, ratio=ratio, K=K) 66 | self.upper_bread = nn.Conv1d(inner_dim, dim, 1, bias=False) 67 | 68 | def execute(self, x): 69 | input = x 70 | shape = x.shape 71 | x = x.flatten(2) 72 | 73 | x = self.lower_bread(x) 74 | x = self.ham(x) 75 | x = self.upper_bread(x) 76 | return input + x.reshape(shape) 77 | 78 | 79 | def main(): 80 | attention_block = Hamburger(64, 32*32, 64, 8, 6) 81 | input = jt.rand([4, 64, 32, 32]) 82 | output = attention_block(input) 83 | print(input.size(), output.size()) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /code/spatial_attentions/mhsa.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class MHSA(nn.Module): 6 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 7 | super(MHSA, self).__init__() 8 | self.num_heads = num_heads 9 | head_dim = dim // num_heads 10 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 11 | self.scale = qk_scale or head_dim ** -0.5 12 | 13 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 14 | self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj = nn.Linear(dim, dim) 16 | self.proj_drop = nn.Dropout(proj_drop) 17 | 18 | def execute(self, x): 19 | b, n, c = x.shape 20 | qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // 21 | self.num_heads).permute(2, 0, 3, 1, 4) 22 | 23 | q, k, v = qkv[0], qkv[1], qkv[2] 24 | 25 | # attn = nn.bmm(q,k.transpose(0,1,3,2))*self.scale 26 | attn = nn.bmm_transpose(q, k)*self.scale 27 | 28 | attn = nn.softmax(attn, dim=-1) 29 | 30 | attn = self.attn_drop(attn) 31 | 32 | out = nn.bmm(attn, v) 33 | out = out.transpose(0, 2, 1, 3).reshape(b, n, c) 34 | out = self.proj(out) 35 | out = self.proj_drop(out) 36 | 37 | return out 38 | 39 | 40 | def main(): 41 | attention_block = MHSA(64) 42 | input = jt.rand([4, 128, 64]) 43 | output = attention_block(input) 44 | print(input.size(), output.size()) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /code/spatial_attentions/ocr.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | from jittor import init 4 | 5 | class OCRHead(nn.Module): 6 | def __init__(self, in_channels, n_cls=19): 7 | super(OCRHead, self).__init__() 8 | self.relu = nn.ReLU() 9 | self.in_channels = in_channels 10 | self.softmax = nn.Softmax(dim = 2) 11 | self.conv_1x1 = nn.Conv(in_channels, in_channels, kernel_size=1) 12 | self.last_conv = nn.Conv(in_channels * 2, n_cls, kernel_size=3, stride=1, padding=1) 13 | self._zero_init_conv() 14 | def _zero_init_conv(self): 15 | self.conv_1x1.weight = init.constant([self.in_channels, self.in_channels, 1, 1], 'float', value=0.0) 16 | 17 | def execute(self, context, feature): 18 | batch_size, c, h, w = feature.shape 19 | origin_feature = feature 20 | feature = feature.reshape(batch_size, c, -1).transpose(0, 2, 1) # b, h*w, c 21 | context = context.reshape(batch_size, context.shape[1], -1) # b, n_cls, h*w 22 | attention = self.softmax(context) 23 | ocr_context = nn.bmm(attention, feature).transpose(0, 2, 1) # b, c, n_cls 24 | relation = nn.bmm(feature, ocr_context).transpose(0, 2, 1) # b, n_cls, h*w 25 | attention = self.softmax(relation) #b , n_cls, h*w 26 | result = nn.bmm(ocr_context, attention).reshape(batch_size, c, h, w) 27 | result = self.conv_1x1(result) 28 | result = jt.concat ([result, origin_feature], dim=1) 29 | result = self.last_conv (result) 30 | return result 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /code/spatial_attentions/offset_module.py: -------------------------------------------------------------------------------- 1 | # PCT: Point Cloud Transformer (CVMJ 2021) 2 | 3 | import jittor as jt 4 | from jittor import nn 5 | 6 | 7 | class SA_Layer(nn.Module): 8 | def __init__(self, channels): 9 | super(SA_Layer, self).__init__() 10 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 11 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 12 | self.q_conv.weight = self.k_conv.weight 13 | self.v_conv = nn.Conv1d(channels, channels, 1) 14 | self.trans_conv = nn.Conv1d(channels, channels, 1) 15 | self.after_norm = nn.BatchNorm1d(channels) 16 | self.act = nn.ReLU() 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | def execute(self, x): 20 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 21 | x_k = self.k_conv(x) # b, c, n 22 | x_v = self.v_conv(x) 23 | energy = nn.bmm(x_q, x_k) # b, n, n 24 | attention = self.softmax(energy) 25 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) 26 | x_r = nn.bmm(x_v, attention) # b, c, n 27 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 28 | x = x + x_r 29 | return x 30 | 31 | 32 | def main(): 33 | attention_block = SA_Layer(64) 34 | input = jt.rand([4, 64, 32]) 35 | output = attention_block(input) 36 | print(input.size(), output.size()) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /code/spatial_attentions/segformer_module.py: -------------------------------------------------------------------------------- 1 | # SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 2 | import jittor as jt 3 | from jittor import nn 4 | 5 | 6 | class EfficientAttention(nn.Module): 7 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 8 | super().__init__() 9 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 10 | 11 | self.dim = dim 12 | self.num_heads = num_heads 13 | head_dim = dim // num_heads 14 | self.scale = qk_scale or head_dim ** -0.5 15 | 16 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 17 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | self.sr_ratio = sr_ratio 23 | if sr_ratio > 1: 24 | self.sr = nn.Conv2d( 25 | dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 26 | self.norm = nn.LayerNorm(dim) 27 | 28 | def execute(self, x, H, W): 29 | B, N, C = x.shape 30 | q = self.q(x).reshape(B, N, self.num_heads, C // 31 | self.num_heads).permute(0, 2, 1, 3) 32 | 33 | if self.sr_ratio > 1: 34 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 35 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 36 | x_ = self.norm(x_) 37 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, 38 | C // self.num_heads).permute(2, 0, 3, 1, 4) 39 | else: 40 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // 41 | self.num_heads).permute(2, 0, 3, 1, 4) 42 | k, v = kv[0], kv[1] 43 | 44 | attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale 45 | attn = nn.softmax(attn, dim=-1) 46 | attn = self.attn_drop(attn) 47 | 48 | x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C) 49 | x = self.proj(x) 50 | x = self.proj_drop(x) 51 | 52 | return x 53 | 54 | 55 | def main(): 56 | attention_block = EfficientAttention(64) 57 | input = jt.rand([4, 128, 64]) 58 | output = attention_block(input, 8, 8) 59 | print(input.size(), output.size()) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /code/spatial_attentions/self_attention.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | 4 | 5 | class SelfAttention(nn.Module): 6 | """ self attention module""" 7 | 8 | def __init__(self, in_dim): 9 | super(SelfAttention, self).__init__() 10 | self.chanel_in = in_dim 11 | 12 | self.query = nn.Conv(in_channels=in_dim, 13 | out_channels=in_dim, kernel_size=1) 14 | self.key = nn.Conv(in_channels=in_dim, 15 | out_channels=in_dim, kernel_size=1) 16 | self.value = nn.Conv(in_channels=in_dim, 17 | out_channels=in_dim, kernel_size=1) 18 | 19 | self.softmax = nn.Softmax(dim=-1) 20 | 21 | def execute(self, x): 22 | """ 23 | inputs : 24 | x : input feature maps( B X C X H X W) 25 | returns : 26 | out : attention value + input feature 27 | attention: B X (HxW) X (HxW) 28 | """ 29 | m_batchsize, C, height, width = x.size() 30 | proj_query = self.query(x).reshape( 31 | m_batchsize, -1, width*height).transpose(0, 2, 1) 32 | proj_key = self.key(x).reshape(m_batchsize, -1, width*height) 33 | energy = nn.bmm(proj_query, proj_key) 34 | attention = self.softmax(energy) 35 | proj_value = self.value(x).reshape(m_batchsize, -1, width*height) 36 | 37 | out = nn.bmm(proj_value, attention.transpose(0, 2, 1)) 38 | out = out.reshape(m_batchsize, C, height, width) 39 | 40 | return out 41 | 42 | 43 | def main(): 44 | attention_block = SelfAttention(64) 45 | input = jt.rand([4, 64, 32, 32]) 46 | output = attention_block(input) 47 | print(input.size(), output.size()) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /code/spatial_attentions/stn.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | import numpy as np 4 | 5 | 6 | def get_pixel_value(img, x, y): 7 | B, C, H, W = img.shape 8 | return img.reindex([B, C, H, W], ['i0', 'i1', '@e0(i0, i2, i3)','@e1(i0, i2, i3)'], extras=[x, y]) 9 | 10 | 11 | def affine_grid_generator(height, width, theta): 12 | num_batch = theta.shape[0] 13 | 14 | # create normalized 2D grid 15 | x = jt.linspace(-1.0, 1.0, width) 16 | y = jt.linspace(-1.0, 1.0, height) 17 | x_t, y_t = jt.meshgrid(x, y) 18 | 19 | # flatten 20 | x_t_flat = x_t.reshape(-1) 21 | y_t_flat = y_t.reshape(-1) 22 | print(x_t.shape) 23 | # reshape to [x_t, y_t , 1] - (homogeneous form) 24 | ones = jt.ones_like(x_t_flat) 25 | sampling_grid = jt.stack([x_t_flat, y_t_flat, ones]) 26 | 27 | # repeat grid num_batch times 28 | sampling_grid = sampling_grid.unsqueeze(0).expand(num_batch, -1, -1) 29 | 30 | 31 | # transform the sampling grid - batch multiply 32 | batch_grids = jt.matmul(theta, sampling_grid) 33 | 34 | # reshape to (num_batch, H, W, 2) 35 | batch_grids = batch_grids.reshape(num_batch, 2, height, width) 36 | return batch_grids 37 | 38 | 39 | def bilinear_sampler(img, x, y): 40 | B, C, H ,W = img.shape 41 | max_y = H - 1 42 | max_x = W - 1 43 | 44 | # rescale x and y to [0, W-1/H-1] 45 | x = 0.5 * (x + 1.0) * (max_x-1) 46 | y = 0.5 * (y + 1.0) * (max_y-1) 47 | 48 | # grab 4 nearest corner points for each (x_i, y_i) 49 | x0 = jt.floor(x).astype('int32') 50 | x1 = x0 + 1 51 | y0 = jt.floor(y).astype('int32') 52 | y1 = y0 + 1 53 | 54 | x0 = jt.minimum(jt.maximum(0, x0), max_x) 55 | x1 = jt.minimum(jt.maximum(0, x1), max_x) 56 | y0 = jt.minimum(jt.maximum(0, y0), max_y) 57 | y1 = jt.minimum(jt.maximum(0, y1), max_y) 58 | 59 | # get pixel value at corner coords 60 | Ia = get_pixel_value(img, x0, y0) 61 | Ib = get_pixel_value(img, x0, y1) 62 | Ic = get_pixel_value(img, x1, y0) 63 | Id = get_pixel_value(img, x1, y1) 64 | 65 | # calculate deltas 66 | wa = (x1-x) * (y1-y) 67 | wb = (x1-x) * (y-y0) 68 | wc = (x-x0) * (y1-y) 69 | wd = (x-x0) * (y-y0) 70 | 71 | # compute output 72 | out = wa*Ia + wb*Ib + wc*Ic + wd*Id 73 | return out 74 | 75 | class STN(nn.Module): 76 | def __init__(self): 77 | super(STN, self).__init__() 78 | 79 | def execute(self, x1, theta): 80 | B, C, H, W = x1.shape 81 | theta = theta.reshape(-1, 2, 3) 82 | 83 | batch_grids = affine_grid_generator(H, W, theta) 84 | 85 | x_s = batch_grids[:, 0, :, :] 86 | y_s = batch_grids[:, 1, :, :] 87 | 88 | out_fmap = bilinear_sampler(x1, x_s, y_s) 89 | 90 | return out_fmap 91 | 92 | 93 | def main(): 94 | stn = STN() 95 | x = jt.randn(1, 3, 224, 224) 96 | theta = jt.array(np.random.uniform(0,1,(1,6))) 97 | y = stn(x, theta) 98 | print(y) 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /code/spatial_temporal_attentions/dstt_module.py: -------------------------------------------------------------------------------- 1 | # Decoupled spatial-temporal transformer for video inpainting (arXiv 2021) 2 | import math 3 | import jittor as jt 4 | from jittor import nn 5 | 6 | 7 | class Attention(nn.Module): 8 | """ 9 | Compute 'Scaled Dot Product Attention 10 | """ 11 | 12 | def __init__(self, p=0.1): 13 | super(Attention, self).__init__() 14 | self.dropout = nn.Dropout(p=p) 15 | 16 | def execute(self, query, key, value): 17 | scores = jt.matmul(query, key.transpose(-2, -1) 18 | ) / math.sqrt(query.size(-1)) 19 | p_attn = nn.softmax(scores, dim=-1) 20 | p_attn = self.dropout(p_attn) 21 | p_val = jt.matmul(p_attn, value) 22 | return p_val, p_attn 23 | 24 | 25 | class MultiHeadedAttention(nn.Module): 26 | """ 27 | Take in model size and number of heads. 28 | """ 29 | 30 | def __init__(self, tokensize, d_model, head, mode, p=0.1): 31 | super().__init__() 32 | self.mode = mode 33 | self.query_embedding = nn.Linear(d_model, d_model) 34 | self.value_embedding = nn.Linear(d_model, d_model) 35 | self.key_embedding = nn.Linear(d_model, d_model) 36 | self.output_linear = nn.Linear(d_model, d_model) 37 | self.attention = Attention(p=p) 38 | self.head = head 39 | self.h, self.w = tokensize 40 | 41 | def execute(self, x, t): 42 | bt, n, c = x.size() 43 | b = bt // t 44 | c_h = c // self.head 45 | key = self.key_embedding(x) 46 | query = self.query_embedding(x) 47 | value = self.value_embedding(x) 48 | if self.mode == 's': 49 | key = key.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4) 50 | query = query.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4) 51 | value = value.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4) 52 | att, _ = self.attention(query, key, value) 53 | att = att.permute(0, 1, 3, 2, 4).view(bt, n, c) 54 | elif self.mode == 't': 55 | key = key.view(b, t, 2, self.h//2, 2, self.w//2, self.head, c_h) 56 | key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).view( 57 | b, 4, self.head, -1, c_h) 58 | query = query.view(b, t, 2, self.h//2, 2, 59 | self.w//2, self.head, c_h) 60 | query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).view( 61 | b, 4, self.head, -1, c_h) 62 | value = value.view(b, t, 2, self.h//2, 2, 63 | self.w//2, self.head, c_h) 64 | value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).view( 65 | b, 4, self.head, -1, c_h) 66 | att, _ = self.attention(query, key, value) 67 | att = att.view(b, 2, 2, self.head, t, self.h//2, self.w//2, c_h) 68 | att = att.permute(0, 4, 1, 5, 2, 6, 3, 69 | 7).view(bt, n, c) 70 | output = self.output_linear(att) 71 | return output 72 | 73 | 74 | def main(): 75 | attention_block_s = MultiHeadedAttention( 76 | tokensize=[4, 8], d_model=64, head=4, mode='s') 77 | attention_block_t = MultiHeadedAttention( 78 | tokensize=[4, 8], d_model=64, head=4, mode='t') 79 | input = jt.rand([8, 32, 64]) 80 | output = attention_block_s(input, 2) 81 | output = attention_block_t(output, 2) 82 | print(input.size(), output.size()) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /code/temporal_attentions/gltr.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/Awesome-Vision-Attentions/35f6f5e7fc034358e7f51683fc27101a75e24fac/code/temporal_attentions/gltr.py -------------------------------------------------------------------------------- /imgs/attention_category.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/Awesome-Vision-Attentions/35f6f5e7fc034358e7f51683fc27101a75e24fac/imgs/attention_category.png -------------------------------------------------------------------------------- /imgs/fuse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/Awesome-Vision-Attentions/35f6f5e7fc034358e7f51683fc27101a75e24fac/imgs/fuse.png -------------------------------------------------------------------------------- /imgs/fuse_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/Awesome-Vision-Attentions/35f6f5e7fc034358e7f51683fc27101a75e24fac/imgs/fuse_fig.png -------------------------------------------------------------------------------- /imgs/timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/Awesome-Vision-Attentions/35f6f5e7fc034358e7f51683fc27101a75e24fac/imgs/timeline.png --------------------------------------------------------------------------------