├── README.md
├── RemoteSAM.pdf
├── arc
├── __init__.py
├── adaptive_rotated_conv.py
├── routing_function.py
└── weight_init.py
├── args.py
├── assets
├── HKUST.jpg
├── NUIST.jpg
├── Radar.png
├── RemoteSAM.png
├── RemoteSAM270K.png
├── SEU.png
├── demo.jpg
└── hhu_logo.png
├── bert
├── activations.py
├── configuration_bert.py
├── configuration_utils.py
├── file_utils.py
├── generation_utils.py
├── modeling_bert.py
├── modeling_utils.py
├── tokenization_bert.py
├── tokenization_utils.py
└── tokenization_utils_base.py
├── data
├── DiverseDataset.py
├── all_dataset.py
├── dataset_refer_bert.py
└── datasets
│ ├── CAP.py
│ ├── CNT.py
│ ├── DET.py
│ ├── MCC.py
│ ├── MLC.py
│ ├── VG.py
│ └── __init__.py
├── lib
├── _utils.py
├── backbone.py
├── cross_scale_interaction.py
├── logger.py
├── mask_predictor.py
├── mmcv_custom
│ ├── __init__.py
│ └── checkpoint.py
├── sa
│ ├── functional.py
│ ├── functions
│ │ ├── __init__.py
│ │ ├── aggregation_refpad.py
│ │ ├── aggregation_zeropad.py
│ │ ├── subtraction2_refpad.py
│ │ ├── subtraction2_zeropad.py
│ │ ├── subtraction_refpad.py
│ │ ├── subtraction_zeropad.py
│ │ └── utils.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── aggregation.py
│ │ ├── subtraction.py
│ │ └── subtraction2.py
├── segmentation.py
├── transformer.py
└── various_receptive.py
├── loss
└── loss.py
├── refer
└── refer.py
├── requirements.txt
├── tasks
├── CAP.sh
├── CNT.sh
├── DET.sh
├── DET_DOTA.sh
├── MCC.sh
├── MLC.sh
├── REF.sh
├── SEG.sh
├── VG.sh
└── code
│ ├── RuleBasedCaptioning.py
│ ├── eval
│ ├── CAP.py
│ ├── CNT.py
│ ├── DET.py
│ ├── DET_DOTA.py
│ ├── MCC.py
│ ├── MLC.py
│ ├── REF.py
│ ├── SEG.py
│ └── VG.py
│ ├── metric
│ ├── RunningScore.py
│ └── cidereval
│ │ ├── cider
│ │ ├── __init__.py
│ │ ├── cider.py
│ │ ├── cider_scorer.py
│ │ └── data
│ │ │ ├── __init__.py
│ │ │ └── coco-val.p
│ │ ├── eval.py
│ │ ├── scorers.py
│ │ └── tokenizer
│ │ ├── __init__.py
│ │ ├── ptbtokenizer.py
│ │ ├── simpletokenizer.py
│ │ └── stanford-corenlp-3.4.1.jar
│ └── model.py
├── train.py
├── transforms.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # RemoteSAM: Towards Segment Anything for Earth Observation
4 |
5 |
6 |
7 | [Liang Yao (姚亮)*](https://multimodality.group/author/%E5%A7%9A%E4%BA%AE/)
8 |

,
9 | [Fan Liu (刘凡)*](https://multimodality.group/author/%E5%88%98%E5%87%A1/) ✉
10 |

,
11 | [Delong Chen (陈德龙)*](https://chendelong.world/)
12 |

,
13 |
14 | [Chuanyi Zhang (张传一)](https://ai.hhu.edu.cn/2023/0809/c17670a264073/page.htm)
15 |

,
16 | [Yijun Wang (王翌骏)](https://multimodality.group/author/%E7%8E%8B%E7%BF%8C%E9%AA%8F/)
17 |

,
18 | [Ziyun Chen (陈子赟)](https://multimodality.group/author/%E9%99%88%E5%AD%90%E8%B5%9F/)
19 |

,
20 |
21 | [Wei Xu (许玮)](https://multimodality.group/author/%E8%AE%B8%E7%8E%AE/)
22 |

,
23 | [Shimin Di (邸世民)](https://cs.seu.edu.cn/shimindi/main.htm)
24 |

,
25 | [Yuhui Zheng (郑钰辉)](https://faculty.nuist.edu.cn/zhengyuhui/en/index.htm)
26 |

27 |
28 | \* *Equal Contribution* ✉ *Corresponding Author*
29 |
30 | Model : 🤗[RemoteSAM](https://huggingface.co/1e12Leon/RemoteSAM)
31 |
32 | Dataset : 🤗[RemoteSAM-270K](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K)
33 |
34 |
35 |
36 | ## News
37 |
38 | - **2025/5/7**: We have released the model and dataset! You can download RemoteSAM-270K from 🤗[RemoteSAM-270K](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K) and checkpoint from 🤗[RemoteSAM](https://huggingface.co/1e12Leon/RemoteSAM).
39 | - **2025/5/3**: Welcome to RemoteSAM! The preprint of our paper is available. Dataset and model are open-sourced at this repository.
40 |
41 |
42 |
43 | ## Introduction
44 | Welcome to the official repository of our paper "RemoteSAM: Towards Segment Anything for Earth Observation" !
45 |
46 | 
47 |
48 | Recent advances in AI have revolutionized Earth observation, yet most remote sensing tasks still rely on specialized models with fragmented interfaces. To address this, we present **RemoteSAM**, a vision foundation model that unifies pixel-, region-, and image-level tasks through a novel architecture centered on Referring Expression Segmentation (RES). Unlike existing paradigms—task-specific heads with limited knowledge sharing or text-based models struggling with dense outputs—RemoteSAM leverages pixel-level predictions as atomic units, enabling upward compatibility to higher-level tasks while eliminating computationally heavy language model backbones. This design achieves an order-of-magnitude parameter reduction (billions to millions), enabling efficient high-resolution data processing.
49 |
50 | 
51 |
52 | We also build **RemoteSAM-270K** dataset, a large-scale collection of 270K Image-Text-Mask triplets generated via an automated pipeline powered by vision-language models (VLMs). This dataset surpasses existing resources in semantic diversity, covering 1,000+ object categories and rich attributes (e.g., color, spatial relations) through linguistically varied prompts. We further introduce RSVocab-1K, a hierarchical semantic vocabulary to quantify dataset coverage and adaptability.
53 |
54 | 
55 |
56 | ## Setting Up
57 |
58 | The code has been verified to work with PyTorch v1.13.0 and Python 3.8.
59 | 1. Clone this repository.
60 | 2. Change directory to root of this repository.
61 |
62 | ### Package Dependencies
63 | 1. Create a new Conda environment with Python 3.8 then activate it:
64 | ```shell
65 | conda create -n RemoteSAM python==3.8
66 | conda activate RemoteSAM
67 | ```
68 |
69 | 2. Install PyTorch v1.13.0 with a CUDA version that works on your cluster/machine (CUDA 11.6 is used in this example):
70 | ```shell
71 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
72 | ```
73 |
74 | 3. Install mmcv from openmmlab:
75 | ```shell
76 | pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html
77 | ```
78 | 4. Install the packages in `requirements.txt` via `pip`:
79 | ```shell
80 | pip install -r requirements.txt
81 | ```
82 | ### The Initialization Weights for Training
83 | 1. Create the `./pretrained_weights` directory where we will be storing the weights.
84 | ```shell
85 | mkdir ./pretrained_weights
86 | ```
87 | 2. Download [pre-trained classification weights of the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth),
88 | and put the `pth` file in `./pretrained_weights`.
89 | These weights are needed for training to initialize the model.
90 |
91 | ## Data Preparation
92 | We perform all experiments based on our proposed dataset RemoteSAM-270K.
93 |
94 | ### Usage
95 | 1. Download our dataset from [HuggingFace](https://huggingface.co/datasets/1e12Leon/RemoteSAM_270K).
96 | 2. Copy all the downloaded files to `./refer/data/`. The dataset folder should be like this:
97 | ```
98 | $DATA_PATH
99 | ├── RemoteSAM-270K
100 | │ ├── JPEGImages
101 | │ ├── Annotations
102 | └──── ├── refs(unc).p
103 | ├── instances.json
104 | ```
105 |
106 |
107 | ## RemoteSAM
108 |
109 | ### Training
110 | We use DistributedDataParallel from PyTorch for training. To run on 8 GPUs on a single node:
111 | More training setting can be change in args.py.
112 | ```shell
113 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
114 | python -m torch.distributed.launch \
115 | --nproc_per_node 8 --master_port 12345 train.py \
116 | --epochs 40 --img_size 896 2>&1 | tee ./output
117 | ```
118 | ### Getting Started
119 |
120 | To get started with RemoteSAM, please first initialize a model and load the RemoteSAM checkpoint with a few lines of code:
121 |
122 | ```python
123 | from tasks.code.model import RemoteSAM, init_demo_model
124 | import cv2
125 | import numpy as np
126 |
127 | device = 'cuda:0'
128 | checkpoint = "./pretrained_weights/checkpoint.pth"
129 |
130 | model = init_demo_model(checkpoint, device)
131 | model = RemoteSAM(model, device, use_EPOC=True)
132 | ```
133 |
134 | Then, you can explore different tasks with RemoteSAM via:
135 |
136 | - **Referring Expression Segmentation**
137 |
138 | ```python
139 | image = cv2.imread("./assets/demo.jpg")
140 | mask = model.referring_seg(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), sentence="the airplane on the right")
141 | ```
142 |
143 | - **Semantic Segmentation**
144 |
145 | ```python
146 | image = cv2.imread("./assets/demo.jpg")
147 | result = model.semantic_seg(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'])
148 | for classname in ["airplane", "vehicle"]:
149 | mask = result[classname]
150 | ```
151 |
152 | - **Object Detection**
153 |
154 | ```python
155 | image = cv2.imread("./assets/demo.jpg")
156 | result = model.detection(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'])
157 | for classname in ["airplane", "vehicle"]:
158 | boxes = result[classname]
159 | ```
160 |
161 | - **Visual Grounding**
162 |
163 | ```python
164 | image = cv2.imread("./assets/demo.jpg")
165 | box = model.visual_grounding(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), sentence="the airplane on the right")
166 | ```
167 |
168 | - **Multi-label classification**
169 |
170 | ```python
171 | image = cv2.imread("./assets/demo.jpg")
172 | result = model.multi_label_cls(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'])
173 | print(result)
174 | ```
175 |
176 | - **Image Classification**
177 |
178 | ```python
179 | image = cv2.imread("./assets/demo.jpg")
180 | result = model.multi_class_cls(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'])
181 | print(result)
182 | ```
183 |
184 | - **Image Captioning**
185 |
186 | ```python
187 | image = cv2.imread("./assets/demo.jpg")
188 | result = model.captioning(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'], region_split=9)
189 | print(result)
190 | ```
191 |
192 | - **Object Counting**
193 |
194 | ```python
195 | image = cv2.imread("./assets/demo.jpg")
196 | result = model.counting(image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), classnames=['airplane', 'vehicle'])
197 | for classname in ["airplane", "vehicle"]:
198 | print("{}: {}".format(classname, result[classname]))
199 | ```
200 |
201 | ### Evaluation
202 |
203 | - **Evaluation of Referring Expression Segmentation**
204 |
205 | ```shell
206 | bash tasks/REF.sh
207 | ```
208 |
209 | - **Evaluation of Semantic Segmentation**
210 |
211 | ```shell
212 | bash tasks/SEG.sh
213 | ```
214 |
215 | - **Evaluation of Object Detection**
216 |
217 | ```shell
218 | bash tasks/DET.sh
219 | ```
220 | - **Evaluation of Visual Grounding**
221 |
222 | ```shell
223 | bash tasks/VG.sh
224 | ```
225 | - **Evaluation of Multi-label classification**
226 |
227 | ```shell
228 | bash tasks/MLC.sh
229 | ```
230 | - **Evaluation of Image classification**
231 |
232 | ```shell
233 | bash tasks/MCC.sh
234 | ```
235 | - **Evaluation of Image Captioning**
236 |
237 | ```shell
238 | bash tasks/CAP.sh
239 | ```
240 | - **Evaluation of Object Counting**
241 |
242 | ```shell
243 | bash tasks/CNT.sh
244 | ```
245 |
246 | ## Acknowledge
247 | - Thanks Lu Wang (王璐) for his efforts on the RemoteSAM-270K dataset.
248 | - Code in this repository is built on [RMSIN](https://github.com/Lsan2401/RMSIN). We'd like to thank the authors for open sourcing their project.
249 |
250 | ## Contact
251 | Please Contact yaoliang@hhu.edu.cn
252 |
253 |
254 | ## Cite
255 | If you find this work useful, please cite our paper as:
256 | ```bibtex
257 | @misc{yao2025RemoteSAM,
258 | title={RemoteSAM: Towards Segment Anything for Earth Observation},
259 | author={Liang Yao and Fan Liu and Delong Chen and Chuanyi Zhang and Yijun Wang and Ziyun Chen and Wei Xu and Shimin Di and Yuhui Zheng},
260 | year={2025},
261 | eprint={2505.18022},
262 | archivePrefix={arXiv},
263 | primaryClass={cs.CV},
264 | url={https://arxiv.org/abs/2505.18022},
265 | }
266 | ```
267 |
--------------------------------------------------------------------------------
/RemoteSAM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/RemoteSAM.pdf
--------------------------------------------------------------------------------
/arc/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptive_rotated_conv import AdaptiveRotatedConv2d
2 | from .routing_function import RountingFunction
3 |
4 | __all__ = [
5 | 'AdaptiveRotatedConv2d', 'RountingFunction',
6 | ]
7 |
--------------------------------------------------------------------------------
/arc/adaptive_rotated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | __all__ = ['AdaptiveRotatedConv2d']
7 |
8 |
9 | def _get_rotation_matrix(thetas):
10 | bs, g = thetas.shape
11 | device = thetas.device
12 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n]
13 |
14 | x = torch.cos(thetas)
15 | y = torch.sin(thetas)
16 | x = x.unsqueeze(0).unsqueeze(0) # shape = [1, 1, bs * g]
17 | y = y.unsqueeze(0).unsqueeze(0)
18 | a = x - y
19 | b = x * y
20 | c = x + y
21 |
22 | rot_mat_positive = torch.cat((
23 | torch.cat((a, 1-a, torch.zeros(1, 7, bs*g, device=device)), dim=1),
24 | torch.cat((torch.zeros(1, 1, bs*g, device=device), x-b, b, torch.zeros(1, 1, bs*g, device=device), 1-c+b, y-b, torch.zeros(1, 3, bs*g, device=device)), dim=1),
25 | torch.cat((torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device), 1-a, torch.zeros(1, 3, bs*g, device=device)), dim=1),
26 | torch.cat((b, y-b, torch.zeros(1,1 , bs*g, device=device), x-b, 1-c+b, torch.zeros(1, 4, bs*g, device=device)), dim=1),
27 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1),
28 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-c+b, x-b, torch.zeros(1, 1, bs*g, device=device), y-b, b), dim=1),
29 | torch.cat((torch.zeros(1, 3, bs*g, device=device), 1-a, torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device)), dim=1),
30 | torch.cat((torch.zeros(1, 3, bs*g, device=device), y-b, 1-c+b, torch.zeros(1, 1, bs*g, device=device), b, x-b, torch.zeros(1, 1, bs*g, device=device)), dim=1),
31 | torch.cat((torch.zeros(1, 7, bs*g, device=device), 1-a, a), dim=1)
32 | ), dim=0) # shape = [k^2, k^2, bs*g]
33 |
34 | rot_mat_negative = torch.cat((
35 | torch.cat((c, torch.zeros(1, 2, bs*g, device=device), 1-c, torch.zeros(1, 5, bs*g, device=device)), dim=1),
36 | torch.cat((-b, x+b, torch.zeros(1, 1, bs*g, device=device), b-y, 1-a-b, torch.zeros(1, 4, bs*g, device=device)), dim=1),
37 | torch.cat((torch.zeros(1, 1, bs*g, device=device), 1-c, c, torch.zeros(1, 6, bs*g, device=device)), dim=1),
38 | torch.cat((torch.zeros(1, 3, bs*g, device=device), x+b, 1-a-b, torch.zeros(1, 1, bs*g, device=device), -b, b-y, torch.zeros(1, 1, bs*g, device=device)), dim=1),
39 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1),
40 | torch.cat((torch.zeros(1, 1, bs*g, device=device), b-y, -b, torch.zeros(1, 1, bs*g, device=device), 1-a-b, x+b, torch.zeros(1, 3, bs*g, device=device)), dim=1),
41 | torch.cat((torch.zeros(1, 6, bs*g, device=device), c, 1-c, torch.zeros(1, 1, bs*g, device=device)), dim=1),
42 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-a-b, b-y, torch.zeros(1, 1, bs*g, device=device), x+b, -b), dim=1),
43 | torch.cat((torch.zeros(1, 5, bs*g, device=device), 1-c, torch.zeros(1, 2, bs*g, device=device), c), dim=1)
44 | ), dim=0) # shape = [k^2, k^2, bs*g]
45 |
46 | mask = (thetas >= 0).unsqueeze(0).unsqueeze(0)
47 | mask = mask.float() # shape = [1, 1, bs*g]
48 | rot_mat = mask * rot_mat_positive + (1 - mask) * rot_mat_negative # shape = [k*k, k*k, bs*g]
49 | rot_mat = rot_mat.permute(2, 0, 1) # shape = [bs*g, k*k, k*k]
50 | rot_mat = rot_mat.reshape(bs, g, rot_mat.shape[1], rot_mat.shape[2]) # shape = [bs, g, k*k, k*k]
51 | return rot_mat
52 |
53 |
54 | def batch_rotate_multiweight(weights, lambdas, thetas):
55 | """
56 | Let
57 | batch_size = b
58 | kernel_number = n
59 | kernel_size = 3
60 | Args:
61 | weights: tensor, shape = [kernel_number, Cout, Cin, k, k]
62 | thetas: tensor of thetas, shape = [batch_size, kernel_number]
63 | Return:
64 | weights_out: tensor, shape = [batch_size x Cout, Cin // groups, k, k]
65 | """
66 | assert(thetas.shape == lambdas.shape)
67 | assert(lambdas.shape[1] == weights.shape[0])
68 |
69 | b = thetas.shape[0]
70 | n = thetas.shape[1]
71 | k = weights.shape[-1]
72 | _, Cout, Cin, _, _ = weights.shape
73 |
74 | if k == 3 :
75 | # Stage 1:
76 | # input: thetas: [b, n]
77 | # lambdas: [b, n]
78 | # output: rotation_matrix: [b, n, 9, 9] (with gate) --> [b*9, n*9]
79 |
80 | # Sub_Stage 1.1:
81 | # input: [b, n] kernel
82 | # output: [b, n, 9, 9] rotation matrix
83 | rotation_matrix = _get_rotation_matrix(thetas)
84 |
85 | # Sub_Stage 1.2:
86 | # input: [b, n, 9, 9] rotation matrix
87 | # [b, n] lambdas
88 | # --> [b, n, 1, 1] lambdas
89 | # --> [b, n, 1, 1] lambdas dot [b, n, 9, 9] rotation matrix
90 | # --> [b, n, 9, 9] rotation matrix with gate (done)
91 | # output: [b, n, 9, 9] rotation matrix with gate
92 | lambdas = lambdas.unsqueeze(2).unsqueeze(3)
93 | rotation_matrix = torch.mul(rotation_matrix, lambdas)
94 |
95 | # Sub_Stage 1.3: Reshape
96 | # input: [b, n, 9, 9] rotation matrix with gate
97 | # output: [b*9, n*9] rotation matrix with gate
98 | rotation_matrix = rotation_matrix.permute(0, 2, 1, 3)
99 | rotation_matrix = rotation_matrix.reshape(b*k*k, n*k*k)
100 |
101 | # Stage 2: Reshape
102 | # input: weights: [n, Cout, Cin, 3, 3]
103 | # --> [n, 3, 3, Cout, Cin]
104 | # --> [n*9, Cout*Cin] done
105 | # output: weights: [n*9, Cout*Cin]
106 | weights = weights.permute(0, 3, 4, 1, 2)
107 | weights = weights.contiguous().view(n*k*k, Cout*Cin)
108 |
109 |
110 | # Stage 3: torch.mm
111 | # [b*9, n*9] x [n*9, Cout*Cin]
112 | # --> [b*9, Cout*Cin]
113 | weights = torch.mm(rotation_matrix, weights)
114 |
115 | # Stage 4: Reshape Back
116 | # input: [b*9, Cout*Cin]
117 | # --> [b, 3, 3, Cout, Cin]
118 | # --> [b, Cout, Cin, 3, 3]
119 | # --> [b * Cout, Cin, 3, 3] done
120 | # output: [b * Cout, Cin, 3, 3]
121 | weights = weights.contiguous().view(b, k, k, Cout, Cin)
122 | weights = weights.permute(0, 3, 4, 1, 2)
123 | weights = weights.reshape(b * Cout, Cin, k, k)
124 | else:
125 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n]
126 |
127 | x = torch.cos(thetas)
128 | y = torch.sin(thetas)
129 | rotate_matrix = torch.tensor([[x, -y, 0], [y, x, 0]])
130 | rotate_matrix = rotate_matrix.unsqueeze(0).repeat(n, 1, 1)
131 |
132 | weights = weights.contiguous().view(n, Cout*Cin, k, k)
133 |
134 | grid = F.affine_grid(rotate_matrix, weights.shape)
135 | weights = F.grid_sample(weights, grid, mode='biliner')
136 |
137 | return weights
138 |
139 |
140 | class AdaptiveRotatedConv2d(nn.Module):
141 |
142 | def __init__(self, in_channels, out_channels, kernel_size,
143 | stride=1, padding=1, dilation=1, groups=1, bias=False,
144 | kernel_number=1, rounting_func=None, rotate_func=batch_rotate_multiweight):
145 | super().__init__()
146 | self.kernel_number = kernel_number
147 | self.in_channels = in_channels
148 | self.out_channels = out_channels
149 | self.kernel_size = kernel_size
150 | self.stride = stride
151 | self.padding = padding
152 | self.dilation = dilation
153 | self.groups = groups
154 | self.bias = bias
155 |
156 | self.rounting_func = rounting_func
157 | self.rotate_func = rotate_func
158 |
159 | self.weight = nn.Parameter(
160 | torch.Tensor(
161 | kernel_number,
162 | out_channels,
163 | in_channels // groups,
164 | kernel_size,
165 | kernel_size,
166 | )
167 | )
168 | nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
169 |
170 | def forward(self, x):
171 | # get alphas, angles
172 | # # [bs, Cin, h, w] --> [bs, n_theta], [bs, n_theta]
173 | alphas, angles = self.rounting_func(x)
174 |
175 | # rotate weight
176 | # # [Cout, Cin, k, k] --> [bs * Cout, Cin, k, k]
177 | # print(self.weight.shape)
178 | rotated_weight = self.rotate_func(self.weight, alphas, angles)
179 |
180 | # reshape images
181 | bs, Cin, h, w = x.shape
182 | x = x.reshape(1, bs * Cin, h, w) # [1, bs * Cin, h, w]
183 |
184 | # adaptive conv over images using group conv
185 | out = F.conv2d(input=x, weight=rotated_weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=(self.groups * bs))
186 |
187 | # reshape back
188 | out = out.reshape(bs, self.out_channels, *out.shape[2:])
189 | return out
190 |
191 | def extra_repr(self):
192 | s = ('{in_channels}, {out_channels}, kernel_number={kernel_number}'
193 | ', kernel_size={kernel_size}, stride={stride}, bias={bias}')
194 |
195 | if self.padding != (0,) * len([self.padding]):
196 | s += ', padding={padding}'
197 | if self.dilation != (1,) * len([self.dilation]):
198 | s += ', dilation={dilation}'
199 | if self.groups != 1:
200 | s += ', groups={groups}'
201 | return s.format(**self.__dict__)
202 |
--------------------------------------------------------------------------------
/arc/routing_function.py:
--------------------------------------------------------------------------------
1 | import math
2 | import einops
3 | import torch
4 | import torch.nn as nn
5 | from .weight_init import trunc_normal_
6 |
7 |
8 | class LayerNormProxy(nn.Module):
9 | # copy from https://github.com/LeapLabTHU/DAT/blob/main/models/dat_blocks.py
10 | def __init__(self, dim):
11 | super().__init__()
12 | self.norm = nn.LayerNorm(dim)
13 |
14 | def forward(self, x):
15 | x = einops.rearrange(x, 'b c h w -> b h w c')
16 | x = self.norm(x)
17 | return einops.rearrange(x, 'b h w c -> b c h w')
18 |
19 |
20 | class RountingFunction(nn.Module):
21 |
22 | def __init__(self, in_channels, kernel_number, dropout_rate=0.2, proportion=40.0):
23 | super().__init__()
24 | self.kernel_number = kernel_number
25 | self.dwc = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
26 | groups=in_channels, bias=False)
27 | self.norm = LayerNormProxy(in_channels)
28 | self.relu = nn.ReLU(inplace=True)
29 |
30 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
31 |
32 | self.dropout1 = nn.Dropout(dropout_rate)
33 | self.fc_alpha = nn.Linear(in_channels, kernel_number, bias=True)
34 |
35 | self.dropout2= nn.Dropout(dropout_rate)
36 | self.fc_theta = nn.Linear(in_channels, kernel_number, bias=False)
37 |
38 | self.act_func = nn.Softsign()
39 | self.proportion = proportion / 180.0 * math.pi
40 |
41 | # init weights
42 | trunc_normal_(self.dwc.weight, std=.02)
43 | trunc_normal_(self.fc_alpha.weight, std=.02)
44 | trunc_normal_(self.fc_theta.weight, std=.02)
45 |
46 | def forward(self, x):
47 |
48 | x = self.dwc(x)
49 | x = self.norm(x)
50 | x = self.relu(x)
51 |
52 | x = self.avg_pool(x).squeeze(dim=-1).squeeze(dim=-1) # avg_x.shape = [batch_size, Cin]
53 |
54 | alphas = self.dropout1(x)
55 | alphas = self.fc_alpha(alphas)
56 | alphas = torch.sigmoid(alphas)
57 |
58 | angles = self.dropout2(x)
59 | angles = self.fc_theta(angles)
60 | angles = self.act_func(angles)
61 | angles = angles * self.proportion
62 |
63 | return alphas, angles
64 |
65 | def extra_repr(self):
66 | s = (f'kernel_number={self.kernel_number}')
67 | return s.format(**self.__dict__)
68 |
--------------------------------------------------------------------------------
/arc/weight_init.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # get from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/weight_init.py
3 | # --------------------------------------------------------
4 | import torch
5 | import math
6 | import warnings
7 |
8 |
9 | def _trunc_normal_(tensor, mean, std, a, b):
10 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
11 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
12 | def norm_cdf(x):
13 | # Computes standard normal cumulative distribution function
14 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
15 |
16 | if (mean < a - 2 * std) or (mean > b + 2 * std):
17 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
18 | "The distribution of values may be incorrect.",
19 | stacklevel=2)
20 |
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
53 | applied while sampling the normal with mean/std applied, therefore a, b args
54 | should be adjusted to match the range of mean, std args.
55 | Args:
56 | tensor: an n-dimensional `torch.Tensor`
57 | mean: the mean of the normal distribution
58 | std: the standard deviation of the normal distribution
59 | a: the minimum cutoff value
60 | b: the maximum cutoff value
61 | Examples:
62 | >>> w = torch.empty(3, 5)
63 | >>> nn.init.trunc_normal_(w)
64 | """
65 | with torch.no_grad():
66 | return _trunc_normal_(tensor, mean, std, a, b)
67 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_parser():
5 | parser = argparse.ArgumentParser(description='RemoteSAM training and testing')
6 | parser.add_argument('--amsgrad', action='store_true',
7 | help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
8 | parser.add_argument('-b', '--batch-size', default=2, type=int)
9 | parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
10 | parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
11 | parser.add_argument('--dataset', default='rrsisd', help='dataset name')
12 | parser.add_argument('--ddp_trained_weights', action='store_true',
13 | help='Only needs specified when testing,'
14 | 'whether the weights to be loaded are from a DDP-trained model')
15 | parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
16 | parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
17 | parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
18 | parser.add_argument('--img_size', default=896, type=int, help='input image size')
19 | parser.add_argument("--local_rank", type=int,default=0,help='local rank for DistributedDataParallel')
20 | parser.add_argument('--lr', default=0.00003, type=float, help='the initial learning rate') # 0.00003
21 | parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
22 | 'where a, b, c, and d refer to the numbers of heads in stage-1,'
23 | 'stage-2, stage-3, and stage-4 PWAMs')
24 | parser.add_argument('--model', default='lavt_one', help='model: lavt, lavt_one')
25 | parser.add_argument('--model_id', default='RemoteSAM', help='name to identify the model')
26 | parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
27 | parser.add_argument('--pin_mem', action='store_true',
28 | help='If true, pin memory when using the data loader.')
29 | parser.add_argument('--pretrained_swin_weights', default='./pretrained_weights/swin_base_patch4_window12_384_22k.pth',
30 | help='path to pre-trained Swin backbone weights')
31 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
32 | parser.add_argument('--refer_data_root', default='/data/dishimin/RemoteSAM-270K', help='REFER dataset root directory')
33 | parser.add_argument('--resume', default='', help='resume from checkpoint')
34 | parser.add_argument('--split', default='test', help='only used when testing')
35 | parser.add_argument('--splitBy', default='unc', help='change to umd or google when the datasset is G-Ref (RefCOCOg)')
36 | parser.add_argument('--swin_type', default='base',
37 | help='tiny, small, base, or large variants of the Swin Transformer')
38 | parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
39 | dest='weight_decay')
40 | parser.add_argument('--window12', action='store_true',
41 | help='only needs specified when testing,'
42 | 'when training, window size is inferred from pre-trained weights file name'
43 | '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
44 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
45 | parser.add_argument("--imageFolder", default='/data/RemoteSAM-270K/JPEGImages/', type=str,
46 | help="imageFolder path")
47 | parser.add_argument("--ref_file_path", default='/data/RemoteSAM-270K/refs(unc)_RemoteSAM.p', type=str,
48 | help="ref_file_path")
49 | parser.add_argument("--instances_path", default='/data/RemoteSAM-270K/instances.json', type=str,
50 | help="instances.json's path")
51 | parser.add_argument("--eval", action="store_true", help="Only run evaluation")
52 | parser.add_argument('--EPOC', action='store_true', help='if true, use EPOC')
53 | parser.add_argument('--save_path', default='', help='save_path')
54 | parser.add_argument('--task', type=str, choices=['CNT', 'DET', 'CAP', 'VG', 'MCC', 'MLC','SEG','REF'],
55 | help='For tasks other than referring segmentation and semantic segmentation, specify the task')
56 | parser.add_argument('--annoFolder', default='/data/dishimin/iSAID/rrsisd/Annotations/', type=str, help="annoFolder path")
57 | parser.add_argument('--sentence', default='', help='sentence input')
58 | parser.add_argument('--_class', default='', help='class input')
59 |
60 | return parser
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = get_parser()
65 | args_dict = parser.parse_args()
66 |
--------------------------------------------------------------------------------
/assets/HKUST.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/HKUST.jpg
--------------------------------------------------------------------------------
/assets/NUIST.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/NUIST.jpg
--------------------------------------------------------------------------------
/assets/Radar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/Radar.png
--------------------------------------------------------------------------------
/assets/RemoteSAM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/RemoteSAM.png
--------------------------------------------------------------------------------
/assets/RemoteSAM270K.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/RemoteSAM270K.png
--------------------------------------------------------------------------------
/assets/SEU.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/SEU.png
--------------------------------------------------------------------------------
/assets/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/demo.jpg
--------------------------------------------------------------------------------
/assets/hhu_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/RemoteSAM/8e534adf542b406129970ea93f1a6dabbfadae35/assets/hhu_logo.png
--------------------------------------------------------------------------------
/bert/activations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def swish(x):
12 | return x * torch.sigmoid(x)
13 |
14 |
15 | def _gelu_python(x):
16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19 | This is now written in C in torch.nn.functional
20 | Also see https://arxiv.org/abs/1606.08415
21 | """
22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23 |
24 |
25 | def gelu_new(x):
26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27 | Also see https://arxiv.org/abs/1606.08415
28 | """
29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30 |
31 |
32 | if torch.__version__ < "1.4.0":
33 | gelu = _gelu_python
34 | else:
35 | gelu = F.gelu
36 |
37 |
38 | def gelu_fast(x):
39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40 |
41 |
42 | ACT2FN = {
43 | "relu": F.relu,
44 | "swish": swish,
45 | "gelu": gelu,
46 | "tanh": torch.tanh,
47 | "gelu_new": gelu_new,
48 | "gelu_fast": gelu_fast,
49 | }
50 |
51 |
52 | def get_activation(activation_string):
53 | if activation_string in ACT2FN:
54 | return ACT2FN[activation_string]
55 | else:
56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
57 |
--------------------------------------------------------------------------------
/bert/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_utils import PretrainedConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49 | # See all BERT models at https://huggingface.co/models?filter=bert
50 | }
51 |
52 |
53 | class BertConfig(PretrainedConfig):
54 | r"""
55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56 | It is used to instantiate an BERT model according to the specified arguments, defining the model
57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58 | the BERT `bert-base-uncased