├── README.md ├── RemoteSAM.pdf ├── arc ├── __init__.py ├── adaptive_rotated_conv.py ├── routing_function.py └── weight_init.py ├── args.py ├── assets ├── HKUST.jpg ├── NUIST.jpg ├── Radar.png ├── RemoteSAM.png ├── RemoteSAM270K.png ├── SEU.png ├── demo.jpg └── hhu_logo.png ├── bert ├── activations.py ├── configuration_bert.py ├── configuration_utils.py ├── file_utils.py ├── generation_utils.py ├── modeling_bert.py ├── modeling_utils.py ├── tokenization_bert.py ├── tokenization_utils.py └── tokenization_utils_base.py ├── data ├── DiverseDataset.py ├── all_dataset.py ├── dataset_refer_bert.py └── datasets │ ├── CAP.py │ ├── CNT.py │ ├── DET.py │ ├── MCC.py │ ├── MLC.py │ ├── VG.py │ └── __init__.py ├── lib ├── _utils.py ├── backbone.py ├── cross_scale_interaction.py ├── logger.py ├── mask_predictor.py ├── mmcv_custom │ ├── __init__.py │ └── checkpoint.py ├── sa │ ├── functional.py │ ├── functions │ │ ├── __init__.py │ │ ├── aggregation_refpad.py │ │ ├── aggregation_zeropad.py │ │ ├── subtraction2_refpad.py │ │ ├── subtraction2_zeropad.py │ │ ├── subtraction_refpad.py │ │ ├── subtraction_zeropad.py │ │ └── utils.py │ └── modules │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── subtraction.py │ │ └── subtraction2.py ├── segmentation.py ├── transformer.py └── various_receptive.py ├── loss └── loss.py ├── refer └── refer.py ├── requirements.txt ├── tasks ├── CAP.sh ├── CNT.sh ├── DET.sh ├── DET_DOTA.sh ├── MCC.sh ├── MLC.sh ├── REF.sh ├── SEG.sh ├── VG.sh └── code │ ├── RuleBasedCaptioning.py │ ├── eval │ ├── CAP.py │ ├── CNT.py │ ├── DET.py │ ├── DET_DOTA.py │ ├── MCC.py │ ├── MLC.py │ ├── REF.py │ ├── SEG.py │ └── VG.py │ ├── metric │ ├── RunningScore.py │ └── cidereval │ │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ ├── cider_scorer.py │ │ └── data │ │ │ ├── __init__.py │ │ │ └── coco-val.p │ │ ├── eval.py │ │ ├── scorers.py │ │ └── tokenizer │ │ ├── __init__.py │ │ ├── ptbtokenizer.py │ │ ├── simpletokenizer.py │ │ └── stanford-corenlp-3.4.1.jar │ └── model.py ├── train.py ├── transforms.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # RemoteSAM: Towards Segment Anything for Earth Observation 4 | 5 | 6 | 7 | [Liang Yao (姚亮)*](https://multimodality.group/author/%E5%A7%9A%E4%BA%AE/) 8 | Logo,     9 | [Fan Liu (刘凡)*](https://multimodality.group/author/%E5%88%98%E5%87%A1/) ✉ 10 | Logo,     11 | [Delong Chen (陈德龙)*](https://chendelong.world/) 12 | Logo,     13 | 14 | [Chuanyi Zhang (张传一)](https://ai.hhu.edu.cn/2023/0809/c17670a264073/page.htm) 15 | Logo,     16 | [Yijun Wang (王翌骏)](https://multimodality.group/author/%E7%8E%8B%E7%BF%8C%E9%AA%8F/) 17 | Logo,     18 | [Ziyun Chen (陈子赟)](https://multimodality.group/author/%E9%99%88%E5%AD%90%E8%B5%9F/) 19 | Logo,     20 | 21 | [Wei Xu (许玮)](https://multimodality.group/author/%E8%AE%B8%E7%8E%AE/) 22 | Logo,     23 | [Shimin Di (邸世民)](https://cs.seu.edu.cn/shimindi/main.htm) 24 | Logo,     25 | [Yuhui Zheng (郑钰辉)](https://faculty.nuist.edu.cn/zhengyuhui/en/index.htm) 26 | Logo 27 | 28 | \* *Equal Contribution* ✉ *Corresponding Author* 29 | 30 | Model : 🤗[RemoteSAM](https://huggingface.co/1e12Leon/RemoteSAM) 31 | 32 | Dataset : 🤗[RemoteSAM-270K](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K) 33 |
34 | 35 | 36 | ## News 37 | 38 | - **2025/5/7**: We have released the model and dataset! You can download RemoteSAM-270K from 🤗[RemoteSAM-270K](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K) and checkpoint from 🤗[RemoteSAM](https://huggingface.co/1e12Leon/RemoteSAM). 39 | - **2025/5/3**: Welcome to RemoteSAM! The preprint of our paper is available. Dataset and model are open-sourced at this repository. 40 | 41 | 42 | 43 | ## Introduction 44 | Welcome to the official repository of our paper "RemoteSAM: Towards Segment Anything for Earth Observation" ! 45 | 46 | ![](assets/RemoteSAM.png) 47 | 48 | Recent advances in AI have revolutionized Earth observation, yet most remote sensing tasks still rely on specialized models with fragmented interfaces. To address this, we present **RemoteSAM**, a vision foundation model that unifies pixel-, region-, and image-level tasks through a novel architecture centered on Referring Expression Segmentation (RES). Unlike existing paradigms—task-specific heads with limited knowledge sharing or text-based models struggling with dense outputs—RemoteSAM leverages pixel-level predictions as atomic units, enabling upward compatibility to higher-level tasks while eliminating computationally heavy language model backbones. This design achieves an order-of-magnitude parameter reduction (billions to millions), enabling efficient high-resolution data processing. 49 | 50 | ![](assets/RemoteSAM270K.png) 51 | 52 | We also build **RemoteSAM-270K** dataset, a large-scale collection of 270K Image-Text-Mask triplets generated via an automated pipeline powered by vision-language models (VLMs). This dataset surpasses existing resources in semantic diversity, covering 1,000+ object categories and rich attributes (e.g., color, spatial relations) through linguistically varied prompts. We further introduce RSVocab-1K, a hierarchical semantic vocabulary to quantify dataset coverage and adaptability. 53 | 54 | ![](assets/Radar.png) 55 | 56 | ## Setting Up 57 | 58 | The code has been verified to work with PyTorch v1.13.0 and Python 3.8. 59 | 1. Clone this repository. 60 | 2. Change directory to root of this repository. 61 | 62 | ### Package Dependencies 63 | 1. Create a new Conda environment with Python 3.8 then activate it: 64 | ```shell 65 | conda create -n RemoteSAM python==3.8 66 | conda activate RemoteSAM 67 | ``` 68 | 69 | 2. Install PyTorch v1.13.0 with a CUDA version that works on your cluster/machine (CUDA 11.6 is used in this example): 70 | ```shell 71 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 72 | ``` 73 | 74 | 3. Install mmcv from openmmlab: 75 | ```shell 76 | pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html 77 | ``` 78 | 4. Install the packages in `requirements.txt` via `pip`: 79 | ```shell 80 | pip install -r requirements.txt 81 | ``` 82 | ### The Initialization Weights for Training 83 | 1. Create the `./pretrained_weights` directory where we will be storing the weights. 84 | ```shell 85 | mkdir ./pretrained_weights 86 | ``` 87 | 2. Download [pre-trained classification weights of the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth), 88 | and put the `pth` file in `./pretrained_weights`. 89 | These weights are needed for training to initialize the model. 90 | 91 | ## Data Preparation 92 | We perform all experiments based on our proposed dataset RemoteSAM-270K. 93 | 94 | ### Usage 95 | 1. Download our dataset from [HuggingFace](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K). 96 | 2. Copy all the downloaded files to `./refer/data/`. The dataset folder should be like this: 97 | ``` 98 | $DATA_PATH 99 | ├── RemoteSAM-270K 100 | │ ├── JPEGImages 101 | │ ├── Annotations 102 | └──── ├── refs(unc).p 103 | ├── instances.json 104 | ``` 105 | 106 | 107 | ## RemoteSAM 108 | 109 | ### Training 110 | We use DistributedDataParallel from PyTorch for training. To run on 8 GPUs on a single node: 111 | More training setting can be change in args.py. 112 | ```shell 113 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 114 | python -m torch.distributed.launch \ 115 | --nproc_per_node 8 --master_port 12345 train.py \ 116 | --epochs 40 --img_size 896 2>&1 | tee ./output 117 | ``` 118 | ### Getting Started 119 | 120 | To get started with RemoteSAM, please first initialize a model and load the RemoteSAM checkpoint with a few lines of code: 121 | 122 | ```python 123 | from tasks.code.model import RemoteSAM, init_demo_model 124 | import cv2 125 | import numpy as np 126 | 127 | device = 'cuda:0' 128 | checkpoint = "./pretrained_weights/checkpoint.pth" 129 | 130 | model = init_demo_model(checkpoint, device) 131 | model = RemoteSAM(model, device, use_EPOC=True) 132 | ``` 133 | 134 | Then, you can explore different tasks with RemoteSAM via: 135 | 136 | - **Referring Expression Segmentation** 137 | 138 | ```python 139 | image = cv2.imread("./assets/demo.jpg") 140 | mask = model.referring_seg(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), sentence="the airplane on the right") 141 | ``` 142 | 143 | - **Semantic Segmentation** 144 | 145 | ```python 146 | image = cv2.imread("./assets/demo.jpg") 147 | result = model.semantic_seg(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle']) 148 | for classname in ["airplane", "vehicle"]: 149 | mask = result[classname] 150 | ``` 151 | 152 | - **Object Detection** 153 | 154 | ```python 155 | image = cv2.imread("./assets/demo.jpg") 156 | result = model.detection(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle']) 157 | for classname in ["airplane", "vehicle"]: 158 | boxes = result[classname] 159 | ``` 160 | 161 | - **Visual Grounding** 162 | 163 | ```python 164 | image = cv2.imread("./assets/demo.jpg") 165 | box = model.visual_grounding(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), sentence="the airplane on the right") 166 | ``` 167 | 168 | - **Multi-label classification** 169 | 170 | ```python 171 | image = cv2.imread("./assets/demo.jpg") 172 | result = model.multi_label_cls(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle']) 173 | print(result) 174 | ``` 175 | 176 | - **Image Classification** 177 | 178 | ```python 179 | image = cv2.imread("./assets/demo.jpg") 180 | result = model.multi_class_cls(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle']) 181 | print(result) 182 | ``` 183 | 184 | - **Image Captioning** 185 | 186 | ```python 187 | image = cv2.imread("./assets/demo.jpg") 188 | result = model.captioning(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'], region_split=9) 189 | print(result) 190 | ``` 191 | 192 | - **Object Counting** 193 | 194 | ```python 195 | image = cv2.imread("./assets/demo.jpg") 196 | result = model.counting(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle']) 197 | for classname in ["airplane", "vehicle"]: 198 | print("{}: {}".format(classname, result[classname])) 199 | ``` 200 | 201 | ### Evaluation 202 | 203 | - **Evaluation of Referring Expression Segmentation** 204 | 205 | ```shell 206 | bash tasks/REF.sh 207 | ``` 208 | 209 | - **Evaluation of Semantic Segmentation** 210 | 211 | ```shell 212 | bash tasks/SEG.sh 213 | ``` 214 | 215 | - **Evaluation of Object Detection** 216 | 217 | ```shell 218 | bash tasks/DET.sh 219 | ``` 220 | - **Evaluation of Visual Grounding** 221 | 222 | ```shell 223 | bash tasks/VG.sh 224 | ``` 225 | - **Evaluation of Multi-label classification** 226 | 227 | ```shell 228 | bash tasks/MLC.sh 229 | ``` 230 | - **Evaluation of Image classification** 231 | 232 | ```shell 233 | bash tasks/MCC.sh 234 | ``` 235 | - **Evaluation of Image Captioning** 236 | 237 | ```shell 238 | bash tasks/CAP.sh 239 | ``` 240 | - **Evaluation of Object Counting** 241 | 242 | ```shell 243 | bash tasks/CNT.sh 244 | ``` 245 | 246 | ## Acknowledge 247 | - Thanks Lu Wang (王璐) for his efforts on the RemoteSAM-270K dataset. 248 | - Code in this repository is built on [RMSIN](https://github.com/Lsan2401/RMSIN). We'd like to thank the authors for open sourcing their project. 249 | 250 | ## Contact 251 | Please Contact yaoliang@hhu.edu.cn 252 | 253 | 254 | ## Cite 255 | If you find this work useful, please cite our paper as: 256 | ```bibtex 257 | @misc{yao2025RemoteSAM, 258 | title={RemoteSAM: Towards Segment Anything for Earth Observation}, 259 | author={Liang Yao and Fan Liu and Delong Chen and Chuanyi Zhang and Yijun Wang and Ziyun Chen and Wei Xu and Shimin Di and Yuhui Zheng}, 260 | year={2025}, 261 | eprint={2505.18022}, 262 | archivePrefix={arXiv}, 263 | primaryClass={cs.CV}, 264 | url={https://arxiv.org/abs/2505.18022}, 265 | } 266 | ``` 267 | -------------------------------------------------------------------------------- /RemoteSAM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/RemoteSAM.pdf -------------------------------------------------------------------------------- /arc/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_rotated_conv import AdaptiveRotatedConv2d 2 | from .routing_function import RountingFunction 3 | 4 | __all__ = [ 5 | 'AdaptiveRotatedConv2d', 'RountingFunction', 6 | ] 7 | -------------------------------------------------------------------------------- /arc/adaptive_rotated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | __all__ = ['AdaptiveRotatedConv2d'] 7 | 8 | 9 | def _get_rotation_matrix(thetas): 10 | bs, g = thetas.shape 11 | device = thetas.device 12 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n] 13 | 14 | x = torch.cos(thetas) 15 | y = torch.sin(thetas) 16 | x = x.unsqueeze(0).unsqueeze(0) # shape = [1, 1, bs * g] 17 | y = y.unsqueeze(0).unsqueeze(0) 18 | a = x - y 19 | b = x * y 20 | c = x + y 21 | 22 | rot_mat_positive = torch.cat(( 23 | torch.cat((a, 1-a, torch.zeros(1, 7, bs*g, device=device)), dim=1), 24 | torch.cat((torch.zeros(1, 1, bs*g, device=device), x-b, b, torch.zeros(1, 1, bs*g, device=device), 1-c+b, y-b, torch.zeros(1, 3, bs*g, device=device)), dim=1), 25 | torch.cat((torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device), 1-a, torch.zeros(1, 3, bs*g, device=device)), dim=1), 26 | torch.cat((b, y-b, torch.zeros(1,1 , bs*g, device=device), x-b, 1-c+b, torch.zeros(1, 4, bs*g, device=device)), dim=1), 27 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1), 28 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-c+b, x-b, torch.zeros(1, 1, bs*g, device=device), y-b, b), dim=1), 29 | torch.cat((torch.zeros(1, 3, bs*g, device=device), 1-a, torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device)), dim=1), 30 | torch.cat((torch.zeros(1, 3, bs*g, device=device), y-b, 1-c+b, torch.zeros(1, 1, bs*g, device=device), b, x-b, torch.zeros(1, 1, bs*g, device=device)), dim=1), 31 | torch.cat((torch.zeros(1, 7, bs*g, device=device), 1-a, a), dim=1) 32 | ), dim=0) # shape = [k^2, k^2, bs*g] 33 | 34 | rot_mat_negative = torch.cat(( 35 | torch.cat((c, torch.zeros(1, 2, bs*g, device=device), 1-c, torch.zeros(1, 5, bs*g, device=device)), dim=1), 36 | torch.cat((-b, x+b, torch.zeros(1, 1, bs*g, device=device), b-y, 1-a-b, torch.zeros(1, 4, bs*g, device=device)), dim=1), 37 | torch.cat((torch.zeros(1, 1, bs*g, device=device), 1-c, c, torch.zeros(1, 6, bs*g, device=device)), dim=1), 38 | torch.cat((torch.zeros(1, 3, bs*g, device=device), x+b, 1-a-b, torch.zeros(1, 1, bs*g, device=device), -b, b-y, torch.zeros(1, 1, bs*g, device=device)), dim=1), 39 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1), 40 | torch.cat((torch.zeros(1, 1, bs*g, device=device), b-y, -b, torch.zeros(1, 1, bs*g, device=device), 1-a-b, x+b, torch.zeros(1, 3, bs*g, device=device)), dim=1), 41 | torch.cat((torch.zeros(1, 6, bs*g, device=device), c, 1-c, torch.zeros(1, 1, bs*g, device=device)), dim=1), 42 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-a-b, b-y, torch.zeros(1, 1, bs*g, device=device), x+b, -b), dim=1), 43 | torch.cat((torch.zeros(1, 5, bs*g, device=device), 1-c, torch.zeros(1, 2, bs*g, device=device), c), dim=1) 44 | ), dim=0) # shape = [k^2, k^2, bs*g] 45 | 46 | mask = (thetas >= 0).unsqueeze(0).unsqueeze(0) 47 | mask = mask.float() # shape = [1, 1, bs*g] 48 | rot_mat = mask * rot_mat_positive + (1 - mask) * rot_mat_negative # shape = [k*k, k*k, bs*g] 49 | rot_mat = rot_mat.permute(2, 0, 1) # shape = [bs*g, k*k, k*k] 50 | rot_mat = rot_mat.reshape(bs, g, rot_mat.shape[1], rot_mat.shape[2]) # shape = [bs, g, k*k, k*k] 51 | return rot_mat 52 | 53 | 54 | def batch_rotate_multiweight(weights, lambdas, thetas): 55 | """ 56 | Let 57 | batch_size = b 58 | kernel_number = n 59 | kernel_size = 3 60 | Args: 61 | weights: tensor, shape = [kernel_number, Cout, Cin, k, k] 62 | thetas: tensor of thetas, shape = [batch_size, kernel_number] 63 | Return: 64 | weights_out: tensor, shape = [batch_size x Cout, Cin // groups, k, k] 65 | """ 66 | assert(thetas.shape == lambdas.shape) 67 | assert(lambdas.shape[1] == weights.shape[0]) 68 | 69 | b = thetas.shape[0] 70 | n = thetas.shape[1] 71 | k = weights.shape[-1] 72 | _, Cout, Cin, _, _ = weights.shape 73 | 74 | if k == 3 : 75 | # Stage 1: 76 | # input: thetas: [b, n] 77 | # lambdas: [b, n] 78 | # output: rotation_matrix: [b, n, 9, 9] (with gate) --> [b*9, n*9] 79 | 80 | # Sub_Stage 1.1: 81 | # input: [b, n] kernel 82 | # output: [b, n, 9, 9] rotation matrix 83 | rotation_matrix = _get_rotation_matrix(thetas) 84 | 85 | # Sub_Stage 1.2: 86 | # input: [b, n, 9, 9] rotation matrix 87 | # [b, n] lambdas 88 | # --> [b, n, 1, 1] lambdas 89 | # --> [b, n, 1, 1] lambdas dot [b, n, 9, 9] rotation matrix 90 | # --> [b, n, 9, 9] rotation matrix with gate (done) 91 | # output: [b, n, 9, 9] rotation matrix with gate 92 | lambdas = lambdas.unsqueeze(2).unsqueeze(3) 93 | rotation_matrix = torch.mul(rotation_matrix, lambdas) 94 | 95 | # Sub_Stage 1.3: Reshape 96 | # input: [b, n, 9, 9] rotation matrix with gate 97 | # output: [b*9, n*9] rotation matrix with gate 98 | rotation_matrix = rotation_matrix.permute(0, 2, 1, 3) 99 | rotation_matrix = rotation_matrix.reshape(b*k*k, n*k*k) 100 | 101 | # Stage 2: Reshape 102 | # input: weights: [n, Cout, Cin, 3, 3] 103 | # --> [n, 3, 3, Cout, Cin] 104 | # --> [n*9, Cout*Cin] done 105 | # output: weights: [n*9, Cout*Cin] 106 | weights = weights.permute(0, 3, 4, 1, 2) 107 | weights = weights.contiguous().view(n*k*k, Cout*Cin) 108 | 109 | 110 | # Stage 3: torch.mm 111 | # [b*9, n*9] x [n*9, Cout*Cin] 112 | # --> [b*9, Cout*Cin] 113 | weights = torch.mm(rotation_matrix, weights) 114 | 115 | # Stage 4: Reshape Back 116 | # input: [b*9, Cout*Cin] 117 | # --> [b, 3, 3, Cout, Cin] 118 | # --> [b, Cout, Cin, 3, 3] 119 | # --> [b * Cout, Cin, 3, 3] done 120 | # output: [b * Cout, Cin, 3, 3] 121 | weights = weights.contiguous().view(b, k, k, Cout, Cin) 122 | weights = weights.permute(0, 3, 4, 1, 2) 123 | weights = weights.reshape(b * Cout, Cin, k, k) 124 | else: 125 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n] 126 | 127 | x = torch.cos(thetas) 128 | y = torch.sin(thetas) 129 | rotate_matrix = torch.tensor([[x, -y, 0], [y, x, 0]]) 130 | rotate_matrix = rotate_matrix.unsqueeze(0).repeat(n, 1, 1) 131 | 132 | weights = weights.contiguous().view(n, Cout*Cin, k, k) 133 | 134 | grid = F.affine_grid(rotate_matrix, weights.shape) 135 | weights = F.grid_sample(weights, grid, mode='biliner') 136 | 137 | return weights 138 | 139 | 140 | class AdaptiveRotatedConv2d(nn.Module): 141 | 142 | def __init__(self, in_channels, out_channels, kernel_size, 143 | stride=1, padding=1, dilation=1, groups=1, bias=False, 144 | kernel_number=1, rounting_func=None, rotate_func=batch_rotate_multiweight): 145 | super().__init__() 146 | self.kernel_number = kernel_number 147 | self.in_channels = in_channels 148 | self.out_channels = out_channels 149 | self.kernel_size = kernel_size 150 | self.stride = stride 151 | self.padding = padding 152 | self.dilation = dilation 153 | self.groups = groups 154 | self.bias = bias 155 | 156 | self.rounting_func = rounting_func 157 | self.rotate_func = rotate_func 158 | 159 | self.weight = nn.Parameter( 160 | torch.Tensor( 161 | kernel_number, 162 | out_channels, 163 | in_channels // groups, 164 | kernel_size, 165 | kernel_size, 166 | ) 167 | ) 168 | nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') 169 | 170 | def forward(self, x): 171 | # get alphas, angles 172 | # # [bs, Cin, h, w] --> [bs, n_theta], [bs, n_theta] 173 | alphas, angles = self.rounting_func(x) 174 | 175 | # rotate weight 176 | # # [Cout, Cin, k, k] --> [bs * Cout, Cin, k, k] 177 | # print(self.weight.shape) 178 | rotated_weight = self.rotate_func(self.weight, alphas, angles) 179 | 180 | # reshape images 181 | bs, Cin, h, w = x.shape 182 | x = x.reshape(1, bs * Cin, h, w) # [1, bs * Cin, h, w] 183 | 184 | # adaptive conv over images using group conv 185 | out = F.conv2d(input=x, weight=rotated_weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=(self.groups * bs)) 186 | 187 | # reshape back 188 | out = out.reshape(bs, self.out_channels, *out.shape[2:]) 189 | return out 190 | 191 | def extra_repr(self): 192 | s = ('{in_channels}, {out_channels}, kernel_number={kernel_number}' 193 | ', kernel_size={kernel_size}, stride={stride}, bias={bias}') 194 | 195 | if self.padding != (0,) * len([self.padding]): 196 | s += ', padding={padding}' 197 | if self.dilation != (1,) * len([self.dilation]): 198 | s += ', dilation={dilation}' 199 | if self.groups != 1: 200 | s += ', groups={groups}' 201 | return s.format(**self.__dict__) 202 | -------------------------------------------------------------------------------- /arc/routing_function.py: -------------------------------------------------------------------------------- 1 | import math 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | from .weight_init import trunc_normal_ 6 | 7 | 8 | class LayerNormProxy(nn.Module): 9 | # copy from https://github.com/LeapLabTHU/DAT/blob/main/models/dat_blocks.py 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.norm = nn.LayerNorm(dim) 13 | 14 | def forward(self, x): 15 | x = einops.rearrange(x, 'b c h w -> b h w c') 16 | x = self.norm(x) 17 | return einops.rearrange(x, 'b h w c -> b c h w') 18 | 19 | 20 | class RountingFunction(nn.Module): 21 | 22 | def __init__(self, in_channels, kernel_number, dropout_rate=0.2, proportion=40.0): 23 | super().__init__() 24 | self.kernel_number = kernel_number 25 | self.dwc = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, 26 | groups=in_channels, bias=False) 27 | self.norm = LayerNormProxy(in_channels) 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 31 | 32 | self.dropout1 = nn.Dropout(dropout_rate) 33 | self.fc_alpha = nn.Linear(in_channels, kernel_number, bias=True) 34 | 35 | self.dropout2= nn.Dropout(dropout_rate) 36 | self.fc_theta = nn.Linear(in_channels, kernel_number, bias=False) 37 | 38 | self.act_func = nn.Softsign() 39 | self.proportion = proportion / 180.0 * math.pi 40 | 41 | # init weights 42 | trunc_normal_(self.dwc.weight, std=.02) 43 | trunc_normal_(self.fc_alpha.weight, std=.02) 44 | trunc_normal_(self.fc_theta.weight, std=.02) 45 | 46 | def forward(self, x): 47 | 48 | x = self.dwc(x) 49 | x = self.norm(x) 50 | x = self.relu(x) 51 | 52 | x = self.avg_pool(x).squeeze(dim=-1).squeeze(dim=-1) # avg_x.shape = [batch_size, Cin] 53 | 54 | alphas = self.dropout1(x) 55 | alphas = self.fc_alpha(alphas) 56 | alphas = torch.sigmoid(alphas) 57 | 58 | angles = self.dropout2(x) 59 | angles = self.fc_theta(angles) 60 | angles = self.act_func(angles) 61 | angles = angles * self.proportion 62 | 63 | return alphas, angles 64 | 65 | def extra_repr(self): 66 | s = (f'kernel_number={self.kernel_number}') 67 | return s.format(**self.__dict__) 68 | -------------------------------------------------------------------------------- /arc/weight_init.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # get from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/weight_init.py 3 | # -------------------------------------------------------- 4 | import torch 5 | import math 6 | import warnings 7 | 8 | 9 | def _trunc_normal_(tensor, mean, std, a, b): 10 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 11 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 12 | def norm_cdf(x): 13 | # Computes standard normal cumulative distribution function 14 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 15 | 16 | if (mean < a - 2 * std) or (mean > b + 2 * std): 17 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 18 | "The distribution of values may be incorrect.", 19 | stacklevel=2) 20 | 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 53 | applied while sampling the normal with mean/std applied, therefore a, b args 54 | should be adjusted to match the range of mean, std args. 55 | Args: 56 | tensor: an n-dimensional `torch.Tensor` 57 | mean: the mean of the normal distribution 58 | std: the standard deviation of the normal distribution 59 | a: the minimum cutoff value 60 | b: the maximum cutoff value 61 | Examples: 62 | >>> w = torch.empty(3, 5) 63 | >>> nn.init.trunc_normal_(w) 64 | """ 65 | with torch.no_grad(): 66 | return _trunc_normal_(tensor, mean, std, a, b) 67 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser(description='RemoteSAM training and testing') 6 | parser.add_argument('--amsgrad', action='store_true', 7 | help='if true, set amsgrad to True in an Adam or AdamW optimizer.') 8 | parser.add_argument('-b', '--batch-size', default=2, type=int) 9 | parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') 10 | parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') 11 | parser.add_argument('--dataset', default='rrsisd', help='dataset name') 12 | parser.add_argument('--ddp_trained_weights', action='store_true', 13 | help='Only needs specified when testing,' 14 | 'whether the weights to be loaded are from a DDP-trained model') 15 | parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine 16 | parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') 17 | parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') 18 | parser.add_argument('--img_size', default=896, type=int, help='input image size') 19 | parser.add_argument("--local_rank", type=int,default=0,help='local rank for DistributedDataParallel') 20 | parser.add_argument('--lr', default=0.00003, type=float, help='the initial learning rate') # 0.00003 21 | parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' 22 | 'where a, b, c, and d refer to the numbers of heads in stage-1,' 23 | 'stage-2, stage-3, and stage-4 PWAMs') 24 | parser.add_argument('--model', default='lavt_one', help='model: lavt, lavt_one') 25 | parser.add_argument('--model_id', default='RemoteSAM', help='name to identify the model') 26 | parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') 27 | parser.add_argument('--pin_mem', action='store_true', 28 | help='If true, pin memory when using the data loader.') 29 | parser.add_argument('--pretrained_swin_weights', default='./pretrained_weights/swin_base_patch4_window12_384_22k.pth', 30 | help='path to pre-trained Swin backbone weights') 31 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 32 | parser.add_argument('--refer_data_root', default='/data/dishimin/RemoteSAM-270K', help='REFER dataset root directory') 33 | parser.add_argument('--resume', default='', help='resume from checkpoint') 34 | parser.add_argument('--split', default='test', help='only used when testing') 35 | parser.add_argument('--splitBy', default='unc', help='change to umd or google when the datasset is G-Ref (RefCOCOg)') 36 | parser.add_argument('--swin_type', default='base', 37 | help='tiny, small, base, or large variants of the Swin Transformer') 38 | parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', 39 | dest='weight_decay') 40 | parser.add_argument('--window12', action='store_true', 41 | help='only needs specified when testing,' 42 | 'when training, window size is inferred from pre-trained weights file name' 43 | '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') 44 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') 45 | parser.add_argument("--imageFolder", default='/data/RemoteSAM-270K/JPEGImages/', type=str, 46 | help="imageFolder path") 47 | parser.add_argument("--ref_file_path", default='/data/RemoteSAM-270K/refs(unc)_RemoteSAM.p', type=str, 48 | help="ref_file_path") 49 | parser.add_argument("--instances_path", default='/data/RemoteSAM-270K/instances.json', type=str, 50 | help="instances.json's path") 51 | parser.add_argument("--eval", action="store_true", help="Only run evaluation") 52 | parser.add_argument('--EPOC', action='store_true', help='if true, use EPOC') 53 | parser.add_argument('--save_path', default='', help='save_path') 54 | parser.add_argument('--task', type=str, choices=['CNT', 'DET', 'CAP', 'VG', 'MCC', 'MLC','SEG','REF'], 55 | help='For tasks other than referring segmentation and semantic segmentation, specify the task') 56 | parser.add_argument('--annoFolder', default='/data/dishimin/iSAID/rrsisd/Annotations/', type=str, help="annoFolder path") 57 | parser.add_argument('--sentence', default='', help='sentence input') 58 | parser.add_argument('--_class', default='', help='class input') 59 | 60 | return parser 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = get_parser() 65 | args_dict = parser.parse_args() 66 | -------------------------------------------------------------------------------- /assets/HKUST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/HKUST.jpg -------------------------------------------------------------------------------- /assets/NUIST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/NUIST.jpg -------------------------------------------------------------------------------- /assets/Radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/Radar.png -------------------------------------------------------------------------------- /assets/RemoteSAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/RemoteSAM.png -------------------------------------------------------------------------------- /assets/RemoteSAM270K.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/RemoteSAM270K.png -------------------------------------------------------------------------------- /assets/SEU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/SEU.png -------------------------------------------------------------------------------- /assets/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/demo.jpg -------------------------------------------------------------------------------- /assets/hhu_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/hhu_logo.png -------------------------------------------------------------------------------- /bert/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | 37 | 38 | def gelu_fast(x): 39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 40 | 41 | 42 | ACT2FN = { 43 | "relu": F.relu, 44 | "swish": swish, 45 | "gelu": gelu, 46 | "tanh": torch.tanh, 47 | "gelu_new": gelu_new, 48 | "gelu_fast": gelu_fast, 49 | } 50 | 51 | 52 | def get_activation(activation_string): 53 | if activation_string in ACT2FN: 54 | return ACT2FN[activation_string] 55 | else: 56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 57 | -------------------------------------------------------------------------------- /bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 56 | It is used to instantiate an BERT model according to the specified arguments, defining the model 57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 58 | the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 62 | for more information. 63 | 64 | 65 | Args: 66 | vocab_size (:obj:`int`, optional, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the different tokens that 68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 69 | hidden_size (:obj:`int`, optional, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, optional, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, optional, defaults to 3072): 76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 78 | The non-linear activation function (function or string) in the encoder and pooler. 79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. 86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, optional, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 89 | initializer_range (:obj:`float`, optional, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 95 | 96 | Example:: 97 | 98 | >>> from transformers import BertModel, BertConfig 99 | 100 | >>> # Initializing a BERT bert-base-uncased style configuration 101 | >>> configuration = BertConfig() 102 | 103 | >>> # Initializing a model from the bert-base-uncased style configuration 104 | >>> model = BertModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | """ 109 | model_type = "bert" 110 | 111 | def __init__( 112 | self, 113 | vocab_size=30522, 114 | hidden_size=768, 115 | num_hidden_layers=12, 116 | num_attention_heads=12, 117 | intermediate_size=3072, 118 | hidden_act="gelu", 119 | hidden_dropout_prob=0.1, 120 | attention_probs_dropout_prob=0.1, 121 | max_position_embeddings=512, 122 | type_vocab_size=2, 123 | initializer_range=0.02, 124 | layer_norm_eps=1e-12, 125 | pad_token_id=0, 126 | gradient_checkpointing=False, 127 | **kwargs 128 | ): 129 | super().__init__(pad_token_id=pad_token_id, **kwargs) 130 | 131 | self.vocab_size = vocab_size 132 | self.hidden_size = hidden_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.hidden_act = hidden_act 136 | self.intermediate_size = intermediate_size 137 | self.hidden_dropout_prob = hidden_dropout_prob 138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 139 | self.max_position_embeddings = max_position_embeddings 140 | self.type_vocab_size = type_vocab_size 141 | self.initializer_range = initializer_range 142 | self.layer_norm_eps = layer_norm_eps 143 | self.gradient_checkpointing = gradient_checkpointing 144 | -------------------------------------------------------------------------------- /data/DiverseDataset.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | 3 | def DiverseDataset(args, image_transforms=None, max_tokens=20): 4 | 5 | # return image, gt, image_name, foo(optional) 6 | 7 | task = args.task 8 | dataset = args.dataset 9 | 10 | assert task in ['DET', 'CNT', 'MLC', 'MCC', 'CAP', 'VG'], "task must be in ['DET', 'CNT', 'MLC', 'MCC', 'CAP', 'VG'], but got {}".format(task) 11 | 12 | if task == 'CNT': 13 | assert dataset in ['DIOR', 'DOTA'], "dataset must be in ['DIOR', 'DOTA'], but got {}".format(dataset) 14 | if dataset == 'DIOR': 15 | # foo := classname 16 | return DIOR_CNT(args, max_tokens=max_tokens) 17 | elif dataset == 'DOTA': 18 | # foo := classname 19 | return DOTA_CNT(args, max_tokens=max_tokens) 20 | 21 | elif task == 'VG': 22 | assert dataset in ['RSVG'], "dataset must be in ['RSVG'], but got {}".format(dataset) 23 | if dataset == 'RSVG': 24 | return RSVG_VG(args, max_tokens=max_tokens) 25 | 26 | elif task == 'DET': 27 | assert dataset in ['DIOR', 'DOTAv2', 'DOTAv1'], "dataset must be in ['DIOR', 'DOTAv2', 'DOTAv1'], but got {}".format(dataset) 28 | if dataset == 'DIOR': 29 | # foo := (image_id, classindex) 30 | return DIOR_DET(args, max_tokens=max_tokens) 31 | elif dataset == 'DOTAv2': 32 | return DOTAv2_DET(args, max_tokens=max_tokens) 33 | elif dataset == 'DOTAv1': 34 | return DOTAv1_DET(args, max_tokens=max_tokens) 35 | 36 | elif task == 'MLC': 37 | assert dataset in ['DIOR', 'DOTAv2'], "dataset must be in ['DIOR', 'DOTAv2'], but got {}".format(dataset) 38 | if dataset == 'DIOR': 39 | return DIOR_MLC(args, max_tokens=max_tokens) 40 | elif dataset == 'DOTAv2': 41 | return DOTAv2_MLC(args, max_tokens=max_tokens) 42 | 43 | elif task == 'MCC': 44 | assert dataset in ['UCM', 'AID'], "dataset must be in ['UCM', 'AID'], but got {}".format(dataset) 45 | if dataset == 'UCM': 46 | return UCM_MCC(args, max_tokens=max_tokens) 47 | if dataset == 'AID': 48 | return AID_MCC(args, max_tokens=max_tokens) 49 | 50 | elif task == 'CAP': 51 | assert dataset in ['UCM'], "dataset must be in ['UCM'], but got {}".format(dataset) 52 | if dataset == 'UCM': 53 | return UCM_CAP(args, max_tokens=max_tokens) -------------------------------------------------------------------------------- /data/all_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import cv2 5 | from PIL import Image 6 | import torch.utils.data as data 7 | import transformers 8 | from tqdm import tqdm # 导入tqdm库用于显示进度条 9 | 10 | import pickle 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | import random 16 | from pycocotools import mask 17 | 18 | class All_Dataset(data.Dataset): 19 | def __init__(self, 20 | args, 21 | image_transforms=None, 22 | max_tokens=20, 23 | split='train', 24 | eval_mode=True, 25 | logger=None) -> None: 26 | """ 27 | parameters: 28 | args: argparse obj 29 | image_transforms: transforms apply to image and mask 30 | max_tokens: determined the max length of token 31 | split: ['train','val','testA','testB'] 32 | eval_mode: whether in training or evaluating 33 | """ 34 | 35 | self.classes = [] 36 | self.image_transforms = image_transforms 37 | self.split = split 38 | self.args = args 39 | self.eval_mode = eval_mode 40 | self.max_tokens = max_tokens 41 | self.image_path = args.imageFolder 42 | self.ref_file = args.ref_file_path 43 | self.instances_path=args.instances_path 44 | self.tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') 45 | 46 | self.maskToSentence,self.maskToimageName,self.maskTosplit,self.maskToCategory,self.maskToDecodedRLE,self.masktoArea,self.list_mask = self.load_captions() 47 | self.refToInput,self.refToAttention,self.refToExist=self.get_sent_embeddings() 48 | if logger: 49 | logger.info(f"=> loaded successfully '{args.dataset}', split by {args.splitBy}, split {split}") 50 | #load ref ,find caption ,image,according to mask 51 | def load_captions(self): 52 | 53 | maskToSentence = {} 54 | maskToimageName = {} 55 | maskTosplit = {} 56 | maskToCategory = {} 57 | masktoRLE={} 58 | masktoArea={} 59 | maskToDecodedRLE={} 60 | list_mask=[] 61 | with open(self.ref_file, 'rb') as f: 62 | data = pickle.load(f) 63 | total_lines = len(data) # 获取总数据行数,用于进度条 64 | pbar = tqdm(data, desc="Loading captions", total=total_lines) # 创建进度条对象 65 | for line in pbar: 66 | split_line = line["split"] 67 | ann_id = line["ann_id"] 68 | img_name = line["file_name"] 69 | img_id = line["image_id"] 70 | category_id = line["category_id"] 71 | if split_line == self.split : 72 | list_mask.append(ann_id) 73 | 74 | 75 | maskToSentence[ann_id] = line["sentences"][0]["sent"] 76 | maskToimageName[ann_id]=img_name 77 | maskTosplit[ann_id]=split_line 78 | maskToCategory[ann_id]=category_id 79 | 80 | # Only load the ann_id of the corresponding training set/validation set/test set. 81 | pbar.close() # 关闭进度条 82 | with open(self.instances_path, 'r') as fi: 83 | data_json = json.load(fi) 84 | annotations = data_json['annotations'] 85 | total_annotations = len(annotations) # 获取总注释数量,用于进度条 86 | pbar = tqdm(annotations, desc="Loading instance data", total=total_annotations) # 创建进度条对象 87 | for data in pbar: 88 | ann_id = data["id"] 89 | 90 | rle = data['segmentation'] 91 | # decoded_mask = mask.decode(rle) # 假设mask.decode是解码RLE的函数 92 | maskToDecodedRLE[ann_id] = rle 93 | 94 | masktoArea[ann_id] = data['area'] 95 | pbar.close() # 关闭进度条 96 | 97 | return maskToSentence,maskToimageName,maskTosplit,maskToCategory,maskToDecodedRLE,masktoArea,list_mask 98 | def get_sent_embeddings(self): 99 | refToInput={} 100 | refToAttention={} 101 | refToExist={} 102 | total_mask_ids = len(self.list_mask) # 获取mask_id的总数,用于进度条 103 | pbar = tqdm(self.list_mask, desc="Generating sentence embeddings", total=total_mask_ids) # 创建进度条对象 104 | for mask_id in pbar: 105 | attention_mask = [0] * self.max_tokens 106 | padded_input_id = [0] * self.max_tokens 107 | ref=self.maskToSentence[mask_id] 108 | input_id=self.tokenizer.encode(text=ref, add_special_tokens=True) 109 | input_id = input_id[:self.max_tokens] 110 | 111 | padded_input_id[:len(input_id)] = input_id 112 | attention_mask[:len(input_id)] = [1] * len(input_id) 113 | 114 | input_id=torch.tensor(padded_input_id).unsqueeze(0) 115 | attention_mask=torch.tensor(attention_mask).unsqueeze(0) 116 | exist = torch.Tensor([True]) 117 | refToInput[mask_id]=input_id 118 | refToAttention[mask_id]=attention_mask 119 | refToExist[mask_id]=exist 120 | pbar.close() # 关闭进度条 121 | return refToInput,refToAttention,refToExist 122 | def get_exist(self, ref, sent_index): 123 | if "exist" in ref["sentences"][sent_index].keys(): 124 | exist = torch.Tensor([ref["sentences"][sent_index]["exist"]]) 125 | else: 126 | exist = torch.Tensor([True]) 127 | return exist 128 | 129 | def __len__(self): 130 | return len(self.list_mask) 131 | 132 | def __getitem__(self, index): 133 | mask_id=self.list_mask[index] 134 | img_name = self.maskToimageName[mask_id] 135 | # ref=self.maskToSentence[mask_id] 136 | # split=self.maskTosplit[mask_id] 137 | try: 138 | img = Image.open(os.path.join(self.image_path, img_name)).convert("RGB") 139 | except (OSError, ValueError) as e: 140 | # print(f"Error loading image {img_name}: {e}") 141 | # 尝试更换扩展名 142 | if img_name.lower().endswith('.png'): 143 | new_img_name = img_name[:-4] + '.jpg' # 替换为 .jpg 144 | img = Image.open(os.path.join(self.image_path, new_img_name)).convert("RGB") 145 | elif img_name.lower().endswith('.jpg'): 146 | new_img_name = img_name[:-4] + '.png' # 替换为 .png 147 | img = Image.open(os.path.join(self.image_path, new_img_name)).convert("RGB") 148 | else: 149 | print("Unsupported file type.") 150 | return None 151 | 152 | rle = self.maskToDecodedRLE[mask_id] 153 | m = mask.decode(rle) # 假设mask.decode是解码RLE的函数 154 | 155 | m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) 156 | ref_mask = m.astype(np.uint8) # convert to np.uint8 157 | # compute area 158 | 159 | 160 | annot = np.zeros(ref_mask.shape) 161 | annot[ref_mask == 1] = 1 162 | # convert it to a Pillow image 163 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 164 | 165 | # tensor_embeddings,attention_mask,exist = self.get_sent_embeddings(ref) 166 | tensor_embeddings=self.refToInput[mask_id] 167 | attention_mask=self.refToAttention[mask_id] 168 | exist=self.refToExist[mask_id] 169 | 170 | if self.image_transforms is not None: 171 | # involves transform from PIL to tensor and mean and std normalization 172 | img, target = self.image_transforms(img, annot) 173 | else: 174 | target = annot 175 | if self.eval_mode: 176 | 177 | # tensor_embeddings=tensor_embeddings.unsqueeze(-1) 178 | # attention_mask=attention_mask.unsqueeze(-1) 179 | 180 | exist=exist.unsqueeze(1) 181 | 182 | 183 | return img, target, tensor_embeddings, attention_mask, exist 184 | 185 | # from dataset.transform import get_transform 186 | # import os 187 | # import sys 188 | # import torch.utils.data as data 189 | # import torch 190 | # import numpy as np 191 | # from PIL import Image 192 | # import transformers 193 | # import argparse 194 | # 195 | # from torchvision import transforms 196 | # 197 | # def main(): 198 | # # 模拟命令行参数,你可以根据实际需求调整这些参数值 199 | # args = argparse.Namespace( 200 | # imageFolder='/home/amax/yaoliang/rmsin/refer/data/images/rrsisd/JPEGImages', # 替换为真实的 refer 数据根目录路径 201 | # ref_file_path='/home/amax/yaoliang/rmsin/refer/data/rrsisd/refs(unc).p', # 替换为实际的数据集名称 202 | # instances_path='/home/amax/yaoliang/rmsin/refer/data/rrsisd/instances.json' # 替换为实际的划分方式 203 | # ) 204 | # # 定义简单的图像变换,这里仅示例,你可以根据实际需求完善或替换 205 | # image_transforms = transforms.Compose([ 206 | # transforms.Resize((480, 480)), 207 | # transforms.ToTensor() 208 | # ]) 209 | # 210 | # # 实例化 ReferDataset 211 | # dataset = All_Dataset( 212 | # args, 213 | # image_transforms=image_transforms, 214 | # max_tokens=20, 215 | # split='val', 216 | # eval_mode=True 217 | # ) 218 | # 219 | # print(f"数据集长度: {len(dataset)}") 220 | # 221 | # # 尝试获取第一个数据样本进行查看 222 | # if len(dataset) > 0: 223 | # img, target, tensor_embeddings, attention_mask, exist = dataset[0] 224 | # 225 | # print("图像数据形状:", img.shape) 226 | # print("目标数据类型:", type(target)) 227 | # if isinstance(target, torch.Tensor): 228 | # print("目标数据形状:", target.shape) 229 | # print("嵌入向量形状:", tensor_embeddings.shape) 230 | # print("注意力掩码形状:", attention_mask.shape) 231 | # print("存在性张量形状:", exist.shape) 232 | # 233 | # if __name__ == "__main__": 234 | # main() -------------------------------------------------------------------------------- /data/datasets/CAP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | class UCM_CAP(torch.utils.data.Dataset): 10 | 11 | # potential classes 12 | classes = ["road", "highway", "bank", "tenniscourt", "roof", "vehicle", "parking lot", "airplane", "stand", "farmland", "waste land", "basketballcourt", "storagetank", "pedestrian", "villa", "beach", "boat", "tree", "grass", "water", "trail", "park", "bush", "golffield", "house", "airport", "people", "overpass", "baseballfield", "swimming-pool", "bunker", "building", "harbor", "residential", "sand", "stadium", "intersection"] 13 | 14 | def __init__(self, args, max_tokens=20) : 15 | 16 | self.args = args 17 | self.split = args.split 18 | self.max_tokens = max_tokens 19 | 20 | self.image_path = args.imageFolder 21 | self.anno_path = args.annoFolder 22 | 23 | self.image_list, self.image2groundtruth, self.image2id = self.load_data() 24 | 25 | 26 | def load_data(self): 27 | 28 | image_list = [] 29 | image2groundtruth = {} 30 | image2id = {} 31 | 32 | with open(self.anno_path, 'r') as f: 33 | data = json.load(f) 34 | 35 | for item in tqdm(data['images'], desc="Collecting groundtruth", total=len(data['images'])): 36 | if item['split'] != self.split: 37 | continue 38 | 39 | image_name = item["filename"] 40 | if image_name not in image2groundtruth: 41 | image2groundtruth[image_name] = [] 42 | image_list.append(image_name) 43 | image2id[image_name] = item["imgid"] 44 | 45 | for gt in item["sentences"]: 46 | caption = gt["raw"] 47 | image2groundtruth[image_name].append(caption) 48 | 49 | return sorted(image_list), image2groundtruth, image2id 50 | 51 | 52 | def __getitem__(self, idx): 53 | 54 | image_name = self.image_list[idx] 55 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 56 | 57 | # return image, gt, image_name 58 | return torch.tensor(np.array(origin_image)), self.image2groundtruth[image_name], image_name 59 | 60 | 61 | def __len__(self): 62 | return len(self.image_list) -------------------------------------------------------------------------------- /data/datasets/CNT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | class Falcon_CNT(torch.utils.data.Dataset): 10 | 11 | convert_name = None 12 | prior_area_info = None 13 | 14 | def __init__(self, args, max_tokens=20) : 15 | 16 | self.args = args 17 | self.split = args.split 18 | self.max_tokens = max_tokens 19 | 20 | self.image_path = args.imageFolder 21 | self.anno_path = args.annoFolder 22 | 23 | self.image_list, self.groundtruth, self.classnames, self.foo = self.load_data() 24 | 25 | 26 | def load_data(self): 27 | 28 | image_list = [] 29 | groundtruth = [] 30 | classnames = [] 31 | foo = [] 32 | 33 | with open(self.anno_path, 'r') as f: 34 | data = json.load(f) 35 | 36 | for item in tqdm(data, desc="Generating sentence embeddings", total=len(data)): 37 | classname = item["conversations"][0]["content"].split("the number of")[-1].split(".")[0].strip() 38 | if classname in self.convert_name: 39 | classname = self.convert_name[classname] 40 | 41 | classnames.append(classname) 42 | 43 | image_list.append(os.path.basename(item["images"][0])) 44 | groundtruth.append(int(item["conversations"][1]["content"])) 45 | 46 | if self.prior_area_info is not None: 47 | foo.append(self.prior_area_info[classname]) 48 | else: 49 | foo.append(None) 50 | 51 | return image_list, groundtruth, classnames, foo 52 | 53 | 54 | def __getitem__(self, idx): 55 | image_name = self.image_list[idx] 56 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 57 | 58 | # return image, gt, image_name, classname, foo 59 | return torch.tensor(np.array(origin_image)), self.groundtruth[idx], image_name, self.classnames[idx], self.foo[idx] 60 | 61 | 62 | def __len__(self): 63 | return len(self.image_list) 64 | 65 | 66 | class DIOR_CNT(Falcon_CNT): 67 | 68 | # unify classnames 69 | convert_name = { 70 | 'railway station': 'trainstation', 71 | 'baseball field': 'baseballfield', 72 | 'basketball court': 'basketballcourt', 73 | 'tennis court': 'tenniscourt', 74 | 'ground track field': 'groundtrackfield', 75 | 'expressway toll station': 'expressway-toll-station', 76 | 'storage tank': 'storagetank', 77 | 'golf field': 'golffield', 78 | 'expressway service area': 'expressway-service-area' 79 | } 80 | 81 | prior_area_info = {'airplane': 36, 'airport': 1824, 'baseballfield': 20, 'basketballcourt': 88, 'bridge': 9, 'chimney': 12, 'dam': 264, 'expressway-service-area': 110, 'expressway-toll-station': 36, 'golffield': 2728, 'groundtrackfield': 132, 'harbor': 24, 'overpass': 9, 'ship': 14, 'stadium': 255, 'storagetank': 4, 'tenniscourt': 36, 'trainstation': 1170, 'vehicle': 6, 'windmill': 27} 82 | 83 | 84 | class DOTA_CNT(Falcon_CNT): 85 | 86 | # unify classnames 87 | convert_name = { 88 | "baseball field": "baseballfield", 89 | "basketball court": "basketballcourt", 90 | "crane": "container-crane", 91 | "ground track field": "groundtrackfield", 92 | "soccer ball field": "soccer-ball-field", 93 | "storage tank": "storagetank", 94 | "swimming pool": "swimming-pool", 95 | "tennis court": "tenniscourt", 96 | } 97 | 98 | prior_area_info = {'airplane': 35, 'ship': 15, 'storagetank': 8, 'baseballfield': 81, 'tenniscourt': 153, 'basketballcourt': 576, 'groundtrackfield': 210, 'harbor': 64, 'bridge': 24, 'vehicle': 21, 'helicopter': 162, 'roundabout': 16, 'soccer-ball-field': 72, 'swimming-pool': 20, "container-crane": 187, "airport": 1337, "helipad":121} -------------------------------------------------------------------------------- /data/datasets/DET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | class COCO_style(torch.utils.data.Dataset): 10 | 11 | classes = None 12 | 13 | def __init__(self, args, max_tokens=20) : 14 | 15 | assert args._class in self.classes, "class not in the list of classes" 16 | 17 | self.args = args 18 | self.split = args.split 19 | self.max_tokens = max_tokens 20 | 21 | self.image_path = args.imageFolder 22 | self.anno_path = args.annoFolder 23 | 24 | self.image_list, self.image2groundtruth, self.image2id = self.load_data() 25 | 26 | 27 | def load_data(self): 28 | 29 | image_list = [] 30 | image2groundtruth = {} 31 | image2id = {} 32 | 33 | with open(self.anno_path, 'r') as f: 34 | data = json.load(f) 35 | 36 | id2image = {} 37 | for item in data['images']: 38 | id2image[item["id"]] = item["file_name"] 39 | 40 | for item in tqdm(data['annotations'], desc="Collecting groundtruth", total=len(data['annotations'])): 41 | image_name = id2image[item["image_id"]] 42 | if image_name not in image2groundtruth: 43 | image2groundtruth[image_name] = [] 44 | image_list.append(image_name) 45 | image2id[image_name] = item["image_id"] 46 | 47 | if item["category_id"] == self.classes.index(self.args._class): 48 | x, y, w, h = item["bbox"] 49 | box = [x, y, x+w, y+h] 50 | image2groundtruth[image_name].append(box) 51 | 52 | return sorted(image_list), image2groundtruth, image2id 53 | 54 | 55 | def __getitem__(self, idx): 56 | 57 | image_name = self.image_list[idx] 58 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 59 | 60 | # return image, gt, image_name 61 | return torch.tensor(np.array(origin_image)), torch.tensor(self.image2groundtruth[image_name]), image_name 62 | 63 | 64 | def __len__(self): 65 | return len(self.image_list) 66 | 67 | 68 | class DOTA_style(torch.utils.data.Dataset): 69 | 70 | classes = None 71 | convert_name = None 72 | 73 | def __init__(self, args, max_tokens=20) : 74 | 75 | assert args._class in self.classes, "class not in the list of classes" 76 | 77 | self.args = args 78 | self.split = args.split 79 | self.max_tokens = max_tokens 80 | 81 | self.image_path = args.imageFolder 82 | self.anno_path = args.annoFolder 83 | 84 | self.image_list, self.image2groundtruth = self.load_data() 85 | 86 | 87 | def load_data(self): 88 | 89 | image_list = os.listdir(self.image_path) 90 | image2groundtruth = {} 91 | 92 | for image_name in image_list: 93 | image2groundtruth[image_name] = [] 94 | 95 | # skip for no gt available in DOTA test set 96 | if self.split != 'test': 97 | for image_name in tqdm(image_list, desc="Collecting groundtruth", total=len(image_list)): 98 | anno_name = os.path.splitext(image_name)[0] + ".txt" 99 | with open(os.path.join(self.anno_path, anno_name), 'r') as f: 100 | for line in f: 101 | data = line.strip().split(' ') 102 | if data[8] in self.convert_name: 103 | data[8] = self.convert_name[data[8]] 104 | if data[8] == self.args._class: 105 | image2groundtruth[image_name].append([int(float(a)) for a in data[:8]]) 106 | 107 | return sorted(image_list), image2groundtruth 108 | 109 | 110 | def __getitem__(self, idx): 111 | 112 | image_name = self.image_list[idx] 113 | origin_img = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 114 | 115 | # return image, gt, image_name 116 | return torch.tensor(np.array(origin_img)), torch.tensor(self.image2groundtruth[image_name]), image_name 117 | 118 | 119 | def __len__(self): 120 | return len(self.image_list) 121 | 122 | 123 | class DIOR_DET(COCO_style): 124 | 125 | classes = ['background', 'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge', 'chimney', 'dam', 'expressway-service-area', 'expressway-toll-station', 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle', 'windmill'] 126 | 127 | 128 | class DOTAv2_DET(DOTA_style): 129 | 130 | classes = ['airplane', 'ship', 'storagetank', 'baseballfield', 'tenniscourt', 'basketballcourt', 'groundtrackfield', 'harbor', 'bridge', 'large-vehicle', 'small-vehicle', 'helicopter', 'roundabout', 'soccer-ball-field', 'swimming-pool', "container-crane", "airport", "helipad"] 131 | 132 | # unify classname 133 | convert_name = { 134 | 'plane': 'airplane', 135 | 'storage-tank': 'storagetank', 136 | 'baseball-diamond': 'baseballfield', 137 | 'tennis-court': 'tenniscourt', 138 | 'basketball-court': 'basketballcourt', 139 | 'ground-track-field': 'groundtrackfield', 140 | } 141 | 142 | 143 | class DOTAv1_DET(DOTA_style): 144 | 145 | classes = ['plane', 'ship', 'storage-tank', 'baseball-datdmond', 'tennis-court', 'basketball-court', 'ground-track-field', 'harbor', 'bridge', 'large-vehicle', 'small-vehicle', 'helicopter', 'roundabout', 'soccer-ball-field', 'swimming-pool'] 146 | 147 | # unify classname 148 | convert_name = { 149 | 'plane': 'airplane', 150 | 'storage-tank': 'storagetank', 151 | 'baseball-diamond': 'baseballfield', 152 | 'tennis-court': 'tenniscourt', 153 | 'basketball-court': 'basketballcourt', 154 | 'ground-track-field': 'groundtrackfield', 155 | } -------------------------------------------------------------------------------- /data/datasets/MCC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | class MCC(torch.utils.data.Dataset): 10 | 11 | classes = None 12 | convert_name = None 13 | 14 | def __init__(self, args, max_tokens=20) : 15 | 16 | self.args = args 17 | self.split = args.split 18 | self.max_tokens = max_tokens 19 | 20 | self.image_path = args.imageFolder 21 | self.anno_path = args.annoFolder 22 | 23 | self.image_list, self.image2groundtruth = self.load_data() 24 | 25 | 26 | def load_data(self): 27 | image_list = [] 28 | image2groundtruth = {} 29 | 30 | for index, classname in enumerate(self.classes): 31 | if classname in self.convert_name: 32 | classname = self.convert_name[classname] 33 | for image in os.listdir(os.path.join(self.anno_path, classname)): 34 | image_path = os.path.join(classname, image) 35 | image2groundtruth[image_path] = index 36 | image_list.append(image_path) 37 | 38 | return image_list, image2groundtruth 39 | 40 | 41 | def __getitem__(self, idx): 42 | 43 | image_name = self.image_list[idx] 44 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 45 | 46 | # return image, gt, image_name 47 | return torch.tensor(np.array(origin_image)), torch.tensor(self.image2groundtruth[image_name]), image_name 48 | 49 | 50 | def __len__(self): 51 | return len(self.image_list) 52 | 53 | 54 | class UCM_MCC(MCC): 55 | 56 | classes = ['agricultural', 'airplane', 'baseballfield', 'beach', 'buildings', 'chaparral', 'denseresidential', 'forest', 'freeway', 'golffield', 'harbor', 'intersection', 'mediumresidential', 'mobilehomepark', 'overpass', 'parkinglot', 'river', 'runway', 'sparseresidential', 'storagetank', 'tenniscourt'] 57 | 58 | convert_name = { 59 | 'baseballfield': 'baseballdiamond', 60 | 'golffield': 'golfcourse', 61 | 'storagetank': 'storagetanks' 62 | } 63 | 64 | 65 | class AID_MCC(MCC): 66 | 67 | classes = ['airport', 'bareland', 'baseballfield', 'beach', 'bridge', 'center', 'church', 'commercial', 'denseresidential', 'desert', 'farmland', 'forest', 'industrial', 'meadow', 'mediumresidential', 'mountain', 'park', 'parking', 'playground', 'pond', 'port', 'railwaystation', 'resort', 'river', 'school', 'sparseresidential', 'square', 'stadium', 'storagetank', 'viaduct'] 68 | 69 | convert_name = { 70 | 'storagetank': 'storagetanks' 71 | } -------------------------------------------------------------------------------- /data/datasets/MLC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | class COCO_style(torch.utils.data.Dataset): 10 | 11 | classes = None 12 | 13 | def __init__(self, args, max_tokens=20) : 14 | 15 | self.args = args 16 | self.split = args.split 17 | self.max_tokens = max_tokens 18 | 19 | self.image_path = args.imageFolder 20 | self.anno_path = args.annoFolder 21 | 22 | self.image_list, self.image2groundtruth = self.load_data() 23 | 24 | 25 | def load_data(self): 26 | 27 | image_list = [] 28 | image2groundtruth = {} 29 | 30 | with open(self.anno_path, 'r') as f: 31 | data = json.load(f) 32 | 33 | id2image = {} 34 | for item in data['images']: 35 | id2image[item["id"]] = item["file_name"] 36 | 37 | for item in tqdm(data['annotations'], desc="Collecting groundtruth", total=len(data['annotations'])): 38 | image_name = id2image[item["image_id"]] 39 | if image_name not in image2groundtruth: 40 | image2groundtruth[image_name] = [0] * len(self.classes) 41 | image_list.append(image_name) 42 | 43 | image2groundtruth[image_name][int(item["category_id"])-1] = 1 44 | 45 | return sorted(image_list), image2groundtruth 46 | 47 | 48 | def __getitem__(self, idx): 49 | 50 | image_name = self.image_list[idx] 51 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 52 | 53 | # return image, gt, image_name 54 | return torch.tensor(np.array(origin_image)), torch.tensor(self.image2groundtruth[image_name]), image_name 55 | 56 | 57 | def __len__(self): 58 | return len(self.image_list) 59 | 60 | 61 | class DOTA_style(COCO_style): 62 | 63 | convert_name = None 64 | 65 | def load_data(self): 66 | 67 | image_list = os.listdir(self.image_path) 68 | image2groundtruth = {} 69 | 70 | for image_name in image_list: 71 | image2groundtruth[image_name] = [0] * len(self.classes) 72 | 73 | # skip for no gt available in DOTA test set 74 | if self.split != 'test': 75 | for image_name in tqdm(image_list, desc="Collecting groundtruth", total=len(image_list)): 76 | anno_name = os.path.splitext(image_name)[0] + ".txt" 77 | with open(os.path.join(self.anno_path, anno_name), 'r') as f: 78 | for line in f: 79 | data = line.strip().split(' ') 80 | if data[8] in self.convert_name: 81 | data[8] = self.convert_name[data[8]] 82 | image2groundtruth[image_name][self.classes.index(data[8])] = 1 83 | 84 | return sorted(image_list), image2groundtruth 85 | 86 | 87 | class DIOR_MLC(COCO_style): 88 | 89 | classes = ['airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge', 'chimney', 'dam', 'expressway-service-area', 'expressway-toll-station', 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle', 'windmill'] 90 | 91 | 92 | class DOTAv2_MLC(DOTA_style): 93 | 94 | classes = ['airplane', 'ship', 'storagetank', 'baseballfield', 'tenniscourt', 'basketballcourt', 'groundtrackfield', 'harbor', 'bridge', 'large-vehicle', 'small-vehicle', 'helicopter', 'roundabout', 'soccer-ball-field', 'swimming-pool', "container-crane", "airport", "helipad"] 95 | 96 | # unify classname 97 | convert_name = { 98 | 'plane': 'airplane', 99 | 'storage-tank': 'storagetank', 100 | 'baseball-diamond': 'baseballfield', 101 | 'tennis-court': 'tenniscourt', 102 | 'basketball-court': 'basketballcourt', 103 | 'ground-track-field': 'groundtrackfield', 104 | } -------------------------------------------------------------------------------- /data/datasets/VG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import json 5 | from tqdm import tqdm 6 | import xml.etree.cElementTree as ET 7 | import numpy as np 8 | 9 | 10 | class RSVG_VG(torch.utils.data.Dataset): 11 | 12 | def __init__(self, args, max_tokens=20): 13 | 14 | self.args = args 15 | self.split = args.split 16 | self.max_tokens = max_tokens 17 | 18 | self.image_path = args.imageFolder 19 | self.anno_path = args.annoFolder 20 | 21 | self.image_list, self.groundtruth, self.sentence = self.load_data() 22 | 23 | 24 | def load_data(self): 25 | 26 | image_list = [] 27 | groundtruth = [] 28 | sentence = [] 29 | 30 | split_txt = os.path.join(os.path.dirname(self.anno_path), self.split + ".txt") 31 | with open(split_txt, "r", encoding="utf-8") as f: 32 | lines = f.readlines() 33 | Index = [ int(line.strip()) for line in lines ] 34 | 35 | count = 0 36 | anno_list = sorted(os.listdir(self.anno_path)) 37 | for anno in tqdm(anno_list, desc="Generating sentence embeddings", total=len(anno_list)): 38 | root = ET.parse(os.path.join(self.anno_path, anno)).getroot() 39 | image_name = root.find('filename').text 40 | objects = [] 41 | for obj in root.findall('object'): 42 | if count in Index: 43 | 44 | bndbox = obj.find('bndbox') 45 | xmin = bndbox.find('xmin').text 46 | ymin = bndbox.find('ymin').text 47 | xmax = bndbox.find('xmax').text 48 | ymax = bndbox.find('ymax').text 49 | box = [int(xmin), int(ymin), int(xmax), int(ymax)] 50 | 51 | description = obj.find('description').text 52 | 53 | image_list.append(image_name) 54 | groundtruth.append(box) 55 | sentence.append(description) 56 | 57 | count += 1 58 | 59 | return image_list, groundtruth, sentence 60 | 61 | 62 | def __getitem__(self, idx): 63 | 64 | image_name = self.image_list[idx] 65 | origin_image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB') 66 | 67 | # return image, gt, image_name, sentence 68 | return torch.tensor(np.array(origin_image)), torch.tensor(self.groundtruth[idx]), image_name, self.sentence[idx] 69 | 70 | 71 | def __len__(self): 72 | return len(self.image_list) -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .CNT import DIOR_CNT, DOTA_CNT 2 | from .VG import RSVG_VG 3 | from .DET import DIOR_DET, DOTAv2_DET, DOTAv1_DET 4 | from .MLC import DIOR_MLC, DOTAv2_MLC 5 | from .MCC import UCM_MCC, AID_MCC 6 | from .CAP import UCM_CAP -------------------------------------------------------------------------------- /lib/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import sys 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from bert.modeling_bert import BertModel 7 | 8 | 9 | def load_weights(model, load_path): 10 | dict_trained = torch.load(load_path)['model'] 11 | dict_new = model.state_dict().copy() 12 | for key in dict_new.keys(): 13 | if key in dict_trained.keys(): 14 | dict_new[key] = dict_trained[key] 15 | model.load_state_dict(dict_new) 16 | del dict_new 17 | del dict_trained 18 | torch.cuda.empty_cache() 19 | print('load weights from {}'.format(load_path)) 20 | return model 21 | 22 | 23 | class _LAVTSimpleDecode(nn.Module): 24 | def __init__(self, backbone, classifier): 25 | super(_LAVTSimpleDecode, self).__init__() 26 | self.backbone = backbone 27 | self.classifier = classifier 28 | 29 | def forward(self, x, l_feats, l_mask): 30 | input_shape = x.shape[-2:] 31 | features = self.backbone(x, l_feats, l_mask) 32 | x_c1, x_c2, x_c3, x_c4 = features 33 | 34 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 35 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 36 | 37 | return x 38 | 39 | 40 | class LAVT(_LAVTSimpleDecode): 41 | pass 42 | 43 | 44 | ############################################### 45 | # LAVT One: put BERT inside the overall model # 46 | ############################################### 47 | class _LAVTOneSimpleDecode(nn.Module): 48 | def __init__(self, backbone, classifier, args): 49 | super(_LAVTOneSimpleDecode, self).__init__() 50 | self.backbone = backbone 51 | self.classifier = classifier 52 | self.text_encoder = BertModel.from_pretrained(args.ck_bert) 53 | self.text_encoder.pooler = None 54 | 55 | def forward(self, x, text, l_mask): 56 | input_shape = x.shape[-2:] 57 | ### language inference ### 58 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768) 59 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) 60 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 61 | ########################## 62 | features = self.backbone(x, l_feats, l_mask) 63 | x_c1, x_c2, x_c3, x_c4 = features # e.g. x_c1:[B, 128, 120, 120], x_c2:[B, 256, 60, 60], x_c3:[B, 512, 30, 30], x_c4:[B, 1024, 15, 15] 64 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 65 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 66 | return x 67 | 68 | 69 | class LAVTOne(_LAVTOneSimpleDecode): #change 70 | pass 71 | -------------------------------------------------------------------------------- /lib/cross_scale_interaction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class h_sigmoid(nn.Module): 7 | def __init__(self, inplace=True): 8 | super(h_sigmoid, self).__init__() 9 | self.relu = nn.ReLU6(inplace=inplace) 10 | 11 | def forward(self, x): 12 | return self.relu(x + 3) / 6 13 | 14 | 15 | class Linear_BN(torch.nn.Sequential): 16 | def __init__(self, a, b, bn_weight_init=1): 17 | super().__init__() 18 | self.add_module('c', torch.nn.Linear(a, b, bias=False)) 19 | bn = torch.nn.BatchNorm1d(b) 20 | torch.nn.init.constant_(bn.weight, bn_weight_init) 21 | torch.nn.init.constant_(bn.bias, 0) 22 | self.add_module('bn', bn) 23 | 24 | @torch.no_grad() 25 | def fuse(self): 26 | l, bn = self._modules.values() 27 | w = bn.weight / (bn.running_var + bn.eps)**0.5 28 | w = l.weight * w[:, None] 29 | b = bn.bias - bn.running_mean * bn.weight / \ 30 | (bn.running_var + bn.eps)**0.5 31 | m = torch.nn.Linear(w.size(1), w.size(0)) 32 | m.weight.data.copy_(w) 33 | m.bias.data.copy_(b) 34 | return m 35 | 36 | def forward(self, x): 37 | l, bn = self._modules.values() 38 | x = l(x) 39 | return bn(x.flatten(0, 1)).reshape_as(x) 40 | 41 | 42 | class Residual(torch.nn.Module): 43 | def __init__(self, m): 44 | super().__init__() 45 | self.m = m 46 | 47 | def forward(self, x): 48 | return x + self.m(x) 49 | 50 | 51 | class ScaleAwareGate(nn.Module): 52 | def __init__(self, inp, oup): 53 | super(ScaleAwareGate, self).__init__() 54 | 55 | self.local_embedding = nn.Conv2d(inp, oup, kernel_size=1) 56 | self.bn1 = nn.BatchNorm2d(oup) 57 | 58 | self.global_embedding = nn.Conv2d(inp, oup, kernel_size=1) 59 | self.bn2 = nn.BatchNorm2d(oup) 60 | 61 | self.global_act = nn.Conv2d(inp, oup, kernel_size=1) 62 | self.bn3 = nn.BatchNorm2d(oup) 63 | self.act = h_sigmoid() 64 | 65 | def forward(self, x_l, x_g): 66 | B, C, H, W = x_l.shape 67 | local_feat = self.local_embedding(x_l) 68 | local_feat = self.bn1(local_feat) 69 | 70 | global_feat = self.global_embedding(x_g) 71 | global_feat = self.bn2(global_feat) 72 | global_feat = F.interpolate(global_feat, size=(H, W), mode='bilinear', align_corners=False) 73 | 74 | global_act = self.global_act(x_g) 75 | global_act = self.bn3(global_act) 76 | sig_act = F.interpolate(self.act(global_act), size=(H, W), mode='bilinear', align_corners=False) 77 | 78 | out = local_feat * sig_act + global_feat 79 | return out 80 | 81 | 82 | class Attention(torch.nn.Module): 83 | def __init__(self, dim, img_shape, att_shape, key_dim=32, num_heads=8, attn_ratio=2, activation=torch.nn.Hardswish): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | self.scale = key_dim ** -0.5 87 | self.key_dim = key_dim 88 | self.img_shape = img_shape 89 | self.nh_kd = nh_kd = key_dim * num_heads 90 | self.d = int(attn_ratio * key_dim) 91 | self.dh = int(attn_ratio * key_dim) * num_heads 92 | self.attn_ratio = attn_ratio 93 | h = self.dh + nh_kd * 2 94 | self.qkv = Linear_BN(dim, h) 95 | 96 | self.parallel_conv = nn.Sequential( 97 | nn.Hardswish(inplace=False), 98 | nn.Conv2d(self.dh, self.dh, kernel_size=3, padding=1, groups=self.dh), 99 | ) 100 | self.to_out = nn.Linear(self.dh, dim) 101 | self.proj = nn.Linear(att_shape, img_shape) 102 | 103 | def forward(self, x): # x (B,N,C) 104 | B, N, C = x.shape 105 | qkv = self.qkv(x) 106 | q, k, v = qkv.view(B, N, self.num_heads, - 107 | 1).split([self.key_dim, self.key_dim, self.d], dim=3) 108 | q = q.permute(0, 2, 1, 3) 109 | k = k.permute(0, 2, 1, 3) 110 | v = v.permute(0, 2, 1, 3) 111 | 112 | v0 = v[:, :, :self.img_shape, :] 113 | 114 | v0 = v0.reshape(B, self.dh, int(self.img_shape ** 0.5), -1) 115 | v_conv = self.parallel_conv(v0).flatten(2) 116 | 117 | attn = ( 118 | (q @ k.transpose(-2, -1)) * self.scale 119 | ) 120 | attn = attn.softmax(dim=-1) 121 | x = (attn @ v).transpose(1, 2).reshape(B, -1, N) 122 | x = self.proj(x) + v_conv 123 | x = self.to_out(x.permute(0, 2, 1)) # + v_conv 124 | return x 125 | 126 | 127 | class CrossScaleAttention(nn.Module): 128 | def __init__(self, dim, img_shape=784, att_shape=1080): 129 | super().__init__() 130 | self.bn1 = nn.BatchNorm2d(dim) 131 | 132 | self.DWConv1 = nn.Sequential( 133 | nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1, groups=dim), 134 | nn.BatchNorm2d(dim), 135 | ) 136 | self.DWConv2 = nn.Sequential( 137 | nn.Conv2d(dim, dim, kernel_size=5, stride=3, padding=2, groups=dim), 138 | nn.BatchNorm2d(dim), 139 | ) 140 | self.attention = Attention(dim, img_shape, att_shape) 141 | self.bn4 = nn.BatchNorm2d(dim) 142 | self.activate = nn.Hardswish() 143 | self.conv = nn.Conv2d(dim, dim, 1) 144 | 145 | 146 | def forward(self, x): 147 | x0 = self.bn1(x) 148 | x1 = self.DWConv1(x0) 149 | x2 = self.DWConv2(x0) 150 | # [B, C, H, W] -> [B, C, H*W] 151 | x0, x1, x2 = x0.view(x0.shape[0], x0.shape[1], -1), x1.view(x1.shape[0], x1.shape[1], -1), x2.view(x2.shape[0], x2.shape[1], -1) 152 | attn = torch.cat((x0, x1, x2), dim=2).permute(0, 2, 1) 153 | attn = self.attention(attn) 154 | attn = attn.permute(0, 2, 1).contiguous().view(x0.shape[0], x0.shape[1], 28, 28) 155 | x = self.conv(self.activate(self.bn4(attn))) 156 | return x 157 | 158 | 159 | class FeedForward(nn.Module): 160 | def __init__(self, dim, hidden_dim): 161 | super().__init__() 162 | self.bn1 = nn.BatchNorm2d(dim) 163 | self.conv1 = nn.Conv2d(dim, hidden_dim, 1) 164 | self.bn2 = nn.BatchNorm2d(hidden_dim) 165 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=dim) 166 | self.relu = nn.ReLU6() 167 | self.conv3 = nn.Conv2d(hidden_dim, dim, 1) 168 | 169 | def forward(self, x): 170 | out = self.conv3(self.relu(self.conv2(self.bn2(self.conv1(self.bn1(x)))))) 171 | return out 172 | 173 | 174 | class IntraFeedForward(nn.Module): 175 | def __init__(self, channels, mlp_ratio=2): 176 | super().__init__() 177 | self.channels = [channels[i]//4 for i in range(len(channels))] 178 | 179 | self.ff1 = Residual(FeedForward(self.channels[0], mlp_ratio*self.channels[0])) 180 | self.ff2 = Residual(FeedForward(self.channels[1], mlp_ratio*self.channels[1])) 181 | self.ff3 = Residual(FeedForward(self.channels[2], mlp_ratio*self.channels[2])) 182 | self.ff4 = Residual(FeedForward(self.channels[3], mlp_ratio*self.channels[3])) 183 | 184 | def forward(self, x): 185 | x1, x2, x3, x4 = x.split(self.channels, dim=1) 186 | x1 = self.ff1(x1) 187 | x2 = self.ff2(x2) 188 | x3 = self.ff3(x3) 189 | x4 = self.ff4(x4) 190 | return torch.cat([x1, x2, x3, x4], dim=1) 191 | 192 | 193 | class CIMBlock(nn.Module): 194 | def __init__(self, dim, channels, mlp_ratio=2): 195 | super().__init__() 196 | self.csa1 = Residual(CrossScaleAttention(dim)) 197 | self.intra_ff = Residual(IntraFeedForward(channels, mlp_ratio)) 198 | self.csa2 = Residual(CrossScaleAttention(dim)) 199 | self.ff = Residual(FeedForward(dim, dim*mlp_ratio)) 200 | 201 | def forward(self, x): 202 | x = self.csa1(x) 203 | x = self.intra_ff(x) 204 | x = self.csa2(x) 205 | x = self.ff(x) 206 | return x 207 | 208 | 209 | class PyramidPoolAgg(nn.Module): 210 | def __init__(self, stride): 211 | super().__init__() 212 | self.stride = stride 213 | 214 | def forward(self, inputs): 215 | B, C, H, W = inputs[-1].shape 216 | H = (H - 1) // self.stride + 1 217 | W = (W - 1) // self.stride + 1 218 | return torch.cat([nn.functional.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], dim=1) 219 | 220 | 221 | class CIM(nn.Module): 222 | def __init__(self, dim, num_layers=1, channels=[128, 256, 512, 1024], downsample=1): 223 | super().__init__() 224 | self.hidden_dim = dim // 4 225 | self.channels = channels 226 | self.stride = downsample 227 | 228 | self.down_channel = nn.Conv2d(dim, self.hidden_dim, 1) 229 | self.up_channel = nn.Conv2d(self.hidden_dim, dim, 1) 230 | 231 | # downsample to h/32, w/32 232 | self.pool = PyramidPoolAgg(stride=self.stride) 233 | self.block = nn.ModuleList([ 234 | CIMBlock(self.hidden_dim, channels) 235 | for _ in range(num_layers) 236 | ]) 237 | self.bn = nn.BatchNorm2d(self.hidden_dim) 238 | self.fusion = nn.ModuleList([ 239 | ScaleAwareGate(channels[i], channels[i]) 240 | for i in range(len(channels)) 241 | ]) 242 | 243 | def forward(self, input): # [B, C, H, W] 244 | out = self.pool(input) 245 | out = self.down_channel(out) 246 | for layer in self.block: 247 | out = layer(out) 248 | out = self.bn(out) 249 | out = self.up_channel(out) 250 | xx = out.split(self.channels, dim=1) 251 | results = [] 252 | for i in range(len(self.channels)): 253 | CIM_before = input[i] 254 | CIM_after = xx[i] 255 | out_ = self.fusion[i](CIM_before, CIM_after) 256 | results.append(out_) 257 | return results 258 | 259 | 260 | 261 | if __name__ == '__main__': 262 | model = CIM(1920) 263 | x1 = torch.randn(2, 128, 120, 120) 264 | x2 = torch.randn(2, 256, 60, 60) 265 | x3 = torch.randn(2, 512, 30, 30) 266 | x4 = torch.randn(2, 1024, 15, 15) 267 | x = tuple([x1, x2, x3, x4]) 268 | y = model(x) 269 | 270 | 271 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 32 | logger.addHandler(console_handler) 33 | 34 | # create file handlers 35 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 36 | file_handler.setLevel(logging.DEBUG) 37 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 38 | logger.addHandler(file_handler) 39 | 40 | return logger 41 | -------------------------------------------------------------------------------- /lib/mask_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | from arc import AdaptiveRotatedConv2d, RountingFunction 6 | 7 | 8 | class SimpleDecoding(nn.Module): 9 | def __init__(self, c4_dims, factor=2): 10 | super(SimpleDecoding, self).__init__() 11 | 12 | hidden_size = c4_dims//factor 13 | c4_size = c4_dims 14 | c3_size = c4_dims//(factor**1) 15 | c2_size = c4_dims//(factor**2) 16 | c1_size = c4_dims//(factor**3) 17 | 18 | self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False) 19 | routing_function1 = RountingFunction(in_channels=hidden_size, kernel_number=1) 20 | self.conv2_4 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 21 | kernel_size=3, padding=1, rounting_func=routing_function1, bias=False, kernel_number=1) 22 | 23 | self.bn1_4 = nn.BatchNorm2d(hidden_size) 24 | self.relu1_4 = nn.ReLU() 25 | self.bn2_4 = nn.BatchNorm2d(hidden_size) 26 | self.relu2_4 = nn.ReLU() 27 | 28 | self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False) 29 | routing_function2 = RountingFunction(in_channels=hidden_size, kernel_number=1) 30 | self.conv2_3 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 31 | kernel_size=3, padding=1, rounting_func=routing_function2, bias=False, kernel_number=1) 32 | self.bn1_3 = nn.BatchNorm2d(hidden_size) 33 | self.relu1_3 = nn.ReLU() 34 | self.bn2_3 = nn.BatchNorm2d(hidden_size) 35 | self.relu2_3 = nn.ReLU() 36 | 37 | self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False) 38 | routing_function3 = RountingFunction(in_channels=hidden_size, kernel_number=1) 39 | self.conv2_2 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 40 | kernel_size=3, padding=1, rounting_func=routing_function3, bias=False, kernel_number=1) 41 | self.bn1_2 = nn.BatchNorm2d(hidden_size) 42 | self.relu1_2 = nn.ReLU() 43 | self.bn2_2 = nn.BatchNorm2d(hidden_size) 44 | self.relu2_2 = nn.ReLU() 45 | 46 | self.conv1_1 = nn.Conv2d(hidden_size, 2, 1) 47 | 48 | def forward(self, x_c4, x_c3, x_c2, x_c1): 49 | # fuse Y4 and Y3 50 | if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1): 51 | x_c4 = F.interpolate(input=x_c4, scale_factor=2, mode='bilinear', align_corners=True) 52 | x = torch.cat([x_c4, x_c3], dim=1) 53 | x = self.conv1_4(x) 54 | x = self.bn1_4(x) 55 | x = self.relu1_4(x) 56 | x = self.conv2_4(x) 57 | x = self.bn2_4(x) 58 | x = self.relu2_4(x) 59 | 60 | # fuse top-down features and Y2 features 61 | if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1): 62 | x = F.interpolate(input=x, scale_factor=2, mode='bilinear', align_corners=True) 63 | x = torch.cat([x, x_c2], dim=1) 64 | x = self.conv1_3(x) 65 | x = self.bn1_3(x) 66 | x = self.relu1_3(x) 67 | x = self.conv2_3(x) 68 | x = self.bn2_3(x) 69 | x = self.relu2_3(x) 70 | 71 | # fuse top-down features and Y1 features 72 | if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1): 73 | x = F.interpolate(input=x, scale_factor=2, mode='bilinear', align_corners=True) 74 | x = torch.cat([x, x_c1], dim=1) 75 | x = self.conv1_2(x) 76 | x = self.bn1_2(x) 77 | x = self.relu1_2(x) 78 | x = self.conv2_2(x) 79 | x = self.bn2_2(x) 80 | x = self.relu2_2(x) 81 | 82 | return self.conv1_1(x) 83 | -------------------------------------------------------------------------------- /lib/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ['load_checkpoint'] 6 | -------------------------------------------------------------------------------- /lib/sa/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /lib/sa/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction_zeropad_forward_kernel( 25 | const ${Dtype}* bottom_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 40 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 41 | top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; 42 | } 43 | else 44 | top_data[offset_top] = bottom_data[offset_center]; 45 | } 46 | } 47 | } 48 | } 49 | ''' 50 | 51 | 52 | _subtraction_zeropad_input_backward_kernel = kernel_loop + ''' 53 | extern "C" 54 | __global__ void subtraction_zeropad_input_backward_kernel( 55 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 56 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 57 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 58 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 59 | const int h = (index / ${bottom_width}) % ${bottom_height}; 60 | const int w = index % ${bottom_width}; 61 | ${Dtype} value = 0; 62 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 63 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 64 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 65 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 66 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 67 | const int h_out = h_out_s / ${stride_h}; 68 | const int w_out = w_out_s / ${stride_w}; 69 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 70 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 71 | value += -top_diff[offset_top]; 72 | } 73 | } 74 | } 75 | } 76 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 77 | const int h_out = h / ${stride_h}; 78 | const int w_out = w / ${stride_w}; 79 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 80 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 81 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 82 | value += top_diff[offset_top]; 83 | } 84 | } 85 | } 86 | bottom_diff[index] = value; 87 | } 88 | } 89 | ''' 90 | 91 | 92 | class SubtractionZeropad(Function): 93 | @staticmethod 94 | def forward(ctx, input, kernel_size, stride, padding, dilation): 95 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 96 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 97 | assert input.dim() == 4 and input.is_cuda 98 | batch_size, input_channels, input_height, input_width = input.size() 99 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 100 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 101 | output = input.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 102 | n = output.numel() // output.shape[2] 103 | with torch.cuda.device_of(input): 104 | f = load_kernel('subtraction_zeropad_forward_kernel', _subtraction_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n, 105 | num=batch_size, input_channels=input_channels, 106 | bottom_height=input_height, bottom_width=input_width, 107 | top_height=output_height, top_width=output_width, 108 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 109 | stride_h=stride[0], stride_w=stride[1], 110 | dilation_h=dilation[0], dilation_w=dilation[1], 111 | pad_h=padding[0], pad_w=padding[1]) 112 | f(block=(CUDA_NUM_THREADS, 1, 1), 113 | grid=(GET_BLOCKS(n), 1, 1), 114 | args=[input.data_ptr(), output.data_ptr()], 115 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 116 | ctx.save_for_backward(input) 117 | return output 118 | 119 | @staticmethod 120 | def backward(ctx, grad_output): 121 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 122 | input, = ctx.saved_tensors 123 | assert grad_output.is_cuda 124 | if not grad_output.is_contiguous(): 125 | grad_output = grad_output.contiguous() 126 | batch_size, input_channels, input_height, input_width = input.size() 127 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 128 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 129 | grad_input = None 130 | opt = dict(Dtype=Dtype(grad_output), 131 | num=batch_size, input_channels=input_channels, 132 | bottom_height=input_height, bottom_width=input_width, 133 | top_height=output_height, top_width=output_width, 134 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 135 | stride_h=stride[0], stride_w=stride[1], 136 | dilation_h=dilation[0], dilation_w=dilation[1], 137 | pad_h=padding[0], pad_w=padding[1]) 138 | with torch.cuda.device_of(input): 139 | if ctx.needs_input_grad[0]: 140 | grad_input = input.new(input.size()) 141 | n = grad_input.numel() 142 | opt['nthreads'] = n 143 | f = load_kernel('subtraction_zeropad_input_backward_kernel', _subtraction_zeropad_input_backward_kernel, **opt) 144 | f(block=(CUDA_NUM_THREADS, 1, 1), 145 | grid=(GET_BLOCKS(n), 1, 1), 146 | args=[grad_output.data_ptr(), grad_input.data_ptr()], 147 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 148 | return grad_input, None, None, None, None 149 | 150 | 151 | def subtraction_zeropad(input, kernel_size=3, stride=1, padding=0, dilation=1): 152 | assert input.dim() == 4 153 | if input.is_cuda: 154 | out = SubtractionZeropad.apply(input, kernel_size, stride, padding, dilation) 155 | else: 156 | raise NotImplementedError 157 | return out 158 | 159 | 160 | def test_subtraction_zeropad(): 161 | import os 162 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 163 | kernel_size, stride, dilation = 5, 4, 2 164 | padding = (dilation * (kernel_size - 1) + 1) // 2 165 | n, c, in_height, in_width = 2, 8, 9, 9 166 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 167 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 168 | x = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 169 | 170 | y1 = subtraction_zeropad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 171 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 172 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 173 | y2 = unfold_i(x).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 174 | # y2 = unfold_i(x[:, :, kernel_size//2:-(kernel_size//2), kernel_size//2:-(kernel_size//2)]).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 175 | assert (y1 - y2).abs().max() < 1e-9 176 | 177 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 178 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 179 | assert (gx1 - gx2).abs().max() < 1e-9 180 | 181 | from functools import partial 182 | assert torch.autograd.gradcheck(partial(subtraction_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), x) 183 | print('test case passed') 184 | 185 | 186 | if __name__ == '__main__': 187 | test_subtraction_zeropad() 188 | -------------------------------------------------------------------------------- /lib/sa/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /lib/sa/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /lib/sa/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .mask_predictor import SimpleDecoding 4 | from .backbone import MultiModalSwinTransformer 5 | from ._utils import LAVT, LAVTOne 6 | 7 | 8 | __all__ = ['lavt', 'lavt_one'] 9 | 10 | 11 | # LAVT 12 | def _segm_lavt(pretrained, args): 13 | # initialize the SwinTransformer backbone with the specified version 14 | if args.swin_type == 'tiny': 15 | embed_dim = 96 16 | depths = [2, 2, 6, 2] 17 | num_heads = [3, 6, 12, 24] 18 | elif args.swin_type == 'small': 19 | embed_dim = 96 20 | depths = [2, 2, 18, 2] 21 | num_heads = [3, 6, 12, 24] 22 | elif args.swin_type == 'base': 23 | embed_dim = 128 24 | depths = [2, 2, 18, 2] 25 | num_heads = [4, 8, 16, 32] 26 | elif args.swin_type == 'large': 27 | embed_dim = 192 28 | depths = [2, 2, 18, 2] 29 | num_heads = [6, 12, 24, 48] 30 | else: 31 | assert False 32 | # args.window12 added for test.py because state_dict is loaded after model initialization 33 | if 'window12' in pretrained or args.window12: 34 | print('Window size 12!') 35 | window_size = 12 36 | else: 37 | window_size = 7 38 | 39 | if args.mha: 40 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 41 | mha = [int(a) for a in mha] 42 | else: 43 | mha = [1, 1, 1, 1] 44 | 45 | out_indices = (0, 1, 2, 3) 46 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 47 | window_size=window_size, 48 | ape=False, drop_path_rate=0.3, patch_norm=True, 49 | out_indices=out_indices, 50 | use_checkpoint=False, num_heads_fusion=mha, 51 | fusion_drop=args.fusion_drop 52 | ) 53 | if pretrained: 54 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 55 | backbone.init_weights(pretrained=pretrained) 56 | else: 57 | print('Randomly initialize Multi-modal Swin Transformer weights.') 58 | backbone.init_weights() 59 | 60 | model_map = [SimpleDecoding, LAVT] 61 | 62 | classifier = model_map[0](8*embed_dim) 63 | base_model = model_map[1] 64 | 65 | model = base_model(backbone, classifier) 66 | return model 67 | 68 | 69 | def _load_model_lavt(pretrained, args): 70 | model = _segm_lavt(pretrained, args) 71 | return model 72 | 73 | 74 | def lavt(pretrained='', args=None): 75 | return _load_model_lavt(pretrained, args) 76 | 77 | 78 | ############################################### 79 | # LAVT One: put BERT inside the overall model # 80 | ############################################### 81 | def _segm_lavt_one(pretrained, args): 82 | # initialize the SwinTransformer backbone with the specified version 83 | if args.swin_type == 'tiny': 84 | embed_dim = 96 85 | depths = [2, 2, 6, 2] 86 | num_heads = [3, 6, 12, 24] 87 | elif args.swin_type == 'small': 88 | embed_dim = 96 89 | depths = [2, 2, 18, 2] 90 | num_heads = [3, 6, 12, 24] 91 | elif args.swin_type == 'base': 92 | embed_dim = 128 93 | depths = [2, 2, 18, 2] 94 | num_heads = [4, 8, 16, 32] 95 | elif args.swin_type == 'large': 96 | embed_dim = 192 97 | depths = [2, 2, 18, 2] 98 | num_heads = [6, 12, 24, 48] 99 | else: 100 | assert False 101 | # args.window12 added for test.py because state_dict is loaded after model initialization 102 | if 'window12' in pretrained or args.window12: 103 | print('Window size 12!') 104 | window_size = 12 105 | else: 106 | window_size = 7 107 | 108 | if args.mha: 109 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 110 | mha = [int(a) for a in mha] 111 | else: 112 | mha = [1, 1, 1, 1] 113 | 114 | out_indices = (0, 1, 2, 3) 115 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 116 | window_size=window_size, 117 | ape=False, drop_path_rate=0.3, patch_norm=True, 118 | out_indices=out_indices, 119 | use_checkpoint=False, num_heads_fusion=mha, 120 | fusion_drop=args.fusion_drop, 121 | # frozen_stages=args.frozen_stages, 122 | # only_fusion=args.only_fusion, 123 | ) 124 | if pretrained: 125 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 126 | backbone.init_weights(pretrained=pretrained) 127 | else: 128 | print('Randomly initialize Multi-modal Swin Transformer weights.') 129 | backbone.init_weights() 130 | 131 | model_map = [SimpleDecoding, LAVTOne] 132 | classifier = model_map[0](8*embed_dim) 133 | base_model = model_map[1] 134 | 135 | model = base_model(backbone, classifier, args) 136 | return model 137 | 138 | 139 | def _load_model_lavt_one(pretrained, args): 140 | model = _segm_lavt_one(pretrained, args) 141 | return model 142 | 143 | 144 | def lavt_one(pretrained='', args=None): 145 | return _load_model_lavt_one(pretrained, args) 146 | -------------------------------------------------------------------------------- /lib/various_receptive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import mean, nn 3 | from collections import OrderedDict 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from numpy import random 7 | import os 8 | 9 | 10 | 11 | def transI_fusebn(kernel, bn): 12 | gamma = bn.weight 13 | std = (bn.running_var + bn.eps).sqrt() 14 | return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std 15 | 16 | 17 | def transIV_depthconcat(kernels, biases): 18 | return torch.cat(kernels, dim=0), torch.cat(biases) 19 | 20 | 21 | def transIII_1x1_kxk(k1, b1, k2, b2, groups): 22 | if groups == 1: 23 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) # 24 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 25 | else: 26 | k_slices = [] 27 | b_slices = [] 28 | k1_T = k1.permute(1, 0, 2, 3) 29 | k1_group_width = k1.size(0) // groups 30 | k2_group_width = k2.size(0) // groups 31 | for g in range(groups): 32 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 33 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 34 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 35 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 36 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 37 | return k, b_hat + b2 38 | 39 | 40 | def _conv_bn(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1): 41 | res=nn.Sequential() 42 | res.add_module('conv',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 43 | res.add_module('bn',nn.BatchNorm2d(output_channel)) 44 | return res 45 | 46 | 47 | def _conv_bn2(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1): 48 | res=nn.Sequential() 49 | res.add_module('conv1',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=1,padding=0,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 50 | res.add_module('bn1',nn.BatchNorm2d(output_channel)) 51 | res.add_module('conv2',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 52 | res.add_module('bn2',nn.BatchNorm2d(output_channel)) 53 | return res 54 | 55 | 56 | class RepBlock(nn.Module): 57 | def __init__(self,input_channel,output_channel,kernel_size=3,groups=1,stride=1): 58 | super().__init__() 59 | self.input_channel=input_channel 60 | self.output_channel=output_channel 61 | self.kernel_size=kernel_size 62 | self.padding=kernel_size//2 63 | self.groups=groups 64 | self.activation=nn.ReLU() 65 | self.sigmoid=nn.Sigmoid() 66 | 67 | #make sure kernel_size=3 padding=1 68 | assert self.kernel_size==3 69 | assert self.padding==1 70 | 71 | self.brb_3x3=_conv_bn2(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups) 72 | self.brb_1x1=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups) 73 | self.brb_identity=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None 74 | 75 | self.brb_3x3_2=_conv_bn2(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups) 76 | self.brb_1x1_2=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups) 77 | self.brb_identity_2=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None 78 | 79 | def forward(self, inputs): 80 | if(self.brb_identity==None): 81 | identity_out=0 82 | else: 83 | identity_out=self.brb_identity(inputs) 84 | out1=self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out) 85 | 86 | 87 | if(self.brb_identity_2==None): 88 | identity_out_2=0 89 | else: 90 | identity_out_2=self.brb_identity_2(out1) 91 | out2=self.brb_1x1_2(out1)+self.brb_3x3_2(out1)+identity_out_2 92 | 93 | # print('relu') 94 | 95 | return self.sigmoid(out2) 96 | 97 | 98 | class VariousReceptive(nn.Module): 99 | def __init__(self,dim): 100 | super().__init__() 101 | self.repblock = RepBlock(1, 1) 102 | 103 | def forward(self, x): 104 | bs, n, dim = x.shape 105 | h, w = int(np.sqrt(n)), int(np.sqrt(n)) 106 | 107 | input = x.view(bs, h, w, dim).permute(0, 3, 1, 2) # bs,dim,h,w 108 | mean_input = torch.mean(input,dim=1,keepdim=True) # bs,1,h,w 109 | weight = self.repblock(mean_input) # bs,1,h,w 110 | out = input * weight 111 | out = out.reshape(bs, dim, -1).permute(0, 2, 1) # bs,n,dim 112 | return out 113 | 114 | 115 | ###test 116 | if __name__ == '__main__': 117 | input=torch.randn(50,1,49,49) 118 | repblock=RepBlock(1,1) 119 | repblock.eval() 120 | out=repblock(input) 121 | 122 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class DiceLoss: 6 | "Dice loss for segmentation" 7 | 8 | def __init__(self, 9 | axis: int = 1, # Class axis 10 | smooth: float = 1e-6, # Helps with numerical stabilities in the IoU division 11 | reduction: str = "sum", # PyTorch reduction to apply to the output 12 | square_in_union: bool = False # Squares predictions to increase slope of gradients 13 | ): 14 | self.axis = axis 15 | self.smooth = smooth 16 | self.reduction = reduction 17 | self.square_in_union = square_in_union 18 | 19 | def __call__(self, pred, targ): 20 | "One-hot encodes targ, then runs IoU calculation then takes 1-dice value" 21 | targ = self._one_hot(targ, pred.shape[self.axis]) 22 | assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' 23 | pred = self.activation(pred) 24 | sum_dims = list(range(2, len(pred.shape))) 25 | inter = torch.sum(pred * targ, dim=sum_dims) 26 | union = (torch.sum(pred ** 2 + targ, dim=sum_dims) if self.square_in_union 27 | else torch.sum(pred + targ, dim=sum_dims)) 28 | dice_score = (2. * inter + self.smooth) / (union + self.smooth) 29 | loss = 1 - dice_score 30 | if self.reduction == 'mean': 31 | loss = loss.mean() 32 | elif self.reduction == 'sum': 33 | loss = loss.sum() 34 | return loss 35 | 36 | @staticmethod 37 | def _one_hot( 38 | x, # Non one-hot encoded targs 39 | classes: int, # The number of classes 40 | axis: int = 1 # The axis to stack for encoding (class dimension) 41 | ): 42 | "Creates one binary mask per class" 43 | return torch.stack([torch.where(x == c, 1, 0) for c in range(classes)], axis=axis) 44 | 45 | def activation(self, x): 46 | "Activation function applied to model output" 47 | return F.softmax(x, dim=self.axis) 48 | 49 | def decodes(self, x): 50 | "Converts model output to target format" 51 | return x.argmax(dim=self.axis) 52 | 53 | 54 | class Loss(): 55 | def __init__(self, weight=0.1): 56 | self.dice_loss = DiceLoss() 57 | self.ce_loss = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([0.9, 1.1]).cuda()) 58 | self.weight = weight 59 | 60 | def __call__(self, pred, targ): 61 | dice_loss = self.dice_loss(pred, targ) 62 | ce_loss = self.ce_loss(pred, targ) 63 | return (1 - self.weight) * ce_loss + self.weight * dice_loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | filelock 3 | ftfy 4 | regex 5 | opencv-python==4.5.3.56 6 | h5py 7 | matplotlib==3.6.1 8 | scikit-image==0.19.3 9 | scipy==1.9.2 10 | timm==0.6.11 11 | tokenizers==0.13.1 12 | tqdm==4.64.1 13 | transformers==4.30.2 14 | yacs 15 | einops 16 | termcolor 17 | pycocotools 18 | Pillow 19 | mmdet 20 | mmsegmentation==0.17.0 21 | scikit-learn 22 | pytorch-ignite 23 | accelerate 24 | shapely 25 | torchmetrics 26 | spacy 27 | -------------------------------------------------------------------------------- /tasks/CAP.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/CAP/CAP.log" 4 | : > $FILE 5 | 6 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 7 | ./tasks/code/eval/CAP.py \ 8 | --resume ./pretrained_weights/checkpoint.pth \ 9 | --window12 \ 10 | --save_path "./result/CAP/" \ 11 | --task "CAP" \ 12 | --dataset "UCM" \ 13 | --imageFolder "./refer/data/UCM_captions/Images" \ 14 | --annoFolder "./refer/data/UCM_captions/dataset.json" \ 15 | 2>&1 | tee -a $FILE -------------------------------------------------------------------------------- /tasks/CNT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/CNT/CNT.log" 4 | : > $FILE 5 | 6 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 7 | ./tasks/code/eval/CNT.py \ 8 | --resume ./pretrained_weights/checkpoint.pth \ 9 | --split test \ 10 | --window12 \ 11 | --save_path "./result/CNT/" \ 12 | --task "CNT" \ 13 | --dataset "DIOR" \ 14 | --imageFolder "./refer/data/Counting/test/DIOR/Images" \ 15 | --annoFolder "./refer/data/Counting/test/DIOR/Annotation/test_IMG_CT.json" \ 16 | --EPOC \ 17 | 2>&1 | tee -a $FILE -------------------------------------------------------------------------------- /tasks/DET.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/DET/DET.log" 4 | : > $FILE 5 | 6 | CATEGORY=("airplane" "airport" "baseballfield" "basketballcourt" "bridge" "chimney" "dam" "expressway-service-area" "expressway-toll-station" "golffield" "groundtrackfield" "harbor" "overpass" "ship" "stadium" "storagetank" "tenniscourt" "trainstation" "vehicle" "windmill") 7 | 8 | for i in "${!CATEGORY[@]}"; do 9 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 10 | ./tasks/code/eval/DET.py \ 11 | --resume ./pretrained_weights/checkpoint.pth \ 12 | --split test \ 13 | --window12 \ 14 | --save_path "./result/DET/" \ 15 | --task "DET" \ 16 | --dataset "DIOR" \ 17 | --imageFolder "./refer/data/DIORcoco/JPEGImages" \ 18 | --annoFolder "./refer/data/DIORcoco/Annotations/DIOR_test.json" \ 19 | --_class "${CATEGORY[$i]}" \ 20 | --EPOC \ 21 | 2>&1 | tee -a $FILE 22 | done -------------------------------------------------------------------------------- /tasks/DET_DOTA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/DET_DOTA/DET_DOTA.log" 4 | : > $FILE 5 | 6 | CATEGORY=("airplane" "ship" "storagetank" "baseballfield" "tenniscourt" "basketballcourt" "groundtrackfield" "harbor" "bridge" "large-vehicle" "small-vehicle" "helicopter" "roundabout" "soccer-ball-field" "swimming-pool" "container-crane" "airport" "helipad") 7 | 8 | for i in "${!CATEGORY[@]}"; do 9 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 10 | ./tasks/code/eval/DET_DOTA.py \ 11 | --resume ./pretrained_weights/checkpoint.pth \ 12 | --split test \ 13 | --window12 \ 14 | --save_path "./result/DET_DOTA/" \ 15 | --task "DET" \ 16 | --dataset "DOTAv2" \ 17 | --imageFolder "./refer/data/DOTAv2_patches/test-dev/images" \ 18 | --annoFolder "./refer/data/DOTAv2_patches/test-dev/annfiles" \ 19 | --_class "${CATEGORY[$i]}" \ 20 | 2>&1 | tee -a $FILE 21 | done -------------------------------------------------------------------------------- /tasks/MCC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/MCC/MCC.log" 4 | : > $FILE 5 | 6 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 7 | ./tasks/code/eval/MCC.py \ 8 | --resume ./pretrained_weights/checkpoint.pth \ 9 | --split test \ 10 | --window12 \ 11 | --save_path "./result/MCC/" \ 12 | --task "MCC" \ 13 | --dataset "UCM" \ 14 | --imageFolder "./refer/data/UCMerced_LandUse/Images" \ 15 | --annoFolder "./refer/data/UCMerced_LandUse/Images" \ 16 | 2>&1 | tee -a $FILE -------------------------------------------------------------------------------- /tasks/MLC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/MLC/MLC.log" 4 | : > $FILE 5 | 6 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 \ 7 | ./tasks/code/eval/MLC.py \ 8 | --resume ./pretrained_weights/checkpoint.pth \ 9 | --split test \ 10 | --window12 \ 11 | --save_path "./result/MLC/" \ 12 | --task "MLC" \ 13 | --dataset "DIOR" \ 14 | --imageFolder "./refer/data/DIORcoco/JPEGImages" \ 15 | --annoFolder "./refer/data/DIORcoco/Annotations/DIOR_test.json" \ 16 | 2>&1 | tee -a $FILE -------------------------------------------------------------------------------- /tasks/REF.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FILE="./result/REF/REF.log" 3 | : > $FILE 4 | 5 | 6 | python ./tasks/code/eval/REF.py --nproc_per_node 1 --master_port 12345 \ 7 | --swin_type base \ 8 | --dataset rrsisd \ 9 | --resume ./pretrained_weights/checkpoint.pth \ 10 | --split val \ 11 | --workers 4 \ 12 | --window12 \ 13 | --img_size 896 \ 14 | 2>&1 | tee -a $FILE 15 | 16 | -------------------------------------------------------------------------------- /tasks/SEG.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FILE="./result/SEG/SEG.log" 3 | : > $FILE 4 | 5 | 6 | python ./tasks/code/eval/SEG.py --nproc_per_node 1 --master_port 12345 \ 7 | --resume ./pretrained_weights/checkpoint.pth \ 8 | --split val \ 9 | --workers 4 \ 10 | --window12 \ 11 | --img_size 896 \ 12 | --save_path "./result/SEG/" \ 13 | --task "SEG" \ 14 | 2>&1 | tee -a $FILE 15 | 16 | -------------------------------------------------------------------------------- /tasks/VG.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="./result/VG/VG.log" 4 | : > $FILE 5 | 6 | python \ 7 | ./tasks/code/eval/VG.py \ 8 | --resume ./pretrained_weights/checkpoint.pth \ 9 | --split test \ 10 | --window12 \ 11 | --task "VG" \ 12 | --dataset "RSVG" \ 13 | --imageFolder "./refer/data/RSVG/JPEGImages" \ 14 | --annoFolder "./refer/data/RSVG/Annotations" \ 15 | 2>&1 | tee -a $FILE -------------------------------------------------------------------------------- /tasks/code/RuleBasedCaptioning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | import json 5 | from .metric.cidereval.eval import CIDErEvalCap 6 | from pycocotools.coco import COCO 7 | 8 | def object_existence_cap(objects): 9 | sentence_templates = [ 10 | "It contains {}.", 11 | "The image includes {}.", 12 | "There are {} in the image.", 13 | "This image features {}.", 14 | "This picture shows {}." 15 | ] 16 | 17 | obj_count = {} 18 | for obj in objects: 19 | obj_count[obj] = obj_count.get(obj, 0) + 1 20 | 21 | object_names = [] 22 | for obj, count in obj_count.items(): 23 | if count >= 5: 24 | object_names.append(f"many {obj}s") 25 | else: 26 | object_names.append(f"{count} {obj}s" if count > 1 else obj) 27 | 28 | # random select a sentence template 29 | sentence_template = random.choice(sentence_templates) 30 | 31 | return sentence_template.format(", ".join(object_names)) 32 | 33 | 34 | def object_distribution_cap(objects, width, height, region_count): 35 | assert region_count in [4, 9], "region_count must be 4 or 9" 36 | 37 | split = int(region_count ** 0.5) 38 | 39 | # 0[0]: object names, 0[1]: region, 0[2]: is/are 40 | sentence_templates = [ 41 | "The {0[1]} region contains {0[0]}.", 42 | "{0[0]} {0[2]} distributed in the {0[1]}.", 43 | "There {0[2]} {0[0]} in the {0[1]}.", 44 | "The {0[1]} region has {0[0]}.", 45 | "{0[0]} {0[2]} observed in the {0[1]}.", 46 | "In the {0[1]}, {0[0]} {0[2]} present.", 47 | ] 48 | 49 | region_width = width / split 50 | region_height = height / split 51 | 52 | if region_count == 9: 53 | 54 | # ┌──────────────┬───────────────┬───────────────┐ 55 | # │ upper left │ upper center │ upper right │ 56 | # ├──────────────┼───────────────┼───────────────┤ 57 | # │ middle left │ center │ middle right │ 58 | # ├──────────────┼───────────────┼───────────────┤ 59 | # │ lower left │ lower center │ lower right │ 60 | # └──────────────┴───────────────┴───────────────┘ 61 | 62 | regions = [ 63 | ["upper left", "upper center", "upper right"], 64 | ["middle left", "center", "middle right"], 65 | ["lower left", "lower center", "lower right"], 66 | ] 67 | 68 | elif region_count == 4: 69 | 70 | regions = [ 71 | ["upper left", "upper right"], 72 | ["lower left", "lower right"], 73 | ] 74 | 75 | # fit into regions based on center coordinates 76 | region_objects = {} 77 | 78 | for i in range(split): 79 | for j in range(split): 80 | region_objects[regions[i][j]] = {} 81 | 82 | for obj, coords in objects: 83 | center_x = (coords[0] + coords[2]) / 2 84 | center_y = (coords[1] + coords[3]) / 2 85 | 86 | loc_i = int(center_y // region_height) 87 | loc_j = int(center_x // region_width) 88 | 89 | region = regions[loc_i][loc_j] 90 | 91 | if obj not in region_objects[region]: 92 | region_objects[region][obj] = 0 93 | region_objects[region][obj] += 1 94 | 95 | distribution_sentences = [] 96 | for region, objs in region_objects.items(): 97 | if objs: 98 | components = ["", region, ""] 99 | cum_cnt = 0 100 | for obj, cnt in objs.items(): 101 | cum_cnt += cnt 102 | if cnt > 1: 103 | components[0] += f"{obj}s, " 104 | else: 105 | components[0] += f"{obj}, " 106 | 107 | components[2] = "are" if cum_cnt > 1 else "is" 108 | components[0] = components[0][:-2] 109 | 110 | # random select a sentence template 111 | sentence_template = random.choice(sentence_templates) 112 | distribution_sentences.append(sentence_template.format(components)) 113 | 114 | return " ".join(distribution_sentences) 115 | 116 | 117 | def merge_json(jsonpath): 118 | """ 119 | Merge multiple json files from Object Detection into one list 120 | """ 121 | filelist = [file for file in os.listdir(jsonpath) if file.endswith('.json')] 122 | result = [] 123 | for file in filelist: 124 | with open(os.path.join(jsonpath, file), 'r') as f: 125 | data = json.load(f) 126 | result += data 127 | result.sort(key=lambda x: x['image_id']) 128 | return result 129 | 130 | 131 | def evaluate(gtjson, predjson): 132 | cocoGt = COCO(gtjson) 133 | cocoDt = cocoGt.loadRes(predjson) 134 | cocoeval_cap = CIDErEvalCap(cocoGt, cocoDt) 135 | cocoeval_cap.evaluate() 136 | 137 | print("\n########## Evaluation Summary ##########") 138 | print("CIDEr:\t{:.3f}%".format(cocoeval_cap.eval['CIDEr'] * 100.)) 139 | print() 140 | 141 | 142 | def single_captioning(pred, shape, region_count): 143 | 144 | assert isinstance(pred, dict) 145 | 146 | objects = [] 147 | for classname in pred: 148 | for box in pred[classname]: 149 | objects.append((classname, box[:4])) 150 | 151 | # part1: object existence 152 | object_names = [obj[0] for obj in objects] 153 | sentence_1 = object_existence_cap(object_names) 154 | 155 | # part2: object distribution 156 | sentence_2 = object_distribution_cap(objects, shape[1], shape[0], region_count=region_count) 157 | 158 | return sentence_1 + " " + sentence_2 159 | 160 | 161 | def captioning(gtjson, predjson, output_json): 162 | 163 | image_captions = [] 164 | all_pred = merge_json(predjson) 165 | 166 | image2pred = {} 167 | for pred in all_pred: 168 | if pred['image_id'] not in image2pred: 169 | image2pred[pred['image_id']] = [] 170 | image2pred[pred['image_id']].append(pred) 171 | 172 | with open(gtjson, 'r') as f: 173 | gtdata = json.load(f) 174 | 175 | for item in gtdata['images']: 176 | imageid = item['id'] 177 | 178 | if imageid not in image2pred: 179 | # form empty caption for json 180 | image_captions.append({ 181 | "image_id": imageid, 182 | "caption": "" 183 | }) 184 | continue 185 | 186 | objects = [] 187 | for obj in image2pred[imageid]: 188 | box = obj['bbox'] 189 | box[2] += box[0] 190 | box[3] += box[1] 191 | objects.append((obj['category_id'], box)) 192 | 193 | # part1: object existence 194 | object_names = [obj[0] for obj in objects] 195 | sentence_1 = object_existence_cap(object_names) 196 | 197 | # part2: object distribution 198 | sentence_2 = object_distribution_cap(objects, item['width'], item['height'], region_count=4) 199 | 200 | # form json 201 | image_captions.append({ 202 | "image_id": imageid, 203 | "caption": sentence_1 + " " + sentence_2 204 | }) 205 | 206 | with open(output_json, 'w') as f: 207 | json.dump(image_captions, f, indent=4) 208 | 209 | 210 | if __name__ == '__main__': 211 | from argparse import ArgumentParser 212 | parser = ArgumentParser() 213 | parser.add_argument('--gt_json', type=str, default='', help='Path to the ground truth JSON file') 214 | parser.add_argument('--detections', type=str, default='', help='Path to the folder of detection JSON files') 215 | parser.add_argument('--output_json', type=str, default='', help='Path to the output JSON file') 216 | parser.add_argument('--eval', action='store_true', help='Whether to evaluate the output JSON file') 217 | args = parser.parse_args() 218 | 219 | captioning(args.gt_json, args.detections, args.output_json) 220 | 221 | if args.eval: 222 | evaluate(args.gt_json, args.output_json) -------------------------------------------------------------------------------- /tasks/code/eval/CAP.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | import json 13 | from pycocotools.coco import COCO 14 | from tasks.code.metric.cidereval.eval import CIDErEvalCap 15 | from data.DiverseDataset import DiverseDataset 16 | from tasks.code.model import RemoteSAM 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | def evaluate(gtjson, predjson): 22 | cocoGt = COCO(gtjson) 23 | cocoDt = cocoGt.loadRes(predjson) 24 | cocoeval_cap = CIDErEvalCap(cocoGt, cocoDt) 25 | cocoeval_cap.evaluate() 26 | 27 | print("\n########## Evaluation Summary ##########") 28 | print("CIDEr:\t{:.3f}%".format(cocoeval_cap.eval['CIDEr'] * 100.)) 29 | print() 30 | 31 | 32 | def infer(model, data_loader, bert_model, args): 33 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | 36 | # evaluation variables 37 | cum_total = 0 38 | json_file = 'result.json' 39 | gt_json = {"images": [], "annotations": []} 40 | pred_json = [] 41 | classnames = data_loader.dataset.classes 42 | 43 | with torch.no_grad(): 44 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 45 | origin_image, groundtruth, image_name = data 46 | 47 | groundtruth = [g[0] for g in groundtruth] 48 | image_name = image_name[0] 49 | 50 | result = model.captioning(image=origin_image.squeeze(0).numpy(), classnames=classnames) 51 | # result = result.split(".")[0] + '.' 52 | 53 | # to form json text 54 | pred_json.append({ 55 | "image_id": data_loader.dataset.image2id[image_name], 56 | "caption": result 57 | }) 58 | 59 | # to form groundtruth json 60 | gt_json['images'].append({ 61 | "id": data_loader.dataset.image2id[image_name], 62 | "file_name": image_name, 63 | "height": origin_image.shape[1], 64 | "width": origin_image.shape[2] 65 | }) 66 | for caption in groundtruth: 67 | gt_json['annotations'].append({ 68 | "id": len(gt_json['annotations']), 69 | "image_id": data_loader.dataset.image2id[image_name], 70 | "caption": caption 71 | }) 72 | 73 | cum_total += 1 74 | 75 | del groundtruth, origin_image, image_name 76 | 77 | #! only for debug 78 | # if cum_total == 100: 79 | # break 80 | 81 | # sync ddp processes 82 | # gt_json['images'] 83 | gathered = [None for _ in range(utils.get_world_size())] 84 | torch.distributed.all_gather_object(gathered, gt_json['images']) 85 | gt_json['images'] = sum(gathered, []) 86 | # gt_json['annotations'] 87 | gathered = [None for _ in range(utils.get_world_size())] 88 | torch.distributed.all_gather_object(gathered, gt_json['annotations']) 89 | gt_json['annotations'] = sum(gathered, []) 90 | # pred_json 91 | gathered = [None for _ in range(utils.get_world_size())] 92 | torch.distributed.all_gather_object(gathered, pred_json) 93 | pred_json = sum(gathered, []) 94 | 95 | if args.local_rank == 0: 96 | gt_path = os.path.join(args.save_path, 'groundtruth.json') 97 | pred_path = os.path.join(args.save_path, 'result.json') 98 | # save caption 99 | pred_json.sort(key=lambda x: x['image_id']) 100 | with open(pred_path, 'w') as f: 101 | json.dump(pred_json, f, indent=4) 102 | # save groundtruth 103 | gt_json['images'].sort(key=lambda x: x['id']) 104 | gt_json['annotations'].sort(key=lambda x: x['image_id']) 105 | with open(gt_path, 'w') as f: 106 | json.dump(gt_json, f, indent=4) 107 | 108 | evaluate(gt_path, pred_path) 109 | 110 | 111 | def main(args): 112 | utils.init_distributed_mode(args) 113 | 114 | dataset_test = DiverseDataset(args) 115 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 116 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 117 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 118 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 119 | 120 | model = segmentation.__dict__[args.model](pretrained='',args=args) 121 | model.to(args.local_rank) 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 123 | single_model = model.module 124 | 125 | checkpoint = torch.load(args.resume, map_location='cpu') 126 | single_model.load_state_dict(checkpoint['model'], strict=False) 127 | 128 | if args.model != 'lavt_one': 129 | model_class = BertModel 130 | bert_model = model_class.from_pretrained(args.ck_bert) 131 | if args.ddp_trained_weights: 132 | bert_model.pooler = None 133 | bert_model.to(args.local_rank) 134 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 135 | single_bert_model = bert_model 136 | single_bert_model.load_state_dict(checkpoint['bert_model']) 137 | else: 138 | bert_model = None 139 | single_bert_model = None 140 | 141 | infer(model, data_loader_test, bert_model, args=args) 142 | 143 | 144 | 145 | 146 | if __name__ == "__main__": 147 | from args import get_parser 148 | parser = get_parser() 149 | args = parser.parse_args() 150 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/CNT.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | from data.DiverseDataset import DiverseDataset 13 | from tasks.code.model import RemoteSAM 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def infer(model, data_loader, bert_model, args): 19 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.15) 20 | metric_logger = utils.MetricLogger(delimiter=" ") 21 | 22 | # evaluation variables 23 | cum_correct, cum_total = 0, 0 24 | cum_boxes = [] 25 | 26 | with torch.no_grad(): 27 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 28 | origin_image, groundtruth, image_name, classname, foo = data 29 | 30 | groundtruth = groundtruth.numpy()[0] 31 | image_name = image_name[0] 32 | foo = foo.numpy()[0] 33 | classname = classname[0] 34 | 35 | boxes = model.detection(image=origin_image.squeeze(0).numpy(), classnames=[classname]) 36 | 37 | if boxes[classname]: 38 | boxes = boxes[classname] 39 | # prior_area_info filter 40 | indice = [] 41 | for i in range(len(boxes)): 42 | x1, y1, x2, y2, conf = boxes[i] 43 | area = (x2-x1)*(y2-y1) 44 | if area <= foo * 0.8: 45 | continue 46 | indice.append(i) 47 | boxes = [boxes[i] for i in indice] 48 | cnt = len(boxes) 49 | else: 50 | cnt = 0 51 | 52 | # evaluate 53 | if cnt == groundtruth: 54 | cum_correct += 1 55 | 56 | cum_total += 1 57 | 58 | del groundtruth, origin_image, image_name, foo, boxes 59 | 60 | #! only for debug 61 | # if cum_total == 10: 62 | # break 63 | 64 | # sync ddp processes 65 | # cum_total 66 | gathered = [None for _ in range(utils.get_world_size())] 67 | torch.distributed.all_gather_object(gathered, cum_total) 68 | cum_total = sum(gathered) 69 | # cum_correct 70 | gathered = [None for _ in range(utils.get_world_size())] 71 | torch.distributed.all_gather_object(gathered, cum_correct) 72 | cum_correct = sum(gathered) 73 | 74 | # sumarize evaluation 75 | print("\n########## Evaluation Summary ##########") 76 | print("Total count: {}".format(cum_total)) 77 | print("Correct count: {}".format(cum_correct)) 78 | print("Accuracy: {:.2f}%".format(cum_correct/cum_total * 100.)) 79 | print() 80 | 81 | 82 | def main(args): 83 | utils.init_distributed_mode(args) 84 | 85 | dataset_test = DiverseDataset(args) 86 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 87 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 88 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 89 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 90 | 91 | model = segmentation.__dict__[args.model](pretrained='',args=args) 92 | model.to(args.local_rank) 93 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 94 | single_model = model.module 95 | 96 | checkpoint = torch.load(args.resume, map_location='cpu') 97 | single_model.load_state_dict(checkpoint['model'], strict=False) 98 | 99 | if args.model != 'lavt_one': 100 | model_class = BertModel 101 | bert_model = model_class.from_pretrained(args.ck_bert) 102 | if args.ddp_trained_weights: 103 | bert_model.pooler = None 104 | bert_model.to(args.local_rank) 105 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 106 | single_bert_model = bert_model 107 | single_bert_model.load_state_dict(checkpoint['bert_model']) 108 | else: 109 | bert_model = None 110 | single_bert_model = None 111 | 112 | infer(model, data_loader_test, bert_model, args=args) 113 | 114 | 115 | if __name__ == "__main__": 116 | from args import get_parser 117 | parser = get_parser() 118 | args = parser.parse_args() 119 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/DET.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | import json 13 | from pycocotools.coco import COCO 14 | from pycocotools.cocoeval import COCOeval 15 | from data.DiverseDataset import DiverseDataset 16 | from tasks.code.model import RemoteSAM 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | def infer(model, data_loader, bert_model, args): 22 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 23 | metric_logger = utils.MetricLogger(delimiter=" ") 24 | 25 | # evaluation variables 26 | cum_total = 0 27 | cum_boxes = [] 28 | json_file = 'preds_%s.json' % args._class 29 | 30 | with torch.no_grad(): 31 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 32 | origin_image, groundtruth, image_name = data 33 | 34 | groundtruth = groundtruth.squeeze(0).numpy() 35 | image_name = image_name[0] 36 | classindex = data_loader.dataset.classes.index(args._class) 37 | 38 | boxes = model.detection(image=origin_image.squeeze(0).numpy(), classnames=[args._class]) 39 | if boxes[args._class]: 40 | boxes = boxes[args._class] 41 | 42 | # to form json text 43 | for i in range(len(boxes)): 44 | x1, y1, x2, y2, conf = boxes[i] 45 | cum_boxes.append({ 46 | 'image_id': data_loader.dataset.image2id[image_name], 47 | 'category_id': classindex, 48 | 'bbox': [int(x1), int(y1), int(x2-x1), int(y2-y1)], 49 | 'score': float(conf) 50 | }) 51 | 52 | cum_total += 1 53 | 54 | del groundtruth, origin_image, image_name, boxes 55 | 56 | #! only for debug 57 | # if cum_total == 10: 58 | # break 59 | 60 | # sync ddp processes 61 | # cum_boxes 62 | gathered = [None for _ in range(utils.get_world_size())] 63 | torch.distributed.all_gather_object(gathered, cum_boxes) 64 | cum_boxes = sum(gathered, []) 65 | 66 | cum_boxes.sort(key=lambda x: x['image_id']) 67 | 68 | if args.local_rank == 0: 69 | # save result 70 | with open(args.save_path + json_file, 'w') as f: 71 | json.dump(cum_boxes, f) 72 | 73 | # evaluate 74 | if len(cum_boxes) > 0: 75 | cocoGt = COCO(args.annoFolder) 76 | cocoDt = cocoGt.loadRes(args.save_path + json_file) 77 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') 78 | cocoEval.params.catIds = [classindex] 79 | cocoEval.evaluate() 80 | cocoEval.accumulate() 81 | # sumarize evaluation 82 | print("\n########## Evaluation Summary ##########") 83 | print("Done evaluating for category: ", args._class) 84 | cocoEval.summarize() 85 | print() 86 | else: 87 | print("No detection for category: ", args._class) 88 | print() 89 | 90 | 91 | def main(args): 92 | utils.init_distributed_mode(args) 93 | 94 | dataset_test = DiverseDataset(args) 95 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 96 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 97 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 98 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 99 | 100 | model = segmentation.__dict__[args.model](pretrained='',args=args) 101 | model.to(args.local_rank) 102 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 103 | single_model = model.module 104 | 105 | checkpoint = torch.load(args.resume, map_location='cpu') 106 | single_model.load_state_dict(checkpoint['model'], strict=False) 107 | 108 | if args.model != 'lavt_one': 109 | model_class = BertModel 110 | bert_model = model_class.from_pretrained(args.ck_bert) 111 | if args.ddp_trained_weights: 112 | bert_model.pooler = None 113 | bert_model.to(args.local_rank) 114 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 115 | single_bert_model = bert_model 116 | single_bert_model.load_state_dict(checkpoint['bert_model']) 117 | else: 118 | bert_model = None 119 | single_bert_model = None 120 | 121 | infer(model, data_loader_test, bert_model, args=args) 122 | 123 | 124 | if __name__ == "__main__": 125 | from args import get_parser 126 | parser = get_parser() 127 | args = parser.parse_args() 128 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/DET_DOTA.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | from data.DiverseDataset import DiverseDataset 13 | from tasks.code.model import RemoteSAM 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def infer(model, data_loader, bert_model, args): 19 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 20 | metric_logger = utils.MetricLogger(delimiter=" ") 21 | 22 | # evaluation variables 23 | cum_total = 0 24 | cum_boxes = [] 25 | 26 | with torch.no_grad(): 27 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 28 | origin_image, groundtruth, image_name = data 29 | 30 | groundtruth = groundtruth.squeeze(0).numpy() 31 | image_name = image_name[0] 32 | 33 | boxes = model.detection(image=origin_image.squeeze(0).numpy(), classnames=[args._class]) 34 | if boxes[args._class]: 35 | boxes = boxes[args._class] 36 | 37 | # to form txt 38 | if len(boxes) > 0: 39 | for i in range(len(boxes)): 40 | temp = boxes[i] 41 | temp.append(image_name) 42 | cum_boxes.append(temp) 43 | 44 | cum_total += 1 45 | 46 | del groundtruth, origin_image, image_name, boxes 47 | 48 | #! only for debug 49 | # if cum_total == 10: 50 | # break 51 | 52 | # sync ddp processes 53 | # cum_boxes 54 | gathered = [None for _ in range(utils.get_world_size())] 55 | torch.distributed.all_gather_object(gathered, cum_boxes) 56 | cum_boxes = sum(gathered, []) 57 | 58 | cum_boxes.sort(key=lambda x: x[5]) 59 | 60 | if args.local_rank == 0: 61 | # save result 62 | with open(args.save_path + 'Task2_' + args._class + '.txt', 'w') as f: 63 | for box in cum_boxes: 64 | x1, y1, x2, y2, conf, image_name = box 65 | pts = [x1, y1, x2, y1, x2, y2, x1, y2] 66 | box_string = ' '.join([str(float(a)) for a in pts]) 67 | f.write(image_name + " " + str(conf) + " " + box_string + '\n') 68 | 69 | print("Done evaluating the category: ", args._class) 70 | print() 71 | 72 | 73 | def main(args): 74 | utils.init_distributed_mode(args) 75 | 76 | dataset_test = DiverseDataset(args) 77 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 78 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 79 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 80 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 81 | 82 | model = segmentation.__dict__[args.model](pretrained='',args=args) 83 | model.to(args.local_rank) 84 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 85 | single_model = model.module 86 | 87 | checkpoint = torch.load(args.resume, map_location='cpu') 88 | single_model.load_state_dict(checkpoint['model'], strict=False) 89 | 90 | if args.model != 'lavt_one': 91 | model_class = BertModel 92 | bert_model = model_class.from_pretrained(args.ck_bert) 93 | if args.ddp_trained_weights: 94 | bert_model.pooler = None 95 | bert_model.to(args.local_rank) 96 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 97 | single_bert_model = bert_model 98 | single_bert_model.load_state_dict(checkpoint['bert_model']) 99 | else: 100 | bert_model = None 101 | single_bert_model = None 102 | 103 | infer(model, data_loader_test, bert_model, args=args) 104 | 105 | 106 | if __name__ == "__main__": 107 | from args import get_parser 108 | parser = get_parser() 109 | args = parser.parse_args() 110 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/MCC.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | from data.DiverseDataset import DiverseDataset 13 | from tasks.code.model import RemoteSAM 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def eval(class_count, true, prob): 19 | from tasks.code.metric.RunningScore import MultiClassRunningScore 20 | 21 | prob = torch.Tensor(prob) 22 | true = torch.Tensor(true) 23 | 24 | running_metrics = MultiClassRunningScore(class_count) 25 | running_metrics.update(true.type(torch.int64), prob.type(torch.float64)) 26 | 27 | accuracy = running_metrics.accuracy() 28 | 29 | # sumarize evaluation 30 | print("\n########## Evaluation Summary ##########") 31 | print("Accuracy: {:.3f}%".format(accuracy["Accuracy"] * 100.)) 32 | print() 33 | 34 | 35 | def classification(model, data_loader, bert_model, args): 36 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | 39 | # evaluation variables 40 | cum_total = 0 41 | true = [] 42 | prob = [] 43 | classnames = data_loader.dataset.classes 44 | 45 | with torch.no_grad(): 46 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 47 | origin_image, groundtruth, image_name = data 48 | 49 | groundtruth = groundtruth.item() 50 | image_name = image_name[0] 51 | 52 | _, prob_ins = model.multi_class_cls(image=origin_image.squeeze(0).numpy(), classnames=classnames, return_prob=True) 53 | 54 | true.append(groundtruth) 55 | prob.append([prob_ins[classname] for classname in classnames]) 56 | 57 | cum_total += 1 58 | 59 | del groundtruth, origin_image, image_name, prob_ins 60 | 61 | #! only for debug 62 | # if cum_total == 10: 63 | # break 64 | 65 | # sync the process 66 | # true 67 | gathered = [None for _ in range(utils.get_world_size())] 68 | torch.distributed.all_gather_object(gathered, true) 69 | true = sum(gathered, []) 70 | # prob 71 | gathered = [None for _ in range(utils.get_world_size())] 72 | torch.distributed.all_gather_object(gathered, prob) 73 | prob = sum(gathered, []) 74 | 75 | return len(classnames), true, prob 76 | 77 | 78 | def main(args): 79 | utils.init_distributed_mode(args) 80 | 81 | dataset_test = DiverseDataset(args) 82 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 83 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 84 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 85 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 86 | 87 | model = segmentation.__dict__[args.model](pretrained='',args=args) 88 | model.to(args.local_rank) 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 90 | single_model = model.module 91 | 92 | checkpoint = torch.load(args.resume, map_location='cpu') 93 | single_model.load_state_dict(checkpoint['model'], strict=False) 94 | 95 | if args.model != 'lavt_one': 96 | model_class = BertModel 97 | bert_model = model_class.from_pretrained(args.ck_bert) 98 | if args.ddp_trained_weights: 99 | bert_model.pooler = None 100 | bert_model.to(args.local_rank) 101 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 102 | single_bert_model = bert_model 103 | single_bert_model.load_state_dict(checkpoint['bert_model']) 104 | else: 105 | bert_model = None 106 | single_bert_model = None 107 | 108 | class_count, true, prob = classification(model, data_loader_test, bert_model, args=args) 109 | 110 | eval(class_count, true, prob) 111 | 112 | 113 | if __name__ == "__main__": 114 | from args import get_parser 115 | parser = get_parser() 116 | args = parser.parse_args() 117 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/MLC.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | from PIL import Image 12 | from data.DiverseDataset import DiverseDataset 13 | from tasks.code.model import RemoteSAM 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def eval(class_count, true, pred, prob): 19 | from tasks.code.metric.RunningScore import MultiLabelRunningScore 20 | 21 | true = torch.Tensor(true) 22 | pred = torch.Tensor(pred) 23 | prob = torch.Tensor(prob) 24 | 25 | running_metrics = MultiLabelRunningScore(class_count) 26 | running_metrics.update(true.type(torch.int64), pred.type(torch.int64), prob) 27 | 28 | accuracy = running_metrics.accuracy() 29 | 30 | # sumarize evaluation 31 | print("\n########## Evaluation Summary ##########") 32 | print("Accuracy: {:.2f}%".format(accuracy["Accuracy"] * 100.)) 33 | print("Accuracy per class: {}".format( 34 | " ".join([str(round(acc * 100., 2)) + "%" for acc in accuracy["Accuracy per Class"]]) 35 | )) 36 | print() 37 | 38 | 39 | def classification(model, data_loader, bert_model, args): 40 | model = RemoteSAM(model, args.local_rank, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 41 | metric_logger = utils.MetricLogger(delimiter=" ") 42 | 43 | # evaluation variables 44 | cum_total = 0 45 | true = [] 46 | pred = [] 47 | prob = [] 48 | classnames = data_loader.dataset.classes 49 | 50 | with torch.no_grad(): 51 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 52 | origin_image, groundtruth, image_name = data 53 | 54 | groundtruth = groundtruth.squeeze(0).numpy() 55 | image_name = image_name[0] 56 | 57 | pred_ins, prob_ins = model.multi_label_cls(image=origin_image.squeeze(0).numpy(), classnames=classnames, return_prob=True) 58 | 59 | true.append(groundtruth) 60 | pred.append([1 if classname in pred_ins else 0 for classname in classnames]) 61 | prob.append([prob_ins[classname] for classname in classnames]) 62 | 63 | cum_total += 1 64 | 65 | del groundtruth, origin_image, image_name, pred_ins, prob_ins 66 | 67 | #! only for debug 68 | # if cum_total == 10: 69 | # break 70 | 71 | # sync the process 72 | # true 73 | gathered = [None for _ in range(utils.get_world_size())] 74 | torch.distributed.all_gather_object(gathered, true) 75 | true = sum(gathered, []) 76 | # pred 77 | gathered = [None for _ in range(utils.get_world_size())] 78 | torch.distributed.all_gather_object(gathered, pred) 79 | pred = sum(gathered, []) 80 | # prob 81 | gathered = [None for _ in range(utils.get_world_size())] 82 | torch.distributed.all_gather_object(gathered, prob) 83 | prob = sum(gathered, []) 84 | 85 | return len(classnames), true, pred, prob 86 | 87 | 88 | def main(args): 89 | utils.init_distributed_mode(args) 90 | 91 | dataset_test = DiverseDataset(args) 92 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, 93 | num_replicas=utils.get_world_size(), rank=utils.get_rank(), shuffle=False) 94 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 95 | batch_size=1, sampler=test_sampler, num_workers=args.workers, pin_memory=args.pin_mem) 96 | 97 | model = segmentation.__dict__[args.model](pretrained='',args=args) 98 | model.to(args.local_rank) 99 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 100 | single_model = model.module 101 | 102 | checkpoint = torch.load(args.resume, map_location='cpu') 103 | single_model.load_state_dict(checkpoint['model'], strict=False) 104 | 105 | if args.model != 'lavt_one': 106 | model_class = BertModel 107 | bert_model = model_class.from_pretrained(args.ck_bert) 108 | if args.ddp_trained_weights: 109 | bert_model.pooler = None 110 | bert_model.to(args.local_rank) 111 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 112 | single_bert_model = bert_model 113 | single_bert_model.load_state_dict(checkpoint['bert_model']) 114 | else: 115 | bert_model = None 116 | single_bert_model = None 117 | 118 | class_count, true, pred, prob = classification(model, data_loader_test, bert_model, args=args) 119 | 120 | eval(class_count, true, pred, prob) 121 | 122 | 123 | if __name__ == "__main__": 124 | from args import get_parser 125 | parser = get_parser() 126 | args = parser.parse_args() 127 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/REF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import utils 4 | import numpy as np 5 | import transforms as T 6 | import random 7 | from bert.modeling_bert import BertModel 8 | from lib import segmentation 9 | 10 | 11 | 12 | 13 | def get_dataset(image_set, transform, args): 14 | from data.dataset_refer_bert import ReferDataset 15 | ds = ReferDataset(args, 16 | split=image_set, 17 | image_transforms=transform, 18 | target_transforms=None, 19 | eval_mode=True 20 | ) 21 | num_classes = 2 22 | return ds, num_classes 23 | 24 | 25 | def evaluate(model, data_loader, bert_model, device): 26 | model.eval() 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | 29 | # evaluation variables 30 | cum_I, cum_U = 0, 0 31 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 32 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 33 | seg_total = 0 34 | mean_IoU = [] 35 | header = 'Test:' 36 | 37 | with torch.no_grad(): 38 | 39 | for data in metric_logger.log_every(data_loader, 100, header): 40 | image, target, sentences, attentions = data 41 | image, target, sentences, attentions = image.to(device), target.to(device), \ 42 | sentences.to(device), attentions.to(device) 43 | sentences = sentences.squeeze(1) 44 | attentions = attentions.squeeze(1) 45 | target = target.cpu().data.numpy() 46 | for j in range(sentences.size(-1)): 47 | if bert_model is not None: 48 | last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] 49 | embedding = last_hidden_states.permute(0, 2, 1) 50 | output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) 51 | else: 52 | output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) 53 | 54 | output = output.cpu() 55 | 56 | output_mask = output.argmax(1).data.numpy() 57 | 58 | I, U = computeIoU(output_mask, target) 59 | if U == 0: 60 | this_iou = 0.0 61 | else: 62 | this_iou = I*1.0/U 63 | mean_IoU.append(this_iou) 64 | cum_I += I 65 | cum_U += U 66 | for n_eval_iou in range(len(eval_seg_iou_list)): 67 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 68 | seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) 69 | 70 | seg_total += 1 71 | 72 | 73 | del image, target, sentences, attentions, output,output_mask 74 | if bert_model is not None: 75 | del last_hidden_states, embedding 76 | 77 | mean_IoU = np.array(mean_IoU) 78 | mIoU = np.mean(mean_IoU) 79 | print('Final results:') 80 | print('Mean IoU is %.2f\n' % (mIoU*100.)) 81 | results_str = '' 82 | for n_eval_iou in range(len(eval_seg_iou_list)): 83 | results_str += ' precision@%s = %.2f\n' % \ 84 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 85 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 86 | print(results_str) 87 | 88 | 89 | 90 | 91 | def get_transform(args): 92 | transforms = [T.Resize(args.img_size, args.img_size), 93 | T.ToTensor(), 94 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 95 | ] 96 | 97 | return T.Compose(transforms) 98 | 99 | 100 | def computeIoU(pred_seg, gd_seg): 101 | I = np.sum(np.logical_and(pred_seg, gd_seg)) 102 | U = np.sum(np.logical_or(pred_seg, gd_seg)) 103 | 104 | return I, U 105 | 106 | 107 | def main(args): 108 | device = torch.device(args.device) 109 | dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) 110 | 111 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 112 | data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, 113 | sampler=test_sampler, num_workers=args.workers) 114 | print(args.model) 115 | single_model = segmentation.__dict__[args.model](pretrained='',args=args) 116 | checkpoint = torch.load(args.resume, map_location='cpu') 117 | single_model.load_state_dict(checkpoint['model'], strict=False) 118 | model = single_model.to(device) 119 | 120 | if args.model != 'lavt_one': 121 | model_class = BertModel 122 | single_bert_model = model_class.from_pretrained(args.ck_bert) 123 | # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines 124 | if args.ddp_trained_weights: 125 | single_bert_model.pooler = None 126 | single_bert_model.load_state_dict(checkpoint['bert_model']) 127 | bert_model = single_bert_model.to(device) 128 | else: 129 | bert_model = None 130 | 131 | evaluate(model, data_loader_test, bert_model, device=device) 132 | 133 | 134 | if __name__ == "__main__": 135 | from args import get_parser 136 | parser = get_parser() 137 | args = parser.parse_args() 138 | print('Image size: {}'.format(str(args.img_size))) 139 | main(args) 140 | -------------------------------------------------------------------------------- /tasks/code/eval/SEG.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch.nn.functional as F 5 | from torchvision.transforms import functional as F1 6 | 7 | from torchvision.transforms import InterpolationMode 8 | import transforms as T 9 | from bert.tokenization_bert import BertTokenizer 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import warnings 15 | import random 16 | 17 | import torch.backends.cudnn as cudnn 18 | 19 | from lib import segmentation 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | def fast_hist(a, b, n): 25 | k = (a >= 0) & (a < n) 26 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 27 | 28 | 29 | def per_class_iu(hist): 30 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 31 | 32 | 33 | def per_class_PA(hist): 34 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 35 | 36 | 37 | def compute_mIoU(args, model, bert_model, cat_dictionary, refToInput, refToAttention, gt_dir, pred_dir, num_classes, 38 | name_classes): 39 | print('Num classes', num_classes) 40 | hist = np.zeros((num_classes, num_classes)) 41 | png_name_list = [] 42 | for file in os.listdir(gt_dir): 43 | if file.endswith(('.jpg', '.png', '.jpeg')): 44 | file_name_without_extension = os.path.splitext(file)[0] 45 | png_name_list.append(file_name_without_extension) 46 | 47 | gt_imgs = [os.path.join(gt_dir, x + ".png") for x in png_name_list] 48 | pred_imgs = [os.path.join(pred_dir, x[0:5] + ".png") for x in png_name_list] 49 | 50 | 51 | for ind in range(len(gt_imgs)): 52 | 53 | catToid = {} 54 | 55 | for index, (category, color) in enumerate(cat_dictionary.items()): 56 | catToid[category] = index + 1 # ??0??? 57 | img = Image.open(pred_imgs[ind]).convert('RGB') 58 | target = Image.open(gt_imgs[ind]).convert('RGB') 59 | target = F1.resize(target, (896, 896), interpolation=InterpolationMode.NEAREST) 60 | o_W, o_H = target.size 61 | 62 | label = change_label(target, cat_dictionary) 63 | 64 | transformimg = get_transform(args) 65 | img, img2 = transformimg(img, img) 66 | output_list = [] 67 | img = img.unsqueeze(0).cuda() 68 | 69 | for cat, cat_value in cat_dictionary.items(): 70 | emb = refToInput[cat].cuda() 71 | att_mask = refToAttention[cat].cuda() 72 | output = model(img, emb, att_mask) 73 | 74 | # output = output 75 | # output = F.interpolate(output, (o_H, o_W), align_corners=True, mode='bilinear') 76 | 77 | output = output.argmax(1).data 78 | # output=torch.sigmoid(output) 79 | 80 | 81 | output = class_colors_transform(output, catToid, cat) 82 | output = np.array(output.cpu()).astype(int) 83 | output_list.append(output) 84 | 85 | # merge_output = np.zeros((o_H, o_H), dtype=output_list[0].dtype) 86 | 87 | merge_output = np.maximum.reduce(output_list) 88 | pred = merge_output 89 | 90 | 91 | if len(label.flatten()) != len(pred.flatten()): 92 | print( 93 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( 94 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind], 95 | pred_imgs[ind])) 96 | continue 97 | 98 | 99 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 100 | if ind > 0 and ind % 10 == 0: 101 | print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs), 102 | 100 * np.nanmean(per_class_iu(hist)), 103 | 100 * np.nanmean(per_class_PA(hist)))) 104 | # ------------------------------------------------# 105 | # mIoU 106 | # ------------------------------------------------# 107 | mIoUs = per_class_iu(hist) 108 | mPA = per_class_PA(hist) 109 | for ind_class in range(num_classes): 110 | # ind_class=ind_class+1 111 | print('===>' + name_classes[ind_class] + ':\tmIou-' + str(round(mIoUs[ind_class] * 100, 2)) + '; mPA-' + str( 112 | round(mPA[ind_class] * 100, 2))) 113 | 114 | print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(mPA) * 100, 2))) 115 | return mIoUs 116 | 117 | 118 | def class_colors_transform(predict, catToid, cat): 119 | color_value = catToid[cat] 120 | zeros_tensor = torch.zeros_like(predict) 121 | 122 | transformed_predict = torch.where(predict > 0, torch.full_like(predict, color_value), zeros_tensor) 123 | # transformed_predict = torch.where(predict > 0.3, torch.full_like(predict, color_value), predict) 124 | 125 | return transformed_predict 126 | 127 | 128 | def get_sent_embedding(cat_dictionary, tokenizer, refToInput, refToAttention, refToExist, max_tokens=20): 129 | for cat, _ in cat_dictionary.items(): 130 | attention_mask = [0] * max_tokens 131 | padded_input_id = [0] * max_tokens 132 | catnew = cat + " in the image." 133 | input_id = tokenizer.encode(text=catnew, add_special_tokens=True) 134 | input_id = input_id[:max_tokens] 135 | 136 | padded_input_id[:len(input_id)] = input_id 137 | attention_mask[:len(input_id)] = [1] * len(input_id) 138 | 139 | input_id = torch.tensor(padded_input_id).unsqueeze(0) 140 | attention_mask = torch.tensor(attention_mask).unsqueeze(0) 141 | exist = torch.Tensor([True]) 142 | refToInput[cat] = input_id 143 | refToAttention[cat] = attention_mask 144 | refToExist[cat] = exist 145 | return refToInput, refToAttention, refToExist 146 | 147 | 148 | def change_label(label, cat_dictionary): 149 | label = np.array(label).astype(int) 150 | label_mapped = np.zeros((label.shape[0], label.shape[1]), dtype=np.uint8) 151 | 152 | categories = list(cat_dictionary.keys()) 153 | 154 | for index, (category, color) in enumerate(cat_dictionary.items()): 155 | mask = np.all(label == color, axis=-1) 156 | label_mapped[mask] = index + 1 157 | 158 | return label_mapped 159 | 160 | 161 | def get_transform(args): 162 | transforms = [T.Resize(args.img_size, args.img_size), 163 | T.ToTensor(), 164 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 165 | ] 166 | 167 | return T.Compose(transforms) 168 | 169 | 170 | 171 | def main(args): 172 | cudnn.benchmark = True 173 | cudnn.deterministic = True 174 | 175 | device = torch.device(args.device) 176 | 177 | single_model = segmentation.__dict__[args.model](pretrained='', args=args) 178 | checkpoint = torch.load(args.resume, map_location='cpu') 179 | single_model.load_state_dict(checkpoint['model'], strict=False) 180 | model = single_model.to(device) 181 | 182 | if args.model != 'lavt_one': 183 | model_class = BertModel 184 | single_bert_model = model_class.from_pretrained(args.ck_bert) 185 | if args.ddp_trained_weights: 186 | single_bert_model.pooler = None 187 | single_bert_model.load_state_dict(checkpoint['bert_model']) 188 | bert_model = single_bert_model.to(device) 189 | else: 190 | bert_model = None 191 | 192 | model.eval() 193 | img_path = "/data/dishimin/iSAID/test/images" 194 | label_path = "/data/dishimin/iSAID/test/masks" 195 | 196 | max_tokens = 20 197 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 198 | 199 | refToInput = {} 200 | refToAttention = {} 201 | refToExist = {} 202 | cat_dictionary = { 203 | "ship": (0, 0, 63), 204 | "storagetank": (0, 63, 63), 205 | "baseballfield": (0, 63, 0), 206 | "tenniscourt": (0, 63, 127), 207 | "basketballcourt": (0, 63, 191), 208 | "groundtrackfield": (0, 63, 255), 209 | "bridge": (0, 127, 63), 210 | "large-vehicle": (0, 127, 127), 211 | "small-vehicle": (0, 0, 127), 212 | "helicopter": (0, 0, 191), 213 | "swimming-pool": (0, 0, 255), 214 | "roundabout": (0, 191, 127), 215 | "soccer-ball-field": (0, 127, 191), 216 | "airplane": (0, 127, 255), 217 | "harbor": (0, 100, 155) 218 | } 219 | name_classes = [] 220 | name_classes.append("background") 221 | for keys, _ in cat_dictionary.items(): 222 | name_classes.append(keys) 223 | refToInput, refToAttention, refToExist = get_sent_embedding(cat_dictionary, tokenizer, refToInput, refToAttention, 224 | refToExist, max_tokens=20) 225 | with torch.no_grad(): 226 | compute_mIoU(args, model, bert_model, cat_dictionary, refToInput, refToAttention, label_path, img_path, 227 | len(cat_dictionary.keys()) + 1, name_classes) 228 | 229 | if __name__ == '__main__': 230 | from args import get_parser 231 | parser = get_parser() 232 | args = parser.parse_args() 233 | main(args) -------------------------------------------------------------------------------- /tasks/code/eval/VG.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # set working directory to root 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) 5 | 6 | import torch 7 | import torch.utils.data 8 | import utils 9 | import numpy as np 10 | from bert.modeling_bert import BertModel 11 | from lib import segmentation 12 | from data.DiverseDataset import DiverseDataset 13 | from tasks.code.model import RemoteSAM 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def infer(model, data_loader, bert_model, device): 19 | model = RemoteSAM(model, device, use_EPOC=args.EPOC, EPOC_threshold = 0.25) 20 | metric_logger = utils.MetricLogger(delimiter=" ") 21 | 22 | # evaluation variables 23 | cum_correct, cum_total = 0, 0 24 | cum_IoU = [] 25 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 26 | seg_correct = np.zeros_like(eval_seg_iou_list, dtype=np.int32) 27 | 28 | with torch.no_grad(): 29 | for data in metric_logger.log_every(data_loader, 100, 'Test:'): 30 | origin_image, groundtruth, image_name, sentence = data 31 | 32 | # prepare 33 | groundtruth = groundtruth.squeeze(0).numpy() 34 | image_name = image_name[0] 35 | sentence = sentence[0] 36 | 37 | boxes = model.visual_grounding(image=origin_image.squeeze(0).numpy(), sentence=sentence) 38 | 39 | if boxes: 40 | IoU = computeIoU(boxes, groundtruth) 41 | cum_IoU.append(IoU) 42 | else: 43 | cum_IoU.append(0.0) 44 | 45 | # evaluate 46 | for n_eval_iou in range(len(eval_seg_iou_list)): 47 | if IoU >= eval_seg_iou_list[n_eval_iou]: 48 | seg_correct[n_eval_iou] += 1 49 | 50 | cum_total += 1 51 | 52 | del groundtruth, origin_image, image_name, boxes 53 | if bert_model is not None: 54 | del last_hidden_states, embedding 55 | 56 | #! only for debug 57 | # if cum_total == 10: 58 | # break 59 | 60 | # sumarize evaluation 61 | print("\n########## Evaluation Summary ##########") 62 | print('Mean IoU:\t{:.2f}%'.format(np.mean(np.array(cum_IoU)) * 100.)) 63 | for n_eval_iou in range(len(eval_seg_iou_list)): 64 | print('precision@{}:\t{:.2f}%'.format(eval_seg_iou_list[n_eval_iou], seg_correct[n_eval_iou] / cum_total * 100.)) 65 | print() 66 | 67 | 68 | def computeIoU(pred, gt): 69 | pred_x1, pred_y1, pred_x2, pred_y2 = pred 70 | gt_x1, gt_y1, gt_x2, gt_y2 = gt 71 | 72 | # calculate intersection 73 | inter_x1 = max(pred_x1, gt_x1) 74 | inter_y1 = max(pred_y1, gt_y1) 75 | inter_x2 = min(pred_x2, gt_x2) 76 | inter_y2 = min(pred_y2, gt_y2) 77 | inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) 78 | 79 | # calculate union 80 | pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) 81 | gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1) 82 | union_area = pred_area + gt_area - inter_area 83 | 84 | if union_area == 0: 85 | IoU = 0.0 86 | else: 87 | IoU = inter_area / union_area 88 | 89 | return IoU 90 | 91 | 92 | def main(args): 93 | device = torch.device(args.device) 94 | 95 | dataset_test = DiverseDataset(args) 96 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 97 | data_loader_test = torch.utils.data.DataLoader(dataset_test, 98 | batch_size=1, sampler=test_sampler, num_workers=args.workers) 99 | 100 | single_model = segmentation.__dict__[args.model](pretrained='',args=args) 101 | checkpoint = torch.load(args.resume, map_location='cpu') 102 | single_model.load_state_dict(checkpoint['model'], strict=False) 103 | model = single_model.to(device) 104 | 105 | if args.model != 'lavt_one': 106 | model_class = BertModel 107 | single_bert_model = model_class.from_pretrained(args.ck_bert) 108 | if args.ddp_trained_weights: 109 | single_bert_model.pooler = None 110 | single_bert_model.load_state_dict(checkpoint['bert_model']) 111 | bert_model = single_bert_model.to(device) 112 | else: 113 | bert_model = None 114 | 115 | infer(model, data_loader_test, bert_model, device=device) 116 | 117 | 118 | if __name__ == "__main__": 119 | from args import get_parser 120 | parser = get_parser() 121 | args = parser.parse_args() 122 | main(args) -------------------------------------------------------------------------------- /tasks/code/metric/RunningScore.py: -------------------------------------------------------------------------------- 1 | ''' 2 | from: https://github.com/biasvariancelabs/aitlas/ 3 | ''' 4 | import dill 5 | import numpy as np 6 | import torch 7 | from ignite.metrics import confusion_matrix 8 | from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix 9 | from sklearn.metrics import average_precision_score, roc_auc_score 10 | from torchmetrics.detection.mean_ap import MeanAveragePrecision 11 | 12 | 13 | class BaseMetric: 14 | """Base class for metrics""" 15 | 16 | def __init__(self, device="cpu", **kwargs): 17 | self.device = device 18 | 19 | def calculate(self, y_true, y_pred): 20 | raise NotImplementedError("Please implement you metric calculation logic here.") 21 | 22 | 23 | class RunningScore(object): 24 | 25 | def __init__(self, num_classes, device): 26 | self.num_classes = num_classes 27 | self.device = device 28 | self.confusion_matrix = None 29 | 30 | def __getstate__(self): 31 | state = self.__dict__.copy() 32 | state["confusion_matrix"] = dill.dumps(state["confusion_matrix"]) 33 | return state 34 | 35 | def __setstate__(self, state): 36 | new_state = state 37 | new_state["confusion_matrix"] = dill.loads(state["confusion_matrix"]) 38 | self.__dict__.update(new_state) 39 | 40 | def update(self, y_true, y_pred, y_prob=None): 41 | """Updates stats on each batch""" 42 | 43 | self.confusion_matrix.update((y_pred, y_true)) 44 | 45 | def reset(self): 46 | """Reset the confusion matrix""" 47 | self.confusion_matrix.reset() 48 | 49 | def get_computed(self): 50 | return self.confusion_matrix.compute().type(torch.DoubleTensor) 51 | 52 | def precision(self): 53 | raise NotImplementedError 54 | 55 | def accuracy(self): 56 | raise NotImplementedError 57 | 58 | def weights(self): 59 | raise NotImplementedError 60 | 61 | def recall(self): 62 | raise NotImplementedError 63 | 64 | def f1_score(self): 65 | precision = self.precision() 66 | recall = self.recall() 67 | micro = ( 68 | 2 69 | * precision["Precision Micro"] 70 | * recall["Recall Micro"] 71 | / (precision["Precision Micro"] + recall["Recall Micro"] + 1e-15) 72 | ) 73 | per_class = ( 74 | 2 75 | * precision["Precision per Class"] 76 | * recall["Recall per Class"] 77 | / (precision["Precision per Class"] + recall["Recall per Class"] + 1e-15) 78 | ) 79 | 80 | return { 81 | "F1_score Micro": float(micro), 82 | "F1_score Macro": np.mean(per_class), 83 | "F1_score Weighted": np.sum(self.weights() * per_class), 84 | "F1_score per Class": per_class, 85 | } 86 | 87 | def iou(self): 88 | raise NotImplementedError 89 | 90 | def get_scores(self, metrics): 91 | """Returns the specified metrics""" 92 | result = [] 93 | for metric in metrics: 94 | result.append(getattr(self, metric)()) 95 | return result 96 | 97 | 98 | class MultiClassRunningScore(RunningScore): 99 | """Calculates confusion matrix for multi-class data. This class contains metrics that are averaged over batches. """ 100 | 101 | def __init__(self, num_classes, device='cpu'): 102 | super().__init__(num_classes, device) 103 | self.confusion_matrix = confusion_matrix.ConfusionMatrix( 104 | num_classes=num_classes, device=device 105 | ) 106 | 107 | def accuracy(self): 108 | cm = self.get_computed() 109 | accuracy = cm.diag().sum() / (cm.sum() + 1e-15) 110 | return {"Accuracy": float(accuracy)} 111 | 112 | def weights(self): 113 | cm = self.get_computed() 114 | return (cm.sum(dim=1) / cm.sum()).numpy() 115 | 116 | def recall(self): 117 | cm = self.get_computed() 118 | micro = cm.diag().sum() / (cm.sum() + 1e-15) # same as accuracy for multiclass 119 | macro = ( 120 | cm.diag() / (cm.sum(dim=1) + 1e-15) 121 | ).mean() # same as average accuracy in breizhcrops 122 | weighted = ( 123 | (cm.diag() / (cm.sum(dim=1) + 1e-15)) 124 | * ((cm.sum(dim=1)) / (cm.sum() + 1e-15)) 125 | ).sum() 126 | per_class = cm.diag() / (cm.sum(dim=1) + 1e-15) 127 | 128 | return { 129 | "Recall Micro": float(micro), 130 | "Recall Macro": float(macro), 131 | "Recall Weighted": float(weighted), 132 | "Recall per Class": per_class.numpy(), 133 | } 134 | 135 | def precision(self): 136 | cm = self.get_computed() 137 | micro = cm.diag().sum() / (cm.sum() + 1e-15) # same as accuracy for multiclass 138 | macro = (cm.diag() / (cm.sum(dim=0) + 1e-15)).mean() 139 | weighted = ( 140 | (cm.diag() / (cm.sum(dim=0) + 1e-15)) 141 | * ((cm.sum(dim=1)) / (cm.sum() + 1e-15)) 142 | ).sum() 143 | per_class = cm.diag() / (cm.sum(dim=0) + 1e-15) 144 | 145 | return { 146 | "Precision Micro": float(micro), 147 | "Precision Macro": float(macro), 148 | "Precision Weighted": float(weighted), 149 | "Precision per Class": per_class.numpy(), 150 | } 151 | 152 | def iou(self): 153 | cm = self.get_computed() 154 | iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15) 155 | 156 | return {"IOU": iou.tolist(), "mIOU": float(iou.mean())} 157 | 158 | def kappa(self): 159 | cm = self.get_computed() 160 | N = cm.shape[0] 161 | 162 | act_hist = cm.sum(axis=1) 163 | 164 | pred_hist = cm.sum(axis=0) 165 | 166 | num_samples = cm.sum() 167 | 168 | total_agreements = cm.diag().sum() 169 | agreements_chance = (act_hist * pred_hist) / num_samples 170 | agreements_chance = agreements_chance.sum() 171 | kappa = (total_agreements - agreements_chance) / ( 172 | num_samples - agreements_chance 173 | ) 174 | return {"Kappa metric": kappa} 175 | 176 | 177 | class MultiLabelRunningScore(RunningScore): 178 | """Calculates a confusion matrix for multi-labelled, multi-class data in addition to the """ 179 | 180 | def __init__(self, num_classes, device='cpu'): 181 | super().__init__(num_classes, device) 182 | self.confusion_matrix = MultiLabelConfusionMatrix( 183 | num_classes=self.num_classes, device=self.device, 184 | ) 185 | self.list_y_prob = [] 186 | self.list_y_true = [] 187 | 188 | def reset(self): 189 | """Reset the confusion matrix and list of probabilities""" 190 | self.confusion_matrix.reset() 191 | self.list_y_prob = [] 192 | self.list_y_true = [] 193 | 194 | def update(self, y_true, y_pred, y_prob=None): 195 | """Updates stats on each batch""" 196 | self.confusion_matrix.update((y_pred, y_true)) 197 | self.list_y_prob.extend(y_prob.tolist()) 198 | self.list_y_true.extend(y_true.tolist()) 199 | 200 | def map(self): 201 | return { 202 | "mAP": average_precision_score( 203 | np.array(self.list_y_true), np.array(self.list_y_prob) 204 | ) 205 | } 206 | 207 | def roc_auc_score(self): 208 | return { 209 | "roc_auc_score": roc_auc_score( 210 | np.array(self.list_y_true), np.array(self.list_y_prob), average=None 211 | ) 212 | } 213 | 214 | def accuracy(self): 215 | tp, tn, fp, fn = self.get_outcomes() 216 | tp_total, tn_total, fp_total, fn_total = self.get_outcomes(total=True) 217 | 218 | accuracy = (tp_total + tn_total) / ( 219 | tp_total + tn_total + fp_total + fn_total + 1e-15 220 | ) 221 | accuracy_per_class = (tp + tn) / (tp + tn + fp + fn + 1e-15) 222 | 223 | return {"Accuracy": accuracy, "Accuracy per Class": accuracy_per_class} 224 | 225 | def precision(self): 226 | tp, tn, fp, fn = self.get_outcomes() 227 | tp_total, tn_total, fp_total, fn_total = self.get_outcomes(total=True) 228 | micro = tp_total / (tp_total + fp_total + 1e-15) 229 | per_class = tp / (tp + fp + 1e-15) 230 | macro = np.mean(per_class) 231 | weighted = np.sum(per_class * self.weights()) 232 | return { 233 | "Precision Micro": float(micro), 234 | "Precision Macro": macro, 235 | "Precision Weighted": weighted, 236 | "Precision per Class": per_class, 237 | } 238 | 239 | def weights(self): 240 | tp, tn, fp, fn = self.get_outcomes() 241 | weights = (tp + fn) / self.get_samples() 242 | return weights 243 | 244 | def recall(self): 245 | tp, tn, fp, fn = self.get_outcomes() 246 | tp_total, tn_total, fp_total, fn_total = self.get_outcomes(total=True) 247 | micro = tp_total / (tp_total + fn_total + 1e-15) 248 | per_class = tp / (tp + fn + 1e-15) 249 | macro = np.mean(per_class) 250 | weighted = np.sum(per_class * self.weights()) 251 | return { 252 | "Recall Micro": float(micro), 253 | "Recall Macro": macro, 254 | "Recall Weighted": weighted, 255 | "Recall per Class": per_class, 256 | } 257 | 258 | def get_outcomes(self, total=False): 259 | """ 260 | Return true/false positives/negatives from the confusion matrix 261 | :param total: do we need to return per class or total 262 | """ 263 | cm = self.get_computed() 264 | tp = cm[:, 1, 1] 265 | tn = cm[:, 0, 0] 266 | fp = cm[:, 0, 1] 267 | fn = cm[:, 1, 0] 268 | 269 | if total: # sum it all if we need to calculate the totals 270 | tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum() 271 | 272 | return tp.numpy(), tn.numpy(), fp.numpy(), fn.numpy() 273 | 274 | def count(self): 275 | tp, tn, fp, fn = self.get_outcomes(True) 276 | return tp + tn + fp + fn 277 | 278 | def get_samples(self): 279 | cm = self.confusion_matrix.compute().cpu().detach().numpy() 280 | return np.sum(cm[:, 1, 0]) + np.sum(cm[:, 1, 1]) 281 | 282 | def iou(self): 283 | tp, tn, fp, fn = self.get_outcomes() 284 | tp_total, tn_total, fp_total, fn_total = self.get_outcomes(total=True) 285 | 286 | iou_per_class = tp / (tp + fp + fn + 1e-15) 287 | iou = tp_total / (tp_total + fp_total + fn_total + 1e-15) 288 | 289 | return { 290 | "IOU": float(iou), 291 | "IOU mean": np.mean(iou_per_class), 292 | "IOU per Class": iou_per_class, 293 | } -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # 4 | # Description: Describes the class to compute the CIDEr 5 | # (Consensus-Based Image Description Evaluation) Metric 6 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 7 | # 8 | # Creation Date: Sun Feb 8 14:16:54 2015 9 | # 10 | # Authors: Ramakrishna Vedantam and 11 | # Tsung-Yi Lin 12 | 13 | # edited by Michele Cafagna 14 | 15 | from .cider_scorer import CiderScorer 16 | 17 | 18 | class Cider: 19 | """ 20 | Main Class to compute the CIDEr metric 21 | 22 | """ 23 | def __init__(self, n=4, df="corpus"): 24 | """ 25 | Initialize the CIDEr scoring function 26 | : param n (int): n-gram size 27 | : param df (string): specifies where to get the IDF values from 28 | takes values 'corpus', 'coco-val' 29 | : return: None 30 | """ 31 | # set cider to sum over 1 to 4-grams 32 | self._n = n 33 | self._df = df 34 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 35 | 36 | def compute_score(self, gts, res): 37 | """ 38 | Main function to compute CIDEr score 39 | : param gts (dict) : {image:tokenized reference sentence} 40 | : param res (dict) : {image:tokenized candidate sentence} 41 | : return: cider (float) : computed CIDEr score for the corpus 42 | """ 43 | 44 | # clear all the previous hypos and refs 45 | self.cider_scorer.clear() 46 | 47 | for res_id in res: 48 | 49 | hypo = res[res_id] 50 | ref = gts[res_id] 51 | 52 | # Sanity check. 53 | assert(type(hypo) is list) 54 | assert(len(hypo) == 1) 55 | assert(type(ref) is list) 56 | assert(len(ref) > 0) 57 | self.cider_scorer += (hypo[0], ref) 58 | 59 | (score, scores) = self.cider_scorer.compute_score() 60 | 61 | return score, scores 62 | 63 | def save_df(self, df_name="corpus"): 64 | self.cider_scorer.save_df(df_name) 65 | 66 | def method(self): 67 | return "CIDEr" 68 | -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/cider/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/tasks/code/metric/cidereval/cider/data/__init__.py -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/cider/data/coco-val.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/tasks/code/metric/cidereval/cider/data/coco-val.p -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'rama' 2 | from .tokenizer.ptbtokenizer import PTBTokenizer 3 | from .cider.cider import Cider 4 | 5 | class CIDErEvalCap: 6 | 7 | def __init__(self, coco, cocoRes, df='corpus'): 8 | print('tokenization...') 9 | 10 | imgIds = {'image_id': coco.getImgIds()}['image_id'] 11 | gts = {} 12 | res = {} 13 | for imgId in imgIds: 14 | gts[imgId] = coco.imgToAnns[imgId] 15 | res[imgId] = cocoRes.imgToAnns[imgId] 16 | 17 | tokenizer = PTBTokenizer() 18 | gts = tokenizer.tokenize(gts) 19 | res = tokenizer.tokenize(res) 20 | 21 | self.eval = {} 22 | self.gts = gts 23 | self.res = res 24 | self.df = df 25 | 26 | 27 | def evaluate(self): 28 | metric_scores = {} 29 | scorer = Cider(df=self.df) 30 | score, scores = scorer.compute_score(self.gts, self.res) 31 | metric_scores["CIDEr"] = list(scores) 32 | self.setEval(score, "CIDEr") 33 | return metric_scores 34 | 35 | 36 | def setEval(self, score, method): 37 | self.eval[method] = score 38 | -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/scorers.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import PTBTokenizer 2 | from .cider.cider import Cider 3 | 4 | 5 | def _preprocess_for_cider(refs, preds): 6 | r""" 7 | Convert preds and refs to the cider data format 8 | 9 | refs: List[List[str]] 10 | preds : List[str] 11 | 12 | return gts: Dict[str : List[Dict['caption':str] : str ]], 13 | res: List[Dict['image_id':str]: 'caption':str] 14 | """ 15 | 16 | assert len(refs) == len(preds) 17 | 18 | gts = {} 19 | res = [] 20 | 21 | for i, (caps, pred) in enumerate(zip(refs, preds)): 22 | gts[i] = [{ 'caption': cap } for cap in caps ] 23 | 24 | res.append({ 'image_id': i, 25 | 'caption': pred}) 26 | return gts, res 27 | 28 | 29 | def cider(predictions, references, df="coco-val"): 30 | r""" 31 | Compute the cider score for the given predictions and references 32 | 33 | predictions : List[str], model's predictions 34 | references: List[List[str]], references 35 | df: str, either 'coco-val' or 'corpus' (default : 'coco-val'). If 'coco-val' the TF-IDF COCO validation split is \\ 36 | used. If 'corpus' the TF-IDF is computed over the reference set provided. 37 | 38 | returns {"avg_score": mp.float, "scores": np.array(np.float)} 39 | """ 40 | gts, res = _preprocess_for_cider(references, predictions) 41 | tokenizer_res = PTBTokenizer('res') 42 | tokenizer_gts = PTBTokenizer('gts') 43 | 44 | _gts = tokenizer_gts.tokenize(gts) 45 | _res = tokenizer_res.tokenize(res) 46 | 47 | scorer = Cider(df=df) 48 | 49 | score, scores = scorer.compute_score(_gts, _res) 50 | 51 | return {"avg_score": score, "scores": scores} 52 | -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | # edited by Michele Cafagna 3 | from .ptbtokenizer import PTBTokenizer 4 | from .simpletokenizer import SimpleTokenizer -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | # path to the stanford corenlp jar 16 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 17 | 18 | # punctuations to be removed from the sentences 19 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | 23 | class PTBTokenizer: 24 | """Python wrapper of Stanford PTBTokenizer""" 25 | 26 | def __init__(self, _source='gts'): 27 | self.source = _source 28 | 29 | def tokenize(self, captions_for_image): 30 | """Tokenize a sample 31 | 32 | Args: 33 | captions_for_image : 34 | 35 | IF _source='gts' follows format: 36 | dict: { str : [ 37 | { "caption" : str }, 38 | { "caption" : str }, 39 | ... 40 | ], 41 | str : [ ... ], 42 | ... 43 | } 44 | IF _source='res' follows format: 45 | list: [ {"image_id" : str, 46 | "caption" : str, 47 | }, 48 | ... 49 | ] 50 | Returns: 51 | final_tokenized_captions_for_index: 52 | list: [ {"image_id" : str, 53 | "caption" : str, 54 | }, 55 | ... 56 | ] 57 | """ 58 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, 59 | 'edu.stanford.nlp.process.PTBTokenizer', 60 | '-preserveLines', '-lowerCase'] 61 | 62 | # ====================================================== 63 | # prepare data for PTB Tokenizer 64 | # ====================================================== 65 | 66 | if self.source == 'gts': 67 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 68 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 69 | final_tokenized_captions_for_image = {} 70 | 71 | elif self.source == 'res': 72 | index = [i for i, v in enumerate(captions_for_image)] 73 | image_id = [v["image_id"] for v in captions_for_image] 74 | sentences = '\n'.join(v["caption"].replace('\n', ' ') for v in captions_for_image) 75 | final_tokenized_captions_for_index = [] 76 | 77 | # ====================================================== 78 | # save sentences to temporary file 79 | # ====================================================== 80 | path_to_jar_dir_name = os.path.dirname(os.path.abspath(__file__)) 81 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dir_name, mode='w') 82 | tmp_file.write(sentences) 83 | tmp_file.close() 84 | 85 | # ====================================================== 86 | # tokenize sentence 87 | # ====================================================== 88 | cmd.append(os.path.basename(tmp_file.name)) 89 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dir_name, stdout=subprocess.PIPE) 90 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0].decode("utf-8") 91 | lines = token_lines.split('\n') 92 | # remove temp file 93 | os.remove(tmp_file.name) 94 | 95 | # ====================================================== 96 | # create dictionary for tokenized captions 97 | # ====================================================== 98 | if self.source == 'gts': 99 | for k, line in zip(image_id, lines): 100 | if k not in final_tokenized_captions_for_image: 101 | final_tokenized_captions_for_image[k] = [] 102 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS]) 103 | final_tokenized_captions_for_image[k].append(tokenized_caption) 104 | 105 | return final_tokenized_captions_for_image 106 | 107 | elif self.source == 'res': 108 | for k, img, line in zip(index, image_id, lines): 109 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS]) 110 | final_tokenized_captions_for_index.append({'image_id': img, 'caption': [tokenized_caption]}) 111 | 112 | return final_tokenized_captions_for_index 113 | -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/tokenizer/simpletokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : simpletokenizer.py 4 | # 5 | # Description : Yet another tokenizer. 6 | # 7 | # Creation Date : 12-11-2021 8 | 9 | import spacy 10 | from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER 11 | from spacy.lang.char_classes import CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS 12 | from spacy.util import compile_infix_regex 13 | 14 | 15 | # punctuations to be removed from the sentences 16 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", 17 | ".", "?", "!", ",", ":", "-", "--", "...", ";", " ", ""] 18 | 19 | infixes = ( 20 | LIST_ELLIPSES 21 | + LIST_ICONS 22 | + [ 23 | r"(?<=[0-9])[+\-\*^](?=[0-9-])", 24 | r"(?<=[{al}{q}])\.(?=[{au}{q}])".format( 25 | al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES 26 | ), 27 | r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA), 28 | # ✅ Commented out regex that splits on hyphens between letters: 29 | # r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS), 30 | r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA), 31 | ] 32 | ) 33 | 34 | 35 | class SimpleTokenizer: 36 | """Simple Tokenizer""" 37 | 38 | def __init__(self, _source='gts'): 39 | self.source = _source 40 | 41 | # setting up the tokenizer 42 | self._nlp = spacy.load("en_core_web_sm") 43 | infix_re = compile_infix_regex(infixes) 44 | self._nlp.tokenizer.infix_finditer = infix_re.finditer 45 | self._tokenizer = self._nlp.tokenizer 46 | 47 | def tokenize(self, captions_for_image): 48 | """Tokenize a sample 49 | 50 | Args: 51 | captions_for_image : 52 | 53 | IF _source='gts' follows format: 54 | dict: { str : [ 55 | { "caption" : str }, 56 | { "caption" : str }, 57 | ... 58 | ], 59 | str : [ ... ], 60 | ... 61 | } 62 | IF _source='res' follows format: 63 | list: [ {"image_id" : str, 64 | "caption" : str, 65 | }, 66 | ... 67 | ] 68 | Returns: 69 | final_tokenized_captions_for_index: 70 | list: [ {"image_id" : str, 71 | "caption" : str, 72 | }, 73 | ... 74 | ] 75 | """ 76 | 77 | tokenized_captions = None 78 | 79 | if self.source == 'gts': 80 | tokenized_captions= {} 81 | 82 | for k in captions_for_image: 83 | 84 | if k not in tokenized_captions: 85 | tokenized_captions[k] = [] 86 | 87 | for item in captions_for_image[k]: 88 | 89 | tokenized_captions[k].append( 90 | " ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS])) 91 | 92 | elif self.source == 'res': 93 | 94 | tokenized_captions= [] 95 | 96 | for item in captions_for_image: 97 | 98 | tokenized_captions.append( 99 | { 'image_id' : item['image_id'], 100 | 'caption' : [" ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS])] 101 | }) 102 | 103 | else: 104 | ValueError("source can be either 'gts' or 'res' ") 105 | 106 | return tokenized_captions -------------------------------------------------------------------------------- /tasks/code/metric/cidereval/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/tasks/code/metric/cidereval/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, h, w): 32 | self.h = h 33 | self.w = w 34 | 35 | def __call__(self, image, target): 36 | image = F.resize(image, (self.h, self.w)) 37 | # If size is a sequence like (h, w), the output size will be matched to this. 38 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 39 | target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST) 40 | return image, target 41 | 42 | 43 | class RandomResize(object): 44 | def __init__(self, min_size, max_size=None): 45 | self.min_size = min_size 46 | if max_size is None: 47 | max_size = min_size 48 | self.max_size = max_size 49 | 50 | def __call__(self, image, target): 51 | size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1) 52 | image = F.resize(image, size) 53 | # If size is a sequence like (h, w), the output size will be matched to this. 54 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 55 | target = F.resize(target, size, interpolation=Image.NEAREST) 56 | return image, target 57 | 58 | 59 | class RandomHorizontalFlip(object): 60 | def __init__(self, flip_prob): 61 | self.flip_prob = flip_prob 62 | 63 | def __call__(self, image, target): 64 | if random.random() < self.flip_prob: 65 | image = F.hflip(image) 66 | target = F.hflip(target) 67 | return image, target 68 | 69 | 70 | class RandomCrop(object): 71 | def __init__(self, size): 72 | self.size = size 73 | 74 | def __call__(self, image, target): 75 | image = pad_if_smaller(image, self.size) 76 | target = pad_if_smaller(target, self.size, fill=255) 77 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 78 | image = F.crop(image, *crop_params) 79 | target = F.crop(target, *crop_params) 80 | return image, target 81 | 82 | 83 | class CenterCrop(object): 84 | def __init__(self, size): 85 | self.size = size 86 | 87 | def __call__(self, image, target): 88 | image = F.center_crop(image, self.size) 89 | target = F.center_crop(target, self.size) 90 | return image, target 91 | 92 | 93 | class ToTensor(object): 94 | def __call__(self, image, target): 95 | image = F.to_tensor(image) 96 | target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64) 97 | return image, target 98 | 99 | 100 | class RandomAffine(object): 101 | def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None): 102 | self.angle = angle 103 | self.translate = translate 104 | self.scale = scale 105 | self.shear = shear 106 | self.resample = resample 107 | self.fillcolor = fillcolor 108 | 109 | def __call__(self, image, target): 110 | affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size) 111 | image = F.affine(image, *affine_params) 112 | target = F.affine(target, *affine_params) 113 | return image, target 114 | 115 | 116 | class Normalize(object): 117 | def __init__(self, mean, std): 118 | self.mean = mean 119 | self.std = std 120 | 121 | def __call__(self, image, target): 122 | image = F.normalize(image, mean=self.mean, std=self.std) 123 | return image, target 124 | 125 | --------------------------------------------------------------------------------