├── LICENSE ├── README.md ├── configs ├── config_airs.yaml ├── config_birds.yaml └── config_cars.yaml ├── lib ├── __init__.py ├── make.sh └── nms │ ├── __init__.py │ ├── _ext │ ├── __init__.py │ └── nms │ │ ├── __init__.py │ │ └── _nms.so │ ├── build.py │ ├── pth_nms.py │ └── src │ ├── cuda │ ├── nms_kernel.cu │ ├── nms_kernel.cu.o │ └── nms_kernel.h │ ├── nms.c │ ├── nms.h │ ├── nms_cuda.c │ └── nms_cuda.h ├── model ├── resnet50.py └── vgg19.py ├── requirements.txt ├── train.py └── utils ├── split_dataset ├── airs_dataset.py ├── birds_dataset.py └── cars_dataset.py ├── transform.py ├── utils.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 PRIS-CV: Computer Vision Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AP-CNN 2 | 3 | Code release for Weakly Supervised Attention Pyramid Convolutional Neural Network for Fine-Grained Visual Classification (TIP2021). 4 | 5 | ### Dependencies 6 | Python 3.6 with all of the `pip install -r requirements.txt` packages including: 7 | - `torch == 0.4.1` 8 | - `opencv-python` 9 | - `visdom` 10 | 11 | ### Data 12 | 1. Download the FGVC image data. Extract them to `data/cars/`, `data/birds/` and `data/airs/`, respectively. 13 | * [Stanford-Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) (cars) 14 | ``` 15 | -/cars/ 16 | └─── car_ims 17 | └─── 00001.jpg 18 | └─── 00002.jpg 19 | └─── ... 20 | └─── cars_annos.mat 21 | ``` 22 | * [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) (birds) 23 | ``` 24 | -/birds/ 25 | └─── images.txt 26 | └─── image_class_labels.txt 27 | └─── train_test_split.txt 28 | └─── images 29 | └─── 001.Black_footed_Albatross 30 | └─── Black_Footed_Albatross_0001_796111.jpg 31 | └─── ... 32 | └─── 002.Laysan_Albatross 33 | └─── ... 34 | ``` 35 | * [FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) (airs) 36 | ``` 37 | -/airs/ 38 | └─── images 39 | └─── 0034309.jpg 40 | └─── 0034958.jpg 41 | └─── ... 42 | └─── variants.txt 43 | └─── images_variant_trainval.txt 44 | └─── images_variant_test.txt 45 | ``` 46 | 47 | 2. Preprocess images. 48 | - For birds: `python utils/split_dataset/birds_dataset.py` 49 | - For cars: `python utils/split_dataset/cars_dataset.py` 50 | - For airs: `python utils/split_dataset/airs_dataset.py` 51 | 52 | ### Training 53 | **Start:** 54 | 55 | 1. `python train.py --dataset {cars,airs,birds} --model {resnet50,vgg19} [options: --visualize]` to start training. 56 | - For example, to train ResNet50 on Stanford-Cars: `python train.py --dataset cars --model resnet50` 57 | - Run `python train.py --help` to see full input arguments. 58 | 59 | **Visualize:** 60 | 1. `python -m visdom.server` to start visdom server. 61 | 62 | 2. Visualize online attention masks and ROIs on `http://localhost:8097`. 63 | 64 | ### Pretrained Checkpoints 65 | Pretrained checkpoints with following settings are available on [download link](https://pan.baidu.com/s/1Z3SafB3UgYaZW1ApqMeTAA), with access code "kjqu". 66 | 67 | | Dataset | base model | accuracy(%) | 68 | | :----: | :----: | :----: | 69 | | CUB-200-2011 | resnet50 | 88.4 | 70 | | Stanford-Cars | resnet50 | 95.3 | 71 | | FGVC-Aircraft | resnet50 | 94.0 | 72 | 73 | 74 | ### Citation 75 | If you find this paper useful in your research, please consider citing: 76 | ``` 77 | @ARTICLE{9350209, 78 | author={Y. {Ding} and Z. {Ma} and S. {Wen} and J. {Xie} and D. {Chang} and Z. {Si} and M. {Wu} and H. {Ling}}, 79 | journal={IEEE Transactions on Image Processing}, 80 | title={AP-CNN: Weakly Supervised Attention Pyramid Convolutional Neural Network for Fine-Grained Visual Classification}, 81 | year={2021}, 82 | volume={30}, 83 | number={}, 84 | pages={2826-2836}, 85 | doi={10.1109/TIP.2021.3055617}} 86 | ``` 87 | 88 | ### Contact 89 | Thanks for your attention! 90 | If you have any suggestion or question, you can leave a message here or contact us directly: 91 | - mazhanyu@bupt.edu.cn 92 | - dingyf@bupt.edu.cn 93 | 94 | -------------------------------------------------------------------------------- /configs/config_airs.yaml: -------------------------------------------------------------------------------- 1 | # data config 2 | train_dir: "data/Aircraft/train" 3 | test_dir: "data/Aircraft/test" 4 | num_class: 100 5 | 6 | # model config 7 | batch_size: 16 8 | learning_rate: 0.001 9 | momentum: 0.9 10 | weight_decay: 5e-4 11 | num_epoch: 100 12 | resize_size: 448 13 | crop_size: 448 14 | 15 | # visualizer config 16 | vis_host: "http://localhost" 17 | vis_port: 8097 -------------------------------------------------------------------------------- /configs/config_birds.yaml: -------------------------------------------------------------------------------- 1 | # data config 2 | train_dir: "data/Birds/train" 3 | test_dir: "data/Birds/test" 4 | num_class: 200 5 | 6 | # model config 7 | batch_size: 16 8 | learning_rate: 0.001 9 | momentum: 0.9 10 | weight_decay: 5e-4 11 | num_epoch: 100 12 | resize_size: 600 13 | crop_size: 448 14 | 15 | # visualizer config 16 | vis_host: "http://localhost" 17 | vis_port: 8097 18 | -------------------------------------------------------------------------------- /configs/config_cars.yaml: -------------------------------------------------------------------------------- 1 | # data config 2 | train_dir: "data/StandCars/train" 3 | test_dir: "data/StandCars/test" 4 | num_class: 196 5 | 6 | # model config 7 | batch_size: 16 8 | learning_rate: 0.001 9 | momentum: 0.9 10 | weight_decay: 5e-4 11 | num_epoch: 100 12 | resize_size: 448 13 | crop_size: 448 14 | 15 | # visualizer config 16 | vis_host: "http://localhost" 17 | vis_port: 8097 18 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/AP-CNN_Pytorch-master/4a7cbc66fc539e42ed4ec88863123a12f1dd2fac/lib/__init__.py -------------------------------------------------------------------------------- /lib/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_ARCH="-gencode arch=compute_30,code=sm_30 \ 4 | -gencode arch=compute_35,code=sm_35 \ 5 | -gencode arch=compute_50,code=sm_50 \ 6 | -gencode arch=compute_52,code=sm_52 \ 7 | -gencode arch=compute_60,code=sm_60 \ 8 | -gencode arch=compute_61,code=sm_61" 9 | 10 | # compile NMS 11 | cd nms/src/cuda 12 | echo "Compiling nms kernels by nvcc..." 13 | nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC $CUDA_ARCH 14 | cd ../../ 15 | python build.py 16 | -------------------------------------------------------------------------------- /lib/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/AP-CNN_Pytorch-master/4a7cbc66fc539e42ed4ec88863123a12f1dd2fac/lib/nms/__init__.py -------------------------------------------------------------------------------- /lib/nms/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/AP-CNN_Pytorch-master/4a7cbc66fc539e42ed4ec88863123a12f1dd2fac/lib/nms/_ext/__init__.py -------------------------------------------------------------------------------- /lib/nms/_ext/nms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._nms import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /lib/nms/_ext/nms/_nms.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/AP-CNN_Pytorch-master/4a7cbc66fc539e42ed4ec88863123a12f1dd2fac/lib/nms/_ext/nms/_nms.so -------------------------------------------------------------------------------- /lib/nms/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | 6 | sources = ['src/nms.c'] 7 | headers = ['src/nms.h'] 8 | defines = [] 9 | with_cuda = False 10 | 11 | if torch.cuda.is_available(): 12 | print('Including CUDA code.') 13 | sources += ['src/nms_cuda.c'] 14 | headers += ['src/nms_cuda.h'] 15 | defines += [('WITH_CUDA', None)] 16 | with_cuda = True 17 | 18 | this_file = os.path.dirname(os.path.realpath(__file__)) 19 | print(this_file) 20 | extra_objects = ['src/cuda/nms_kernel.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi = create_extension( 24 | '_ext.nms', 25 | headers=headers, 26 | sources=sources, 27 | define_macros=defines, 28 | relative_to=__file__, 29 | with_cuda=with_cuda, 30 | extra_objects=extra_objects, 31 | extra_compile_args=['-std=c99'] 32 | ) 33 | 34 | if __name__ == '__main__': 35 | ffi.build() 36 | -------------------------------------------------------------------------------- /lib/nms/pth_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._ext import nms 3 | import numpy as np 4 | 5 | def pth_nms(dets, thresh): 6 | """ 7 | dets has to be a tensor 8 | """ 9 | if not dets.is_cuda: 10 | x1 = dets[:, 0] 11 | y1 = dets[:, 1] 12 | x2 = dets[:, 2] 13 | y2 = dets[:, 3] 14 | scores = dets[:, 4] 15 | 16 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 17 | order = scores.sort(0, descending=True)[1] 18 | # order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long() 19 | 20 | keep = torch.LongTensor(dets.size(0)) 21 | num_out = torch.LongTensor(1) 22 | nms.cpu_nms(keep, num_out, dets, order, areas, thresh) 23 | 24 | return keep[:num_out[0]] 25 | else: 26 | x1 = dets[:, 0] 27 | y1 = dets[:, 1] 28 | x2 = dets[:, 2] 29 | y2 = dets[:, 3] 30 | scores = dets[:, 4] 31 | 32 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 33 | order = scores.sort(0, descending=True)[1] 34 | # order = torch.from_numpy(np.ascontiguousarray(scores.cpu().numpy().argsort()[::-1])).long().cuda() 35 | 36 | dets = dets[order].contiguous() 37 | 38 | keep = torch.LongTensor(dets.size(0)) 39 | num_out = torch.LongTensor(1) 40 | # keep = torch.cuda.LongTensor(dets.size(0)) 41 | # num_out = torch.cuda.LongTensor(1) 42 | nms.gpu_nms(keep, num_out, dets, thresh) 43 | 44 | return order[keep[:num_out[0]].cuda()].contiguous() 45 | # return order[keep[:num_out[0]]].contiguous() 46 | 47 | -------------------------------------------------------------------------------- /lib/nms/src/cuda/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | #include 12 | #include 13 | #include 14 | #include "nms_kernel.h" 15 | 16 | __device__ inline float devIoU(float const * const a, float const * const b) { 17 | float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); 18 | float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); 19 | float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); 20 | float interS = width * height; 21 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 22 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 23 | return interS / (Sa + Sb - interS); 24 | } 25 | 26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 27 | const float *dev_boxes, unsigned long long *dev_mask) { 28 | const int row_start = blockIdx.y; 29 | const int col_start = blockIdx.x; 30 | 31 | // if (row_start > col_start) return; 32 | 33 | const int row_size = 34 | fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 35 | const int col_size = 36 | fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 37 | 38 | __shared__ float block_boxes[threadsPerBlock * 5]; 39 | if (threadIdx.x < col_size) { 40 | block_boxes[threadIdx.x * 5 + 0] = 41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 42 | block_boxes[threadIdx.x * 5 + 1] = 43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 44 | block_boxes[threadIdx.x * 5 + 2] = 45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 46 | block_boxes[threadIdx.x * 5 + 3] = 47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 48 | block_boxes[threadIdx.x * 5 + 4] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 50 | } 51 | __syncthreads(); 52 | 53 | if (threadIdx.x < row_size) { 54 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 55 | const float *cur_box = dev_boxes + cur_box_idx * 5; 56 | int i = 0; 57 | unsigned long long t = 0; 58 | int start = 0; 59 | if (row_start == col_start) { 60 | start = threadIdx.x + 1; 61 | } 62 | for (i = start; i < col_size; i++) { 63 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 64 | t |= 1ULL << i; 65 | } 66 | } 67 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 68 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 69 | } 70 | } 71 | 72 | 73 | void _nms(int boxes_num, float * boxes_dev, 74 | unsigned long long * mask_dev, float nms_overlap_thresh) { 75 | 76 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 77 | DIVUP(boxes_num, threadsPerBlock)); 78 | dim3 threads(threadsPerBlock); 79 | nms_kernel<<>>(boxes_num, 80 | nms_overlap_thresh, 81 | boxes_dev, 82 | mask_dev); 83 | } 84 | 85 | #ifdef __cplusplus 86 | } 87 | #endif 88 | -------------------------------------------------------------------------------- /lib/nms/src/cuda/nms_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/AP-CNN_Pytorch-master/4a7cbc66fc539e42ed4ec88863123a12f1dd2fac/lib/nms/src/cuda/nms_kernel.cu.o -------------------------------------------------------------------------------- /lib/nms/src/cuda/nms_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _NMS_KERNEL 2 | #define _NMS_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 10 | 11 | void _nms(int boxes_num, float * boxes_dev, 12 | unsigned long long * mask_dev, float nms_overlap_thresh); 13 | 14 | #ifdef __cplusplus 15 | } 16 | #endif 17 | 18 | #endif 19 | 20 | -------------------------------------------------------------------------------- /lib/nms/src/nms.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) { 5 | // boxes has to be sorted 6 | THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous"); 7 | THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous"); 8 | THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous"); 9 | THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous"); 10 | // Number of ROIs 11 | long boxes_num = THFloatTensor_size(boxes, 0); 12 | long boxes_dim = THFloatTensor_size(boxes, 1); 13 | 14 | long * keep_out_flat = THLongTensor_data(keep_out); 15 | float * boxes_flat = THFloatTensor_data(boxes); 16 | long * order_flat = THLongTensor_data(order); 17 | float * areas_flat = THFloatTensor_data(areas); 18 | 19 | THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num); 20 | THByteTensor_fill(suppressed, 0); 21 | unsigned char * suppressed_flat = THByteTensor_data(suppressed); 22 | 23 | // nominal indices 24 | int i, j; 25 | // sorted indices 26 | int _i, _j; 27 | // temp variables for box i's (the box currently under consideration) 28 | float ix1, iy1, ix2, iy2, iarea; 29 | // variables for computing overlap with box j (lower scoring box) 30 | float xx1, yy1, xx2, yy2; 31 | float w, h; 32 | float inter, ovr; 33 | 34 | long num_to_keep = 0; 35 | for (_i=0; _i < boxes_num; ++_i) { 36 | i = order_flat[_i]; 37 | if (suppressed_flat[i] == 1) { 38 | continue; 39 | } 40 | keep_out_flat[num_to_keep++] = i; 41 | ix1 = boxes_flat[i * boxes_dim]; 42 | iy1 = boxes_flat[i * boxes_dim + 1]; 43 | ix2 = boxes_flat[i * boxes_dim + 2]; 44 | iy2 = boxes_flat[i * boxes_dim + 3]; 45 | iarea = areas_flat[i]; 46 | for (_j = _i + 1; _j < boxes_num; ++_j) { 47 | j = order_flat[_j]; 48 | if (suppressed_flat[j] == 1) { 49 | continue; 50 | } 51 | xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]); 52 | yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]); 53 | xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]); 54 | yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]); 55 | w = fmaxf(0.0, xx2 - xx1 + 1); 56 | h = fmaxf(0.0, yy2 - yy1 + 1); 57 | inter = w * h; 58 | ovr = inter / (iarea + areas_flat[j] - inter); 59 | if (ovr >= nms_overlap_thresh) { 60 | suppressed_flat[j] = 1; 61 | } 62 | } 63 | } 64 | 65 | long *num_out_flat = THLongTensor_data(num_out); 66 | *num_out_flat = num_to_keep; 67 | THByteTensor_free(suppressed); 68 | return 1; 69 | } -------------------------------------------------------------------------------- /lib/nms/src/nms.h: -------------------------------------------------------------------------------- 1 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh); -------------------------------------------------------------------------------- /lib/nms/src/nms_cuda.c: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "cuda/nms_kernel.h" 13 | 14 | 15 | extern THCState *state; 16 | 17 | int gpu_nms(THLongTensor * keep, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh) { 18 | // boxes has to be sorted 19 | THArgCheck(THLongTensor_isContiguous(keep), 0, "boxes must be contiguous"); 20 | THArgCheck(THCudaTensor_isContiguous(state, boxes), 2, "boxes must be contiguous"); 21 | // Number of ROIs 22 | int boxes_num = THCudaTensor_size(state, boxes, 0); 23 | int boxes_dim = THCudaTensor_size(state, boxes, 1); 24 | 25 | float* boxes_flat = THCudaTensor_data(state, boxes); 26 | 27 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 28 | THCudaLongTensor * mask = THCudaLongTensor_newWithSize2d(state, boxes_num, col_blocks); 29 | unsigned long long* mask_flat = THCudaLongTensor_data(state, mask); 30 | 31 | _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh); 32 | 33 | THLongTensor * mask_cpu = THLongTensor_newWithSize2d(boxes_num, col_blocks); 34 | THLongTensor_copyCuda(state, mask_cpu, mask); 35 | THCudaLongTensor_free(state, mask); 36 | 37 | unsigned long long * mask_cpu_flat = THLongTensor_data(mask_cpu); 38 | 39 | THLongTensor * remv_cpu = THLongTensor_newWithSize1d(col_blocks); 40 | unsigned long long* remv_cpu_flat = THLongTensor_data(remv_cpu); 41 | THLongTensor_fill(remv_cpu, 0); 42 | 43 | long * keep_flat = THLongTensor_data(keep); 44 | long num_to_keep = 0; 45 | 46 | int i, j; 47 | for (i = 0; i < boxes_num; i++) { 48 | int nblock = i / threadsPerBlock; 49 | int inblock = i % threadsPerBlock; 50 | 51 | if (!(remv_cpu_flat[nblock] & (1ULL << inblock))) { 52 | keep_flat[num_to_keep++] = i; 53 | unsigned long long *p = &mask_cpu_flat[0] + i * col_blocks; 54 | for (j = nblock; j < col_blocks; j++) { 55 | remv_cpu_flat[j] |= p[j]; 56 | } 57 | } 58 | } 59 | 60 | long * num_out_flat = THLongTensor_data(num_out); 61 | * num_out_flat = num_to_keep; 62 | 63 | THLongTensor_free(mask_cpu); 64 | THLongTensor_free(remv_cpu); 65 | 66 | return 1; 67 | } 68 | -------------------------------------------------------------------------------- /lib/nms/src/nms_cuda.h: -------------------------------------------------------------------------------- 1 | int gpu_nms(THLongTensor * keep_out, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh); -------------------------------------------------------------------------------- /model/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | import math 5 | import time 6 | import os 7 | import numpy as np 8 | import cv2 9 | import random 10 | import torch.utils.model_zoo as model_zoo 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from lib.nms.pth_nms import pth_nms 15 | 16 | def get_merge_bbox(dets, inds): 17 | xx1 = np.min(dets[inds][:,0]) 18 | yy1 = np.min(dets[inds][:,1]) 19 | xx2 = np.max(dets[inds][:,2]) 20 | yy2 = np.max(dets[inds][:,3]) 21 | 22 | return np.array((xx1, yy1, xx2, yy2)) 23 | 24 | def pth_nms_merge(dets, thresh, topk): 25 | dets = dets.cpu().data.numpy() 26 | x1 = dets[:, 0] 27 | y1 = dets[:, 1] 28 | x2 = dets[:, 2] 29 | y2 = dets[:, 3] 30 | scores = dets[:, 4] 31 | 32 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 33 | order = scores.argsort()[::-1] 34 | 35 | boxes_merge = [] 36 | cnt = 0 37 | while order.size > 0: 38 | i = order[0] 39 | 40 | xx1 = np.maximum(x1[i], x1[order[1:]]) 41 | yy1 = np.maximum(y1[i], y1[order[1:]]) 42 | xx2 = np.minimum(x2[i], x2[order[1:]]) 43 | yy2 = np.minimum(y2[i], y2[order[1:]]) 44 | w = np.maximum(0.0, xx2 - xx1 + 1) 45 | h = np.maximum(0.0, yy2 - yy1 + 1) 46 | inter = w * h 47 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 48 | inds = np.where(ovr <= thresh)[0] 49 | 50 | inds_merge = np.where((ovr > 0.5)*(0.9*scores[i]= topk: 56 | break 57 | 58 | return torch.from_numpy(np.array(boxes_merge)) 59 | 60 | class BasicConv(nn.Module): 61 | 62 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 63 | bn=True, bias=False): 64 | super(BasicConv, self).__init__() 65 | self.out_channels = out_planes 66 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 67 | dilation=dilation, groups=groups, bias=bias) 68 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 69 | self.relu = nn.ReLU(inplace=True) if relu else None 70 | 71 | def forward(self, x): 72 | x = self.conv(x) 73 | if self.bn is not None: 74 | x = self.bn(x) 75 | if self.relu is not None: 76 | x = self.relu(x) 77 | return x 78 | 79 | 80 | def conv3x3(in_planes, out_planes, stride=1): 81 | """3x3 convolution with padding""" 82 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 83 | padding=1, bias=False) 84 | 85 | 86 | class BasicBlock(nn.Module): 87 | expansion = 1 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None): 90 | super(BasicBlock, self).__init__() 91 | self.conv1 = conv3x3(inplanes, planes, stride) 92 | self.bn1 = nn.BatchNorm2d(planes) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.conv2 = conv3x3(planes, planes) 95 | self.bn2 = nn.BatchNorm2d(planes) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | residual = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class Bottleneck(nn.Module): 119 | expansion = 4 120 | 121 | def __init__(self, inplanes, planes, stride=1, downsample=None): 122 | super(Bottleneck, self).__init__() 123 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 124 | self.bn1 = nn.BatchNorm2d(planes) 125 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 126 | padding=1, bias=False) 127 | self.bn2 = nn.BatchNorm2d(planes) 128 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 129 | self.bn3 = nn.BatchNorm2d(planes * 4) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.downsample = downsample 132 | self.stride = stride 133 | 134 | def forward(self, x): 135 | residual = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | 157 | class SimpleFPA(nn.Module): 158 | def __init__(self, in_planes, out_planes): 159 | """ 160 | Feature Pyramid Attention 161 | :type channels: int 162 | """ 163 | super(SimpleFPA, self).__init__() 164 | 165 | self.channels_cond = in_planes 166 | # Master branch 167 | self.conv_master = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 168 | 169 | # Global pooling branch 170 | self.conv_gpb = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 171 | 172 | def forward(self, x): 173 | """ 174 | :param x: Shape: [b, 2048, h, w] 175 | :return: out: Feature maps. Shape: [b, 2048, h, w] 176 | """ 177 | # Master branch 178 | x_master = self.conv_master(x) 179 | 180 | # Global pooling branch 181 | x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) 182 | x_gpb = self.conv_gpb(x_gpb) 183 | 184 | out = x_master + x_gpb 185 | 186 | return out 187 | 188 | class PyramidFeatures(nn.Module): 189 | """Feature pyramid module with top-down feature pathway""" 190 | def __init__(self, B2_size, B3_size, B4_size, B5_size, feature_size=256): 191 | super(PyramidFeatures, self).__init__() 192 | 193 | self.P5_1 = SimpleFPA(B5_size, feature_size) 194 | self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 195 | 196 | self.P4_1 = nn.Conv2d(B4_size, feature_size, kernel_size=1, stride=1, padding=0) 197 | self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 198 | 199 | self.P3_1 = nn.Conv2d(B3_size, feature_size, kernel_size=1, stride=1, padding=0) 200 | self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 201 | 202 | def forward(self, inputs): 203 | B3, B4, B5 = inputs 204 | 205 | P5_x = self.P5_1(B5) 206 | P5_upsampled_x = F.interpolate(P5_x, scale_factor=2) 207 | P5_x = self.P5_2(P5_x) 208 | 209 | P4_x = self.P4_1(B4) 210 | P4_x = P5_upsampled_x + P4_x 211 | P4_upsampled_x = F.interpolate(P4_x, scale_factor=2) 212 | P4_x = self.P4_2(P4_x) 213 | 214 | P3_x = self.P3_1(B3) 215 | P3_x = P3_x + P4_upsampled_x 216 | P3_x = self.P3_2(P3_x) 217 | 218 | return [P3_x, P4_x, P5_x] 219 | 220 | class PyramidAttentions(nn.Module): 221 | """Attention pyramid module with bottom-up attention pathway""" 222 | def __init__(self, channel_size=256): 223 | super(PyramidAttentions, self).__init__() 224 | 225 | self.A3_1 = SpatialGate(channel_size) 226 | self.A3_2 = ChannelGate(channel_size) 227 | 228 | self.A4_1 = SpatialGate(channel_size) 229 | self.A4_2 = ChannelGate(channel_size) 230 | 231 | self.A5_1 = SpatialGate(channel_size) 232 | self.A5_2 = ChannelGate(channel_size) 233 | 234 | def forward(self, inputs): 235 | F3, F4, F5 = inputs 236 | 237 | A3_spatial = self.A3_1(F3) 238 | A3_channel = self.A3_2(F3) 239 | A3 = A3_spatial*F3 + A3_channel*F3 240 | 241 | A4_spatial = self.A4_1(F4) 242 | A4_channel = self.A4_2(F4) 243 | A4_channel = (A4_channel + A3_channel) / 2 244 | A4 = A4_spatial*F4 + A4_channel*F4 245 | 246 | A5_spatial = self.A5_1(F5) 247 | A5_channel = self.A5_2(F5) 248 | A5_channel = (A5_channel + A4_channel) / 2 249 | A5 = A5_spatial*F5 + A5_channel*F5 250 | 251 | return [A3, A4, A5, A3_spatial, A4_spatial, A5_spatial] 252 | 253 | class SpatialGate(nn.Module): 254 | """generation spatial attention mask""" 255 | def __init__(self, out_channels): 256 | super(SpatialGate, self).__init__() 257 | self.conv = nn.ConvTranspose2d(out_channels,1,kernel_size=3,stride=1,padding=1) 258 | def forward(self, x): 259 | x = self.conv(x) 260 | return torch.sigmoid(x) 261 | 262 | class ChannelGate(nn.Module): 263 | """generation channel attention mask""" 264 | def __init__(self, out_channels): 265 | super(ChannelGate, self).__init__() 266 | self.conv1 = nn.Conv2d(out_channels,out_channels//16,kernel_size=1,stride=1,padding=0) 267 | self.conv2 = nn.Conv2d(out_channels//16,out_channels,kernel_size=1,stride=1,padding=0) 268 | def forward(self, x): 269 | x = nn.AdaptiveAvgPool2d(output_size=1)(x) 270 | x = F.relu(self.conv1(x), inplace=True) 271 | x = torch.sigmoid(self.conv2(x)) 272 | return x 273 | 274 | class Flatten(nn.Module): 275 | def __init__(self): 276 | super(Flatten, self).__init__() 277 | 278 | def forward(self, x): 279 | return x.view(x.size(0), -1) 280 | 281 | def generate_anchors_single_pyramid(scales, ratios, shape, feature_stride, anchor_stride): 282 | """ 283 | scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] 284 | ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] 285 | shape: [height, width] spatial shape of the feature map over which 286 | to generate anchors. 287 | feature_stride: Stride of the feature map relative to the image in pixels. 288 | anchor_stride: Stride of anchors on the feature map. For example, if the 289 | value is 2 then generate anchors for every other feature map pixel. 290 | """ 291 | # Get all combinations of scales and ratios 292 | scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) 293 | scales = scales.flatten() 294 | ratios = ratios.flatten() 295 | 296 | # Enumerate heights and widths from scales and ratios 297 | heights = scales / np.sqrt(ratios) 298 | widths = scales * np.sqrt(ratios) 299 | 300 | # Enumerate shifts in feature space 301 | shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride 302 | shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride 303 | shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) 304 | 305 | # Enumerate combinations of shifts, widths, and heights 306 | box_widths, box_centers_x = np.meshgrid(widths, shifts_x) 307 | box_heights, box_centers_y = np.meshgrid(heights, shifts_y) 308 | 309 | box_centers = np.stack( 310 | [box_centers_x, box_centers_y], axis=2).reshape([-1, 2]) 311 | box_sizes = np.stack([box_widths, box_heights], axis=2).reshape([-1, 2]) 312 | 313 | # Convert to corner coordinates (x1, y1, x2, y2) 314 | boxes = np.concatenate([box_centers - 0.5 * box_sizes, 315 | box_centers + 0.5 * box_sizes], axis=1) 316 | return torch.from_numpy(boxes).cuda() 317 | 318 | class ResNet(nn.Module): 319 | """implementation of AP-CNN on ResNet""" 320 | def __init__(self, num_classes, block, layers): 321 | super(ResNet, self).__init__() 322 | self.inplanes = 64 323 | self.num_classes = num_classes 324 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 325 | self.bn1 = nn.BatchNorm2d(64) 326 | self.relu = nn.ReLU(inplace=True) 327 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 328 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 329 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 330 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 331 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 332 | 333 | if self.num_classes == 200: 334 | hidden_num = 512 335 | else: 336 | hidden_num = 256 337 | 338 | if block == BasicBlock: 339 | fpn_sizes = [self.layer1[layers[0] - 1].conv2.out_channels, self.layer2[layers[1] - 1].conv2.out_channels, 340 | self.layer3[layers[2] - 1].conv2.out_channels, 341 | self.layer4[layers[3] - 1].conv2.out_channels] 342 | elif block == Bottleneck: 343 | fpn_sizes = [self.layer1[layers[0] - 1].conv3.out_channels, self.layer2[layers[1] - 1].conv3.out_channels, 344 | self.layer3[layers[2] - 1].conv3.out_channels, 345 | self.layer4[layers[3] - 1].conv3.out_channels] 346 | 347 | self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3]) 348 | self.apn = PyramidAttentions(channel_size=256) 349 | 350 | self.cls5 = nn.Sequential( 351 | nn.AdaptiveAvgPool2d(1), 352 | Flatten(), 353 | nn.BatchNorm1d(256), 354 | nn.Linear(256, hidden_num), 355 | nn.BatchNorm1d(hidden_num), 356 | nn.ELU(inplace=True), 357 | nn.Linear(hidden_num, self.num_classes) 358 | ) 359 | 360 | self.cls4 = nn.Sequential( 361 | nn.AdaptiveAvgPool2d(1), 362 | Flatten(), 363 | nn.BatchNorm1d(256), 364 | nn.Linear(256, hidden_num), 365 | nn.BatchNorm1d(hidden_num), 366 | nn.ELU(inplace=True), 367 | nn.Linear(hidden_num, self.num_classes) 368 | ) 369 | 370 | self.cls3 = nn.Sequential( 371 | nn.AdaptiveAvgPool2d(1), 372 | Flatten(), 373 | nn.BatchNorm1d(256), 374 | nn.Linear(256, hidden_num), 375 | nn.BatchNorm1d(hidden_num), 376 | nn.ELU(inplace=True), 377 | nn.Linear(hidden_num, self.num_classes) 378 | ) 379 | 380 | self.cls_concate = nn.Sequential( 381 | Flatten(), 382 | nn.BatchNorm1d(256*3), 383 | nn.Linear(256*3, hidden_num), 384 | nn.BatchNorm1d(hidden_num), 385 | nn.ELU(inplace=True), 386 | nn.Linear(hidden_num, self.num_classes) 387 | ) 388 | 389 | self.criterion = nn.CrossEntropyLoss() 390 | 391 | for m in self.modules(): 392 | if isinstance(m, nn.Conv2d): 393 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 394 | m.weight.data.normal_(0, math.sqrt(2. / n)) 395 | elif isinstance(m, nn.BatchNorm2d): 396 | m.weight.data.fill_(1) 397 | m.bias.data.zero_() 398 | 399 | def _make_layer(self, block, planes, blocks, stride=1): 400 | downsample = None 401 | if stride != 1 or self.inplanes != planes * block.expansion: 402 | downsample = nn.Sequential( 403 | nn.Conv2d(self.inplanes, planes * block.expansion, 404 | kernel_size=1, stride=stride, bias=False), 405 | nn.BatchNorm2d(planes * block.expansion), 406 | ) 407 | 408 | layers = [] 409 | layers.append(block(self.inplanes, planes, stride, downsample)) 410 | self.inplanes = planes * block.expansion 411 | for i in range(1, blocks): 412 | layers.append(block(self.inplanes, planes)) 413 | 414 | return nn.Sequential(*layers) 415 | 416 | def get_att_roi(self, att_mask, feature_stride, anchor_size, img_h, img_w, iou_thred=0.2, topk=1): 417 | """generation multi-leve ROIs upon spatial attention masks with NMS method""" 418 | with torch.no_grad(): 419 | roi_ret_nms = [] 420 | n, c, h, w = att_mask.size() 421 | att_corner_unmask = torch.zeros_like(att_mask).cuda() 422 | if self.num_classes == 200: 423 | att_corner_unmask[:, :, int(0.2 * h):int(0.8 * h), int(0.2 * w):int(0.8 * w)] = 1 424 | else: 425 | att_corner_unmask[:, :, int(0.1 * h):int(0.9 * h), int(0.1 * w):int(0.9 * w)] = 1 426 | att_mask = att_mask * att_corner_unmask 427 | feat_anchor = generate_anchors_single_pyramid([anchor_size], [1], [h, w], feature_stride, 1) 428 | feat_new_cls = att_mask.clone() 429 | for i in range(n): 430 | boxes = feat_anchor.clone().float() 431 | scores = feat_new_cls[i].view(-1) 432 | score_thred_index = scores > scores.mean() 433 | boxes = boxes[score_thred_index, :] 434 | scores = scores[score_thred_index] 435 | # nms 436 | nms_index = pth_nms(torch.cat([boxes, scores.unsqueeze(1)], dim=1), iou_thred)[:topk] 437 | boxes_nms = boxes[nms_index, :] 438 | if len(boxes_nms.size()) == 1: 439 | boxes_nms = boxes_nms.unsqueeze(0) 440 | # boxes_nms = pth_nms_merge(torch.cat([boxes, scores.unsqueeze(1)], dim=1), iou_thred, topk).cuda() 441 | boxes_nms[:, 0] = torch.clamp(boxes_nms[:, 0], min=0) 442 | boxes_nms[:, 1] = torch.clamp(boxes_nms[:, 1], min=0) 443 | boxes_nms[:, 2] = torch.clamp(boxes_nms[:, 2], max=img_w - 1) 444 | boxes_nms[:, 3] = torch.clamp(boxes_nms[:, 3], max=img_h - 1) 445 | roi_ret_nms.append(torch.cat([torch.FloatTensor([i] * boxes_nms.size(0)).unsqueeze(1).cuda(), boxes_nms], 1)) 446 | 447 | return torch.cat(roi_ret_nms, 0) 448 | 449 | def get_roi_crop_feat(self, x, roi_list, scale): 450 | """ROI guided refinement: ROI guided Zoom-in & ROI guided Dropblock""" 451 | n, c, x2_h, x2_w = x.size() 452 | roi_3, roi_4, roi_5 = roi_list 453 | roi_all = torch.cat([roi_3, roi_4, roi_5], 0) 454 | x2_ret = [] 455 | crop_info_all = [] 456 | if self.training: 457 | for i in range(n): 458 | roi_all_i = roi_all[roi_all[:, 0] == i] / scale 459 | xx1_resize, yy1_resize, = torch.min(roi_all_i[:, 1:3], 0)[0] 460 | xx2_resize, yy2_resize = torch.max(roi_all_i[:, 3:5], 0)[0] 461 | roi_3_i = roi_3[roi_3[:, 0] == i] / scale 462 | roi_4_i = roi_4[roi_4[:, 0] == i] / scale 463 | # alway drop the roi with highest score 464 | mask_un = torch.ones(c, x2_h, x2_w).cuda() 465 | pro_rand = random.random() 466 | if pro_rand < 0.3: 467 | ind_rand = random.randint(0, roi_3_i.size(0) - 1) 468 | xx1_drop, yy1_drop = roi_3_i[ind_rand, 1:3] 469 | xx2_drop, yy2_drop = roi_3_i[ind_rand, 3:5] 470 | mask_un[:, yy1_drop.long():yy2_drop.long(), xx1_drop.long():xx2_drop.long()] = 0 471 | elif pro_rand < 0.6: 472 | ind_rand = random.randint(0, roi_4_i.size(0) - 1) 473 | xx1_drop, yy1_drop = roi_4_i[ind_rand, 1:3] 474 | xx2_drop, yy2_drop = roi_4_i[ind_rand, 3:5] 475 | mask_un[:, yy1_drop.long():yy2_drop.long(), xx1_drop.long():xx2_drop.long()] = 0 476 | x2_drop = x[i] * mask_un 477 | x2_crop = x2_drop[:, yy1_resize.long():yy2_resize.long(), 478 | xx1_resize.long():xx2_resize.long()].contiguous().unsqueeze(0) 479 | # normalize 480 | scale_rate = c*(yy2_resize-yy1_resize)*(xx2_resize-xx1_resize) / torch.sum(mask_un[:, yy1_resize.long():yy2_resize.long(), 481 | xx1_resize.long():xx2_resize.long()]) 482 | x2_crop = x2_crop * scale_rate 483 | 484 | x2_crop_resize = F.interpolate(x2_crop, (x2_h, x2_w), mode='bilinear', align_corners=False) 485 | x2_ret.append(x2_crop_resize) 486 | 487 | crop_info = [xx1_resize, xx2_resize, yy1_resize, yy2_resize] 488 | crop_info_all.append(crop_info) 489 | else: 490 | for i in range(n): 491 | roi_all_i = roi_all[roi_all[:, 0] == i] / scale 492 | xx1_resize, yy1_resize, = torch.min(roi_all_i[:, 1:3], 0)[0] 493 | xx2_resize, yy2_resize = torch.max(roi_all_i[:, 3:5], 0)[0] 494 | x2_crop = x[i, :, yy1_resize.long():yy2_resize.long(), 495 | xx1_resize.long():xx2_resize.long()].contiguous().unsqueeze(0) 496 | x2_crop_resize = F.interpolate(x2_crop, (x2_h, x2_w), mode='bilinear', align_corners=False) 497 | x2_ret.append(x2_crop_resize) 498 | 499 | crop_info = [xx1_resize, xx2_resize, yy1_resize, yy2_resize] 500 | crop_info_all.append(crop_info) 501 | return torch.cat(x2_ret, 0), crop_info_all 502 | 503 | def Concate(self, f3, f4, f5): 504 | f3 = nn.AdaptiveAvgPool2d(output_size=1)(f3) 505 | f5 = nn.AdaptiveAvgPool2d(output_size=1)(f5) 506 | f4 = nn.AdaptiveAvgPool2d(output_size=1)(f4) 507 | f_concate = torch.cat([f3, f4, f5], dim=1) 508 | return f_concate 509 | 510 | def forward(self, inputs, targets): 511 | # ResNet backbone with FC removed 512 | n, c, img_h, img_w = inputs.size() 513 | x = self.conv1(inputs) 514 | x = self.bn1(x) 515 | x = self.relu(x) 516 | x = self.maxpool(x) 517 | 518 | x1 = self.layer1(x) 519 | x2 = self.layer2(x1) 520 | x3 = self.layer3(x2) 521 | x4 = self.layer4(x3) 522 | 523 | # stage I 524 | f3, f4, f5 = self.fpn([x2, x3, x4]) 525 | f3_att, f4_att, f5_att, a3, a4, a5 = self.apn([f3, f4, f5]) 526 | 527 | # feature concat 528 | f_concate = self.Concate(f3, f4, f5) 529 | out_concate = self.cls_concate(f_concate) 530 | loss_concate = self.criterion(out_concate, targets) 531 | 532 | out3 = self.cls3(f3_att) 533 | out4 = self.cls4(f4_att) 534 | out5 = self.cls5(f5_att) 535 | 536 | loss3 = self.criterion(out3, targets) 537 | loss4 = self.criterion(out4, targets) 538 | loss5 = self.criterion(out5, targets) 539 | 540 | loss = loss3 + loss4 + loss5 + loss_concate 541 | out = (out3 + out4 + out5 + out_concate) / 4 542 | _, predicted = torch.max(out.data, 1) 543 | correct = predicted.eq(targets.data).cpu().sum().item() 544 | 545 | 546 | # roi pyramid 547 | roi_3 = self.get_att_roi(a3, 2 ** 3, 64, img_h, img_w, iou_thred=0.05, topk=5) 548 | roi_4 = self.get_att_roi(a4, 2 ** 4, 128, img_h, img_w, iou_thred=0.05, topk=3) 549 | roi_5 = self.get_att_roi(a5, 2 ** 5, 256, img_h, img_w, iou_thred=0.05, topk=1) 550 | roi_list = [roi_3, roi_4, roi_5] 551 | 552 | # stage II 553 | x2_crop_resize, _ = self.get_roi_crop_feat(x2, roi_list, 2 ** 3) 554 | x3_crop_resize = self.layer3(x2_crop_resize) 555 | x4_crop_resize = self.layer4(x3_crop_resize) 556 | 557 | f3_crop_resize, f4_crop_resize, f5_crop_resize = self.fpn([x2_crop_resize, x3_crop_resize, x4_crop_resize]) 558 | f3_att_crop_resize, f4_att_crop_resize, f5_att_crop_resize, a3_crop_resize, a4_crop_resize, a5_crop_resize = self.apn([f3_crop_resize, f4_crop_resize, f5_crop_resize]) 559 | 560 | # feature concat 561 | f_concate_crop_resize = self.Concate(f3_crop_resize, f4_crop_resize, f5_crop_resize) 562 | out_concate_crop_resize = self.cls_concate(f_concate_crop_resize) 563 | loss_concate_crop_resize = self.criterion(out_concate_crop_resize, targets) 564 | 565 | out3_crop_resize = self.cls3(f3_att_crop_resize) 566 | out4_crop_resize = self.cls4(f4_att_crop_resize) 567 | out5_crop_resize = self.cls5(f5_att_crop_resize) 568 | 569 | loss3_crop_resize = self.criterion(out3_crop_resize, targets) 570 | loss4_crop_resize = self.criterion(out4_crop_resize, targets) 571 | loss5_crop_resize = self.criterion(out5_crop_resize, targets) 572 | 573 | loss_crop_resize = loss3_crop_resize + loss4_crop_resize + loss5_crop_resize + loss_concate_crop_resize 574 | out_crop_resize = (out3_crop_resize + out4_crop_resize + out5_crop_resize + out_concate_crop_resize) / 4 575 | _, predicted_crop_resize = torch.max(out_crop_resize.data, 1) 576 | correct_crop_resize = predicted_crop_resize.eq(targets.data).cpu().sum().item() 577 | 578 | 579 | out_mean = (out_crop_resize + out) / 2 580 | predicted_mean_, predicted_mean = torch.max(out_mean.data, 1) 581 | correct_mean = predicted_mean.eq(targets.data).cpu().sum().item() 582 | 583 | loss_ret = {'loss': loss + loss_crop_resize, 'loss1': loss, 'loss2': loss_crop_resize, 'loss3': loss} 584 | acc_ret = {'acc': correct_mean, 'acc1': correct, 'acc2': correct_crop_resize, 'acc3': correct} 585 | 586 | # attetion masks for visualizaton 587 | mask_cat = torch.cat([a3, 588 | F.interpolate(a4, a3.size()[2:]), 589 | F.interpolate(a5, a3.size()[2:])], 1) 590 | 591 | return loss_ret, acc_ret, mask_cat, roi_list 592 | 593 | 594 | def resnet18(num_classes, **kwargs): 595 | """Constructs a ResNet-18 model. 596 | Args: 597 | pretrained (bool): If True, returns a model pre-trained on ImageNet 598 | """ 599 | model = ResNet(num_classes, BasicBlock, [2, 2, 2, 2], **kwargs) 600 | return model 601 | 602 | 603 | def resnet34(num_classes, **kwargs): 604 | """Constructs a ResNet-34 model. 605 | Args: 606 | pretrained (bool): If True, returns a model pre-trained on ImageNet 607 | """ 608 | model = ResNet(num_classes, BasicBlock, [3, 4, 6, 3], **kwargs) 609 | return model 610 | 611 | 612 | def resnet50(num_classes, **kwargs): 613 | """Constructs a ResNet-50 model. 614 | Args: 615 | pretrained (bool): If True, returns a model pre-trained on ImageNet 616 | """ 617 | model = ResNet(num_classes, Bottleneck, [3, 4, 6, 3], **kwargs) 618 | return model 619 | 620 | 621 | def resnet101(num_classes, **kwargs): 622 | """Constructs a ResNet-101 model. 623 | Args: 624 | pretrained (bool): If True, returns a model pre-trained on ImageNet 625 | """ 626 | model = ResNet(num_classes, Bottleneck, [3, 4, 23, 3], **kwargs) 627 | return model 628 | 629 | 630 | def resnet152(num_classes, **kwargs): 631 | """Constructs a ResNet-152 model. 632 | Args: 633 | pretrained (bool): If True, returns a model pre-trained on ImageNet 634 | """ 635 | model = ResNet(num_classes, Bottleneck, [3, 8, 36, 3], **kwargs) 636 | return model 637 | -------------------------------------------------------------------------------- /model/vgg19.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | import math 5 | import time 6 | import os 7 | import numpy as np 8 | import cv2 9 | import random 10 | import torch.utils.model_zoo as model_zoo 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | from torch.utils.checkpoint import checkpoint_sequential 14 | 15 | def get_merge_bbox(dets, inds): 16 | xx1 = np.min(dets[inds][:,0]) 17 | yy1 = np.min(dets[inds][:,1]) 18 | xx2 = np.max(dets[inds][:,2]) 19 | yy2 = np.max(dets[inds][:,3]) 20 | 21 | return np.array((xx1, yy1, xx2, yy2)) 22 | 23 | def pth_nms_merge(dets, thresh, topk): 24 | dets = dets.cpu().data.numpy() 25 | x1 = dets[:, 0] 26 | y1 = dets[:, 1] 27 | x2 = dets[:, 2] 28 | y2 = dets[:, 3] 29 | scores = dets[:, 4] 30 | 31 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 32 | order = scores.argsort()[::-1] 33 | 34 | boxes_merge = [] 35 | cnt = 0 36 | while order.size > 0: 37 | i = order[0] 38 | 39 | xx1 = np.maximum(x1[i], x1[order[1:]]) 40 | yy1 = np.maximum(y1[i], y1[order[1:]]) 41 | xx2 = np.minimum(x2[i], x2[order[1:]]) 42 | yy2 = np.minimum(y2[i], y2[order[1:]]) 43 | w = np.maximum(0.0, xx2 - xx1 + 1) 44 | h = np.maximum(0.0, yy2 - yy1 + 1) 45 | inter = w * h 46 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 47 | inds = np.where(ovr <= thresh)[0] 48 | 49 | inds_merge = np.where((ovr > 0.5)*(0.9*scores[i]= topk: 55 | break 56 | 57 | return torch.from_numpy(np.array(boxes_merge)) 58 | 59 | def pth_nms(dets, thresh, topk): 60 | dets = dets.cpu().data.numpy() 61 | x1 = dets[:, 0] 62 | y1 = dets[:, 1] 63 | x2 = dets[:, 2] 64 | y2 = dets[:, 3] 65 | scores = dets[:, 4] 66 | 67 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 68 | order = scores.argsort()[::-1] 69 | 70 | boxes_merge = [] 71 | cnt = 0 72 | while order.size > 0: 73 | i = order[0] 74 | 75 | xx1 = np.maximum(x1[i], x1[order[1:]]) 76 | yy1 = np.maximum(y1[i], y1[order[1:]]) 77 | xx2 = np.minimum(x2[i], x2[order[1:]]) 78 | yy2 = np.minimum(y2[i], y2[order[1:]]) 79 | w = np.maximum(0.0, xx2 - xx1 + 1) 80 | h = np.maximum(0.0, yy2 - yy1 + 1) 81 | inter = w * h 82 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 83 | inds = np.where(ovr <= thresh)[0] 84 | 85 | xx1 = dets[i, 0] 86 | yy1 = dets[i, 1] 87 | xx2 = dets[i, 2] 88 | yy2 = dets[i, 3] 89 | boxes_merge.append(np.array((xx1, yy1, xx2, yy2))) 90 | order = order[inds + 1] 91 | 92 | cnt += 1 93 | if cnt >= topk: 94 | break 95 | 96 | return torch.from_numpy(np.array(boxes_merge)) 97 | 98 | class BasicConv(nn.Module): 99 | 100 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 101 | bn=True, bias=False): 102 | super(BasicConv, self).__init__() 103 | self.out_channels = out_planes 104 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 105 | dilation=dilation, groups=groups, bias=bias) 106 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 107 | self.relu = nn.ReLU(inplace=True) if relu else None 108 | 109 | def forward(self, x): 110 | x = self.conv(x) 111 | if self.bn is not None: 112 | x = self.bn(x) 113 | if self.relu is not None: 114 | x = self.relu(x) 115 | return x 116 | 117 | 118 | def conv3x3(in_planes, out_planes, stride=1): 119 | """3x3 convolution with padding""" 120 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 121 | padding=1, bias=False) 122 | 123 | 124 | class BasicBlock(nn.Module): 125 | expansion = 1 126 | 127 | def __init__(self, inplanes, planes, stride=1, downsample=None): 128 | super(BasicBlock, self).__init__() 129 | self.conv1 = conv3x3(inplanes, planes, stride) 130 | self.bn1 = nn.BatchNorm2d(planes) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.conv2 = conv3x3(planes, planes) 133 | self.bn2 = nn.BatchNorm2d(planes) 134 | self.downsample = downsample 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | residual = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | 147 | if self.downsample is not None: 148 | residual = self.downsample(x) 149 | 150 | out += residual 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | 156 | class Bottleneck(nn.Module): 157 | expansion = 4 158 | 159 | def __init__(self, inplanes, planes, stride=1, downsample=None): 160 | super(Bottleneck, self).__init__() 161 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 162 | self.bn1 = nn.BatchNorm2d(planes) 163 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 164 | padding=1, bias=False) 165 | self.bn2 = nn.BatchNorm2d(planes) 166 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 167 | self.bn3 = nn.BatchNorm2d(planes * 4) 168 | self.relu = nn.ReLU(inplace=True) 169 | self.downsample = downsample 170 | self.stride = stride 171 | 172 | def forward(self, x): 173 | residual = x 174 | 175 | out = self.conv1(x) 176 | out = self.bn1(out) 177 | out = self.relu(out) 178 | 179 | out = self.conv2(out) 180 | out = self.bn2(out) 181 | out = self.relu(out) 182 | 183 | out = self.conv3(out) 184 | out = self.bn3(out) 185 | 186 | if self.downsample is not None: 187 | residual = self.downsample(x) 188 | 189 | out += residual 190 | out = self.relu(out) 191 | 192 | return out 193 | 194 | 195 | class SimpleFPA(nn.Module): 196 | def __init__(self, in_planes, out_planes): 197 | """ 198 | Feature Pyramid Attention 199 | :type channels: int 200 | """ 201 | super(SimpleFPA, self).__init__() 202 | 203 | self.channels_cond = in_planes 204 | # Master branch 205 | self.conv_master = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 206 | 207 | # Global pooling branch 208 | self.conv_gpb = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 209 | 210 | def forward(self, x): 211 | """ 212 | :param x: Shape: [b, 2048, h, w] 213 | :return: out: Feature maps. Shape: [b, 2048, h, w] 214 | """ 215 | # Master branch 216 | x_master = self.conv_master(x) 217 | 218 | # Global pooling branch 219 | x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) 220 | x_gpb = self.conv_gpb(x_gpb) 221 | 222 | out = x_master + x_gpb 223 | 224 | return out 225 | 226 | 227 | class PyramidFeatures(nn.Module): 228 | def __init__(self, C2_size, C3_size, C4_size, C5_size, feature_size=256): 229 | super(PyramidFeatures, self).__init__() 230 | 231 | # upsample C5 to get P5 from the FPN paper 232 | # self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0) 233 | self.P5_1 = SimpleFPA(C5_size, feature_size) 234 | self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 235 | self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 236 | 237 | # add P5 elementwise to C4 238 | # self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0) 239 | self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 240 | self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 241 | 242 | # add P4 elementwise to C3 243 | # self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0) 244 | self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 245 | 246 | self.relu = nn.ReLU(inplace=True) 247 | 248 | def forward(self, inputs): 249 | C3, C4, C5 = inputs 250 | 251 | # remove lateral connection 252 | P5_x = self.P5_1(C5) 253 | P5_upsampled_x = self.P5_upsampled(P5_x) 254 | P5_x = self.P5_2(P5_x) 255 | # add activation 256 | P5_x = self.relu(P5_x) 257 | 258 | P4_x = P5_upsampled_x 259 | P4_upsampled_x = self.P4_upsampled(P4_x) 260 | P4_x = self.P4_2(P4_x) 261 | # add activation 262 | P4_x = self.relu(P4_x) 263 | 264 | P3_x = P4_upsampled_x 265 | P3_x = self.P3_2(P3_x) 266 | # add activation 267 | P3_x = self.relu(P3_x) 268 | 269 | return [P3_x, P4_x, P5_x] 270 | 271 | class PyramidAttentions(nn.Module): 272 | def __init__(self, channel_size=256): 273 | super(PyramidAttentions, self).__init__() 274 | 275 | self.A3_1 = SpatialGate(channel_size) 276 | self.A3_2 = ChannelGate(channel_size) 277 | 278 | self.A4_1 = SpatialGate(channel_size) 279 | self.A4_2 = ChannelGate(channel_size) 280 | 281 | self.A5_1 = SpatialGate(channel_size) 282 | self.A5_2 = ChannelGate(channel_size) 283 | 284 | def forward(self, inputs): 285 | f3, f4, f5 = inputs 286 | 287 | A3_spatial = self.A3_1(f3) 288 | A3_channel = self.A3_2(f3) 289 | A3 = A3_spatial*f3 + A3_channel*f3 290 | 291 | A4_spatial = self.A4_1(f4) 292 | A4_channel = self.A4_2(f4) 293 | A4_channel = (A4_channel + A3_channel) / 2 294 | A4 = A4_spatial*f4 + A4_channel*f4 295 | 296 | A5_spatial = self.A5_1(f5) 297 | A5_channel = self.A5_2(f5) 298 | A5_channel = (A5_channel + A4_channel) / 2 299 | A5 = A5_spatial*f5 + A5_channel*f5 300 | 301 | return [A3, A4, A5, A3_spatial, A4_spatial, A5_spatial] 302 | 303 | class SpatialGate(nn.Module): 304 | """docstring for SpatialGate""" 305 | def __init__(self, out_channels): 306 | super(SpatialGate, self).__init__() 307 | self.conv = nn.ConvTranspose2d(out_channels,1,kernel_size=3,stride=1,padding=1) 308 | def forward(self, x): 309 | x = self.conv(x) 310 | return torch.sigmoid(x) 311 | 312 | class ChannelGate(nn.Module): 313 | """docstring for SpatialGate""" 314 | def __init__(self, out_channels): 315 | super(ChannelGate, self).__init__() 316 | self.conv1 = nn.Conv2d(out_channels,out_channels//16,kernel_size=1,stride=1,padding=0) 317 | self.conv2 = nn.Conv2d(out_channels//16,out_channels,kernel_size=1,stride=1,padding=0) 318 | def forward(self, x): 319 | x = nn.AdaptiveAvgPool2d(output_size=1)(x) 320 | x = F.relu(self.conv1(x), inplace=True) 321 | x = torch.sigmoid(self.conv2(x)) 322 | return x 323 | 324 | 325 | class Flatten(nn.Module): 326 | def __init__(self): 327 | super(Flatten, self).__init__() 328 | 329 | def forward(self, x): 330 | return x.view(x.size(0), -1) 331 | 332 | def generate_anchors_single_pyramid(scales, ratios, shape, feature_stride, anchor_stride): 333 | """ 334 | scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] 335 | ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] 336 | shape: [height, width] spatial shape of the feature map over which 337 | to generate anchors. 338 | feature_stride: Stride of the feature map relative to the image in pixels. 339 | anchor_stride: Stride of anchors on the feature map. For example, if the 340 | value is 2 then generate anchors for every other feature map pixel. 341 | """ 342 | # Get all combinations of scales and ratios 343 | scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) 344 | scales = scales.flatten() 345 | ratios = ratios.flatten() 346 | 347 | # Enumerate heights and widths from scales and ratios 348 | heights = scales / np.sqrt(ratios) 349 | widths = scales * np.sqrt(ratios) 350 | 351 | # Enumerate shifts in feature space 352 | shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride 353 | shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride 354 | shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) 355 | 356 | # Enumerate combinations of shifts, widths, and heights 357 | box_widths, box_centers_x = np.meshgrid(widths, shifts_x) 358 | box_heights, box_centers_y = np.meshgrid(heights, shifts_y) 359 | 360 | box_centers = np.stack( 361 | [box_centers_x, box_centers_y], axis=2).reshape([-1, 2]) 362 | box_sizes = np.stack([box_widths, box_heights], axis=2).reshape([-1, 2]) 363 | 364 | # Convert to corner coordinates (x1, y1, x2, y2) 365 | boxes = np.concatenate([box_centers - 0.5 * box_sizes, 366 | box_centers + 0.5 * box_sizes], axis=1) 367 | return torch.from_numpy(boxes).cuda() 368 | 369 | class VGG(nn.Module): 370 | 371 | def __init__(self, features, num_classes=1000, init_weights=False): 372 | super(VGG, self).__init__() 373 | self.num_classes = num_classes 374 | self.features = features 375 | 376 | self.num_segments = 2 377 | 378 | fpn_sizes = [128, 256, 512, 512] 379 | self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3]) 380 | self.apn = PyramidAttentions() 381 | 382 | self.cls5 = nn.Sequential( 383 | nn.AdaptiveAvgPool2d(1), 384 | Flatten(), 385 | nn.BatchNorm1d(256), 386 | nn.Linear(256, 512), 387 | nn.BatchNorm1d(512), 388 | nn.ELU(inplace=True), 389 | nn.Linear(512, self.num_classes) 390 | ) 391 | 392 | self.cls4 = nn.Sequential( 393 | nn.AdaptiveAvgPool2d(1), 394 | Flatten(), 395 | nn.BatchNorm1d(256), 396 | nn.Linear(256, 512), 397 | nn.BatchNorm1d(512), 398 | nn.ELU(inplace=True), 399 | nn.Linear(512, self.num_classes) 400 | ) 401 | 402 | self.cls3 = nn.Sequential( 403 | nn.AdaptiveAvgPool2d(1), 404 | Flatten(), 405 | nn.BatchNorm1d(256), 406 | nn.Linear(256, 512), 407 | nn.BatchNorm1d(512), 408 | nn.ELU(inplace=True), 409 | nn.Linear(512, self.num_classes) 410 | ) 411 | 412 | self.cls_main = nn.Sequential( 413 | nn.AdaptiveAvgPool2d(1), 414 | Flatten(), 415 | nn.BatchNorm1d(512), 416 | nn.Linear(512, 512), 417 | nn.BatchNorm1d(512), 418 | nn.ELU(inplace=True), 419 | nn.Linear(512, self.num_classes) 420 | ) 421 | 422 | self.criterion = nn.CrossEntropyLoss() 423 | 424 | if init_weights: 425 | for m in self.modules(): 426 | if isinstance(m, nn.Conv2d): 427 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 428 | m.weight.data.normal_(0, math.sqrt(2. / n)) 429 | # init.xavier_normal(m.weight) 430 | elif isinstance(m, nn.BatchNorm2d): 431 | m.weight.data.fill_(1) 432 | m.bias.data.zero_() 433 | 434 | def get_att_roi(self, att_mask, feature_stride, anchor_size, img_h, img_w, iou_thred=0.2, topk=1): 435 | with torch.no_grad(): 436 | roi_ret_nms = [] 437 | n, c, h, w = att_mask.size() 438 | att_corner_unmask = torch.zeros_like(att_mask).cuda() 439 | if self.num_classes == 200: 440 | att_corner_unmask[:, :, int(0.2 * h):int(0.8 * h), int(0.2 * w):int(0.8 * w)] = 1 441 | else: 442 | att_corner_unmask[:, :, int(0.1 * h):int(0.9 * h), int(0.1 * w):int(0.9 * w)] = 1 443 | att_mask = att_mask * att_corner_unmask 444 | feat_anchor = generate_anchors_single_pyramid([anchor_size], [1], [h, w], feature_stride, 1) 445 | feat_new_cls = att_mask.clone() 446 | for i in range(n): 447 | boxes = feat_anchor.clone().float() 448 | scores = feat_new_cls[i].view(-1) 449 | score_thred_index = scores > scores.mean() 450 | boxes = boxes[score_thred_index, :] 451 | scores = scores[score_thred_index] 452 | boxes_nms = pth_nms(torch.cat([boxes, scores.unsqueeze(1)], dim=1), iou_thred, topk).cuda() 453 | boxes_nms[:, 0] = torch.clamp(boxes_nms[:, 0], min=0) 454 | boxes_nms[:, 1] = torch.clamp(boxes_nms[:, 1], min=0) 455 | boxes_nms[:, 2] = torch.clamp(boxes_nms[:, 2], max=img_w - 1) 456 | boxes_nms[:, 3] = torch.clamp(boxes_nms[:, 3], max=img_h - 1) 457 | roi_ret_nms.append(torch.cat([torch.FloatTensor([i] * boxes_nms.size(0)).unsqueeze(1).cuda(), boxes_nms], 1)) 458 | 459 | return torch.cat(roi_ret_nms, 0) 460 | 461 | def get_roi_crop_feat(self, x, roi_list, scale): 462 | n, c, x2_h, x2_w = x.size() 463 | roi_3, roi_4, roi_5 = roi_list 464 | roi_all = torch.cat([roi_3, roi_4, roi_5], 0) 465 | x2_ret = [] 466 | if self.training: 467 | for i in range(n): 468 | roi_all_i = roi_all[roi_all[:, 0] == i] / scale 469 | xx1_resize, yy1_resize, = torch.min(roi_all_i[:, 1:3], 0)[0] 470 | xx2_resize, yy2_resize = torch.max(roi_all_i[:, 3:5], 0)[0] 471 | roi_3_i = roi_3[roi_3[:, 0] == i] / scale 472 | roi_4_i = roi_4[roi_4[:, 0] == i] / scale 473 | # alway drop the roi with highest score 474 | mask_un = torch.ones(c, x2_h, x2_w).cuda() 475 | pro_rand = random.random() 476 | if pro_rand < 0.3: 477 | ind_rand = random.randint(0, roi_3_i.size(0) - 1) 478 | xx1_drop, yy1_drop = roi_3_i[ind_rand, 1:3] 479 | xx2_drop, yy2_drop = roi_3_i[ind_rand, 3:5] 480 | mask_un[:, yy1_drop.long():yy2_drop.long(), xx1_drop.long():xx2_drop.long()] = 0 481 | elif pro_rand < 0.7: 482 | ind_rand = random.randint(0, roi_4_i.size(0) - 1) 483 | xx1_drop, yy1_drop = roi_4_i[ind_rand, 1:3] 484 | xx2_drop, yy2_drop = roi_4_i[ind_rand, 3:5] 485 | mask_un[:, yy1_drop.long():yy2_drop.long(), xx1_drop.long():xx2_drop.long()] = 0 486 | x2_drop = x[i] * mask_un 487 | x2_crop = x2_drop[:, yy1_resize.long():yy2_resize.long(), 488 | xx1_resize.long():xx2_resize.long()].contiguous().unsqueeze(0) 489 | 490 | # normalize 491 | scale_rate = c*(yy2_resize-yy1_resize)*(xx2_resize-xx1_resize) / torch.sum(mask_un[:, yy1_resize.long():yy2_resize.long(), 492 | xx1_resize.long():xx2_resize.long()]) 493 | x2_crop = x2_crop * scale_rate 494 | 495 | x2_crop_resize = F.upsample(x2_crop, (x2_h, x2_w), mode='bilinear', align_corners=False) 496 | x2_ret.append(x2_crop_resize) 497 | else: 498 | for i in range(n): 499 | roi_all_i = roi_all[roi_all[:, 0] == i] / scale 500 | xx1_resize, yy1_resize, = torch.min(roi_all_i[:, 1:3], 0)[0] 501 | xx2_resize, yy2_resize = torch.max(roi_all_i[:, 3:5], 0)[0] 502 | x2_crop = x[i, :, yy1_resize.long():yy2_resize.long(), 503 | xx1_resize.long():xx2_resize.long()].contiguous().unsqueeze(0) 504 | x2_crop_resize = F.upsample(x2_crop, (x2_h, x2_w), mode='bilinear', align_corners=False) 505 | x2_ret.append(x2_crop_resize) 506 | return torch.cat(x2_ret, 0) 507 | 508 | 509 | def forward(self, inputs, targets): 510 | # inputs.requires_grad = True 511 | n, c, img_h, img_w = inputs.size() 512 | 513 | x3 = checkpoint_sequential(nn.Sequential(*list(self.features.children())[:27]), self.num_segments, inputs) 514 | x4 = checkpoint_sequential(nn.Sequential(*list(self.features.children())[27:40]), self.num_segments, x3) 515 | x5 = checkpoint_sequential(nn.Sequential(*list(self.features.children())[40:]), self.num_segments, x4) 516 | 517 | # stage I 518 | f3, f4, f5 = self.fpn([x3, x4, x5]) 519 | f3_att, f4_att, f5_att, a3, a4, a5 = self.apn([f3, f4, f5]) 520 | 521 | 522 | out3 = self.cls3(f3_att) 523 | out4 = self.cls4(f4_att) 524 | out5 = self.cls5(f5_att) 525 | loss3 = self.criterion(out3, targets) 526 | loss4 = self.criterion(out4, targets) 527 | loss5 = self.criterion(out5, targets) 528 | 529 | # origin classifier 530 | out_main = self.cls_main(x5) 531 | loss_main = self.criterion(out_main, targets) 532 | 533 | loss = loss3 + loss4 + loss5 + loss_main 534 | out = (F.softmax(out3, 1) + F.softmax(out4, 1) + F.softmax(out5, 1) + F.softmax(out_main, 1)) / 4 535 | _, predicted = torch.max(out.data, 1) 536 | correct = predicted.eq(targets.data).cpu().sum().item() 537 | 538 | # stage II 539 | roi_3 = self.get_att_roi(a3, 2 ** 3, 64, img_h, img_w, iou_thred=0.05, topk=5) 540 | roi_4 = self.get_att_roi(a4, 2 ** 4, 128, img_h, img_w, iou_thred=0.05, topk=3) 541 | roi_5 = self.get_att_roi(a5, 2 ** 5, 256, img_h, img_w, iou_thred=0.05, topk=1) 542 | roi_list = [roi_3, roi_4, roi_5] 543 | 544 | x3_crop_resize = self.get_roi_crop_feat(x3, roi_list, 2 ** 3) 545 | x4_crop_resize = checkpoint_sequential(nn.Sequential(*list(self.features.children())[27:40]), self.num_segments, x3_crop_resize) 546 | x5_crop_resize = checkpoint_sequential(nn.Sequential(*list(self.features.children())[40:]), self.num_segments, x4_crop_resize) 547 | 548 | f3_crop_resize, f4_crop_resize, f5_crop_resize = self.fpn([x3_crop_resize, x4_crop_resize, x5_crop_resize]) 549 | f3_att_crop_resize, f4_att_crop_resize, f5_att_crop_resize, _, _, _ = self.apn([f3_crop_resize, f4_crop_resize, f5_crop_resize]) 550 | 551 | 552 | out3_crop_resize = self.cls3(f3_att_crop_resize) 553 | out4_crop_resize = self.cls4(f4_att_crop_resize) 554 | out5_crop_resize = self.cls5(f5_att_crop_resize) 555 | loss3_crop_resize = self.criterion(out3_crop_resize, targets) 556 | loss4_crop_resize = self.criterion(out4_crop_resize, targets) 557 | loss5_crop_resize = self.criterion(out5_crop_resize, targets) 558 | 559 | # origin classifier 560 | out_main_crop_resize = self.cls_main(x5_crop_resize) 561 | loss_main_crop_resize = self.criterion(out_main_crop_resize, targets) 562 | 563 | loss_crop_resize = loss3_crop_resize + loss4_crop_resize + loss5_crop_resize + loss_main_crop_resize 564 | out_crop_resize = (F.softmax(out3_crop_resize, 1) + F.softmax(out4_crop_resize, 1) + F.softmax(out5_crop_resize, 1) + F.softmax(out_main_crop_resize, 1)) / 4 565 | _, predicted_crop_resize = torch.max(out_crop_resize.data, 1) 566 | correct_crop_resize = predicted_crop_resize.eq(targets.data).cpu().sum().item() 567 | 568 | 569 | out_mean = (out_crop_resize + out) / 2 570 | predicted_mean_, predicted_mean = torch.max(out_mean.data, 1) 571 | correct_mean = predicted_mean.eq(targets.data).cpu().sum().item() 572 | 573 | loss_ret = {'loss': loss + loss_crop_resize, 'loss1': loss, 'loss2': loss_crop_resize, 'loss3': loss} 574 | acc_ret = {'acc': correct_mean, 'acc1': correct, 'acc2': correct_crop_resize, 'acc3': correct} 575 | 576 | mask_cat = torch.cat([a3, 577 | F.upsample(a4, a3.size()[2:]), 578 | F.upsample(a5, a3.size()[2:])], 1) 579 | 580 | return loss_ret, acc_ret, mask_cat, roi_list 581 | 582 | 583 | def make_layers(cfg, batch_norm=False): 584 | layers = [] 585 | in_channels = 3 586 | for v in cfg: 587 | if v == 'M': 588 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 589 | else: 590 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 591 | if batch_norm: 592 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 593 | else: 594 | layers += [conv2d, nn.ReLU(inplace=True)] 595 | in_channels = v 596 | return nn.Sequential(*layers) 597 | 598 | cfg = { 599 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 600 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 601 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 602 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 603 | } 604 | 605 | def vgg16(num_class): 606 | """VGG 16-layer model (configuration "D") with batch normalization 607 | 608 | """ 609 | model = VGG(make_layers(cfg['D'], batch_norm=True), num_classes=num_class) 610 | 611 | return model 612 | 613 | def vgg19(num_class): 614 | 615 | model = VGG(make_layers(cfg['E'], batch_norm=True), num_classes=num_class) 616 | 617 | return model 618 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pip install -r requirements.txt 2 | numpy == 1.22.0 3 | torch == 0.4.1 4 | torchvision == 0.2.1 5 | opencv-python 6 | visdom 7 | pillow 8 | wget 9 | cffi 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import argparse 6 | import torchvision 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torch.backends.cudnn as cudnn 13 | import torchvision 14 | import cv2 15 | import torchvision.transforms as transforms 16 | from requests.utils import urlparse 17 | import wget 18 | 19 | import model 20 | import model.resnet50 21 | import model.vgg19 22 | from utils.utils import load_config, setup_seed, plot_roi, plot_mask_cat 23 | from utils.visualize import Visualizer 24 | from utils.transform import UnNormalizer 25 | from PIL import Image 26 | 27 | def main(): 28 | model_options = ['resnet50', 'vgg19'] 29 | dataset_options = ['birds', 'cars', 'airs'] 30 | 31 | parser = argparse.ArgumentParser(description='AP-CNN') 32 | parser.add_argument('--dataset', '-d', default='birds', 33 | choices=dataset_options) 34 | parser.add_argument('--model', '-a', default='resnet50', 35 | choices=model_options) 36 | parser.add_argument('--seed', type=int, default=1, 37 | help='random seed (default: 1)') 38 | parser.add_argument("--gpu", type=int, default=0, 39 | help='gpu index (default: 0)') 40 | parser.add_argument('--visualize', action='store_true', default=False, 41 | help='plot attention masks and ROIs') 42 | 43 | args = parser.parse_args() 44 | 45 | setup_seed(args.seed) 46 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 47 | 48 | ### prepare configurations 49 | config_file = "configs/config_{}.yaml".format(args.dataset) 50 | config = load_config(config_file) 51 | # data config 52 | train_dir = config['train_dir'] 53 | test_dir = config['test_dir'] 54 | num_class = config['num_class'] 55 | # model config 56 | batch_size = config['batch_size'] 57 | learning_rate = config['learning_rate'] 58 | momentum = config['momentum'] 59 | weight_decay = float(config['weight_decay']) 60 | num_epoch = config['num_epoch'] 61 | resize_size = config['resize_size'] 62 | crop_size = config['crop_size'] 63 | # visualizer config 64 | vis_host = config['vis_host'] 65 | vis_port = config['vis_port'] 66 | 67 | ### setup exp_dir 68 | exp_name = "AP-CNN_{}_{}".format(args.model, args.dataset) 69 | time_str = time.strftime("%m-%d-%H-%M", time.localtime()) 70 | exp_dir = os.path.join("./logs", exp_name + '_' + time_str) 71 | if not os.path.exists(exp_dir): 72 | os.makedirs(exp_dir) 73 | # generate log files 74 | logger = logging.getLogger() 75 | logger.setLevel(logging.INFO) 76 | logging.basicConfig(filename=os.path.join(exp_dir, 'train.log'), level=logging.INFO, filemode='w') 77 | console = logging.StreamHandler() 78 | console.setLevel(logging.INFO) 79 | formatter = logging.Formatter('%(levelname)-4s %(message)s') 80 | console.setFormatter(formatter) 81 | logging.getLogger('').addHandler(console) 82 | 83 | logging.info('==>exp dir:%s' % exp_dir) 84 | logging.info("OPENING " + exp_dir + '/results_train.csv') 85 | logging.info("OPENING " + exp_dir + '/results_test.csv') 86 | 87 | results_train_file = open(exp_dir + '/results_train.csv', 'w') 88 | results_train_file.write('epoch, train_acc, train_loss\n') 89 | results_train_file.flush() 90 | results_test_file = open(exp_dir + '/results_test.csv', 'w') 91 | results_test_file.write('epoch, test_acc, test_loss\n') 92 | results_test_file.flush() 93 | 94 | # set up Visualizer 95 | vis = Visualizer(env=exp_name, port=vis_port, server=vis_host) 96 | 97 | ### preparing data 98 | logging.info('==> Preparing data..') 99 | 100 | transform_train = transforms.Compose([ 101 | transforms.Resize((resize_size, resize_size), Image.BILINEAR), 102 | transforms.RandomCrop(crop_size), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 106 | ]) 107 | 108 | transform_test = transforms.Compose([ 109 | transforms.Resize((resize_size, resize_size), Image.BILINEAR), 110 | transforms.CenterCrop(crop_size), 111 | transforms.ToTensor(), 112 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 113 | ]) 114 | 115 | unorm = UnNormalizer([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 116 | 117 | trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train) 118 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) 119 | 120 | testset = torchvision.datasets.ImageFolder(root=test_dir, transform=transform_test) 121 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4) 122 | logging.info('==> Successfully Preparing data..') 123 | 124 | ### building model 125 | logging.info('==> Building model..') 126 | # load pretrained backbone on ImageNet 127 | if args.model == "resnet50": 128 | url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 129 | elif args.model == "vgg19": 130 | url = 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth' 131 | model_dir = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch/models')) 132 | filename = os.path.basename(urlparse(url).path) 133 | pretrained_path = os.path.join(model_dir, filename) 134 | if not os.path.exists(pretrained_path): 135 | wget.download(url, pretrained_path) 136 | net = getattr(getattr(model, args.model), args.model)(num_class) 137 | if pretrained_path: 138 | logging.info('load pretrained backbone') 139 | net_dict = net.state_dict() 140 | pretrained_dict = torch.load(pretrained_path) 141 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net_dict} 142 | net_dict.update(pretrained_dict) 143 | net.load_state_dict(net_dict) 144 | use_cuda = torch.cuda.is_available() 145 | if use_cuda: 146 | net.cuda() 147 | cudnn.benchmark = True 148 | logging.info('==> Successfully Building model..') 149 | 150 | ### training scripts 151 | def train(epoch): 152 | logging.info('Epoch: %d' % epoch) 153 | net.train() 154 | train_loss = 0 155 | correct = 0 156 | total = 0 157 | idx = 0 158 | flag = 0 159 | count = 0 160 | 161 | for batch_idx, (inputs, targets) in enumerate(trainloader): 162 | idx = batch_idx 163 | if use_cuda: 164 | inputs, targets = inputs.cuda(), targets.cuda() 165 | optimizer.zero_grad() 166 | inputs, targets = Variable(inputs), Variable(targets) 167 | loss_ret, acc_ret, mask_cat, roi_list = net(inputs, targets) 168 | loss = loss_ret['loss'] 169 | loss.backward() 170 | optimizer.step() 171 | train_loss += loss.data 172 | total += targets.size(0) 173 | correct += acc_ret['acc'] 174 | if args.visualize and flag % 100 == 0: 175 | plot_mask_cat(inputs, mask_cat, unorm, vis, 'train') 176 | plot_roi(inputs, roi_list, unorm, vis, 'train') 177 | flag += 1 178 | train_acc = 100. * correct / total 179 | train_loss = train_loss / (idx + 1) 180 | logging.info('Iteration %d, train_acc = %.4f, train_loss = %.4f' % (epoch, train_acc, train_loss)) 181 | results_train_file.write('%d, %.4f,%.4f\n' % (epoch, train_acc, train_loss)) 182 | results_train_file.flush() 183 | return train_acc, train_loss 184 | 185 | ### test scripts 186 | def test(epoch): 187 | with torch.no_grad(): 188 | net.eval() 189 | test_loss = 0 190 | correct = 0 191 | total = 0 192 | idx = 0 193 | flag = 0 194 | count = 0 195 | for batch_idx, (inputs, targets) in enumerate(testloader): 196 | idx = batch_idx 197 | if use_cuda: 198 | inputs, targets = inputs.cuda(), targets.cuda() 199 | inputs, targets = Variable(inputs), Variable(targets) 200 | loss_ret, acc_ret, mask_cat, roi_list = net(inputs, targets) 201 | loss = loss_ret['loss'] 202 | 203 | test_loss += loss.data 204 | total += targets.size(0) 205 | correct += acc_ret['acc'] 206 | if args.visualize and flag % 100 == 0: 207 | plot_mask_cat(inputs, mask_cat, unorm, vis, 'test') 208 | plot_roi(inputs, roi_list, unorm, vis, 'test') 209 | flag += 1 210 | 211 | test_acc = 100. * correct / total 212 | test_loss = test_loss / (idx + 1) 213 | logging.info('Iteration %d, test_acc = %.4f, test_loss = %.4f' % (epoch, test_acc, test_loss)) 214 | results_test_file.write('%d, %.4f,%.4f\n' % (epoch, test_acc, test_loss)) 215 | results_test_file.flush() 216 | return test_acc, test_loss 217 | 218 | if args.dataset == 'birds': 219 | optimizer = optim.SGD([ 220 | {'params': nn.Sequential(*list(net.children())[7:]).parameters(), 'lr': learning_rate}, 221 | {'params': nn.Sequential(*list(net.children())[:7]).parameters(), 'lr': learning_rate/10} 222 | 223 | ], 224 | momentum=momentum, weight_decay=weight_decay) 225 | 226 | def cosine_anneal_schedule(t): 227 | cos_inner = np.pi * (t % (num_epoch)) 228 | cos_inner /= (num_epoch) 229 | cos_out = np.cos(cos_inner) + 1 230 | return float( learning_rate / 2 * cos_out) 231 | 232 | max_test_acc = 0. 233 | for epoch in range(0, num_epoch): 234 | optimizer.param_groups[0]['lr'] = cosine_anneal_schedule(epoch) 235 | optimizer.param_groups[1]['lr'] = cosine_anneal_schedule(epoch) / 10 236 | for param_group in optimizer.param_groups: 237 | print(param_group['lr']) 238 | train(epoch) 239 | test_acc, _ = test(epoch) 240 | if test_acc > max_test_acc: 241 | max_test_acc = test_acc 242 | torch.save(net.state_dict(), os.path.join(exp_dir, 'model_best.pth')) 243 | print('max_test_acc=',max_test_acc) 244 | 245 | else: 246 | optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) 247 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoch) 248 | 249 | max_test_acc = 0. 250 | for epoch in range(0, num_epoch): 251 | scheduler.step(epoch) 252 | for param_group in optimizer.param_groups: 253 | print(param_group['lr']) 254 | train(epoch) 255 | test_acc, _ = test(epoch) 256 | if test_acc > max_test_acc: 257 | max_test_acc = test_acc 258 | torch.save(net.state_dict(), os.path.join(exp_dir, 'model_best.pth')) 259 | print('max_test_acc=',max_test_acc) 260 | 261 | torch.save(net.state_dict(), os.path.join(exp_dir, 'model_final.pth')) 262 | 263 | if __name__=="__main__": 264 | main() -------------------------------------------------------------------------------- /utils/split_dataset/airs_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import shutil 4 | 5 | 6 | # divivd dataset (without annotations) 7 | img_dir = 'data/airs/' 8 | 9 | save_dir = 'data/Aircraft/' 10 | if not os.path.exists(save_dir): 11 | os.mkdir(save_dir) 12 | save_dir_train = os.path.join(save_dir, 'train') 13 | if not os.path.exists(save_dir_train): 14 | os.mkdir(save_dir_train) 15 | save_dir_test = os.path.join(save_dir, 'test') 16 | if not os.path.exists(save_dir_test): 17 | os.mkdir(save_dir_test) 18 | 19 | # generate train dataset 20 | f = open(os.path.join(img_dir, "images_variant_trainval.txt")) 21 | foo = f.readlines() 22 | 23 | for i in range(len(foo)): 24 | index = foo[i].find(" ") 25 | image_name = foo[i][:index] + ".jpg" 26 | classes = foo[i][index+1:][:-1] 27 | if classes.find("/")>=0: 28 | classes = classes[:classes.find("/")] + "_" + classes[classes.find("/")+1:] 29 | else: 30 | pass 31 | if classes.find(" ")>=0: 32 | classes = classes[:classes.find(" ")] + "_" + classes[classes.find(" ")+1:] 33 | else: 34 | pass 35 | # make class dir 36 | try: 37 | os.mkdir(os.path.join(save_dir_train, classes)) 38 | except: 39 | print("file already exists") 40 | src_path = os.path.join(img_dir, 'images', image_name) 41 | dst_path = os.path.join(save_dir_train, classes, image_name) 42 | try: 43 | shutil.copyfile(src_path, dst_path) 44 | print("src:", src_path, "dst:", dst_path) 45 | except: 46 | print("error",i,foo[i]) 47 | break 48 | 49 | # generate test dataset 50 | f = open(os.path.join(img_dir, "images_variant_test.txt")) 51 | foo = f.readlines() 52 | 53 | for i in range(len(foo)): 54 | index = foo[i].find(" ") 55 | image_name = foo[i][:index] + ".jpg" 56 | classes = foo[i][index+1:][:-1] 57 | if classes.find("/")>=0: 58 | classes = classes[:classes.find("/")] + "_" + classes[classes.find("/")+1:] 59 | else: 60 | pass 61 | if classes.find(" ")>=0: 62 | classes = classes[:classes.find(" ")] + "_" + classes[classes.find(" ")+1:] 63 | else: 64 | pass 65 | # make class dir 66 | try: 67 | os.mkdir(os.path.join(save_dir_test, classes)) 68 | except: 69 | print("file already exists") 70 | src_path = os.path.join(img_dir, 'images', image_name) 71 | dst_path = os.path.join(save_dir_test, classes, image_name) 72 | try: 73 | shutil.copyfile(src_path, dst_path) 74 | print("src:", src_path, "dst:", dst_path) 75 | except: 76 | print("error",i,foo[i]) 77 | break -------------------------------------------------------------------------------- /utils/split_dataset/birds_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import shutil 4 | 5 | 6 | # divivd dataset (without annotations) 7 | img_dir = 'data/birds/' 8 | 9 | save_dir = 'data/Birds/' 10 | if not os.path.exists(save_dir): 11 | os.mkdir(save_dir) 12 | save_dir_train = os.path.join(save_dir, 'train') 13 | if not os.path.exists(save_dir_train): 14 | os.mkdir(save_dir_train) 15 | save_dir_test = os.path.join(save_dir, 'test') 16 | if not os.path.exists(save_dir_test): 17 | os.mkdir(save_dir_test) 18 | 19 | f2 = open(os.path.join(img_dir, "images.txt")) 20 | foo = f2.readlines() 21 | 22 | f = open(os.path.join(img_dir, "train_test_split.txt")) 23 | bar = f.readlines() 24 | 25 | f3 = open(os.path.join(img_dir, "image_class_labels.txt")) 26 | baz = f3.readlines() 27 | 28 | for i in range(len(foo)): 29 | image_id = foo[i].split(" ")[0] 30 | image_path = foo[i].split(" ")[1][:-1] 31 | image_name = image_path.split("/")[1] 32 | is_train = int(bar[i].split(" ")[1][:-1]) 33 | classes = baz[i].split(" ")[1][:-1].zfill(2) 34 | # split train & test data 35 | if is_train: 36 | # make class dir 37 | try: 38 | os.mkdir(os.path.join(save_dir_train, classes)) 39 | except: 40 | print("file already exists") 41 | src_path = os.path.join(img_dir, 'images', image_path) 42 | dst_path = os.path.join(save_dir_train, classes, image_name) 43 | else: 44 | # make class dir 45 | try: 46 | os.mkdir(os.path.join(save_dir_test, classes)) 47 | except: 48 | print("file already exists") 49 | src_path = os.path.join(img_dir, 'images', image_path) 50 | dst_path = os.path.join(save_dir_test, classes, image_name) 51 | shutil.copyfile(src_path, dst_path) 52 | print("src:", src_path, "dst:", dst_path) -------------------------------------------------------------------------------- /utils/split_dataset/cars_dataset.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import cv2 3 | import os 4 | 5 | 6 | # divivd dataset (without annotations) 7 | img_dir = 'data/cars/' 8 | 9 | save_dir = 'data/StandCars/' 10 | if not os.path.exists(save_dir): 11 | os.mkdir(save_dir) 12 | save_dir_train = os.path.join(save_dir, 'train') 13 | if not os.path.exists(save_dir_train): 14 | os.mkdir(save_dir_train) 15 | save_dir_test = os.path.join(save_dir, 'test') 16 | if not os.path.exists(save_dir_test): 17 | os.mkdir(save_dir_test) 18 | 19 | m = loadmat(os.path.join(img_dir, "cars_annos.mat")) 20 | info = m['annotations'][0] 21 | for img_info in info: 22 | img_name = img_info[0][0] 23 | img_path = os.path.join(img_dir, img_name) 24 | classes = str(int(img_info[-2])) 25 | # make class dir 26 | try: 27 | os.mkdir(os.path.join(save_dir_train, classes)) 28 | except: 29 | print("file already exists") 30 | try: 31 | os.mkdir(os.path.join(save_dir_test, classes)) 32 | except: 33 | print("file already exists") 34 | 35 | # split to train/test 36 | img_test_flag = int(img_info[-1]) 37 | if img_test_flag: 38 | save_path = os.path.join(save_dir_test, classes, img_name[8:]) 39 | img = cv2.imread(img_path) 40 | # save origin image 41 | cv2.imwrite(save_path, img) 42 | else: 43 | save_path = os.path.join(save_dir_train, classes, img_name[8:]) 44 | img = cv2.imread(img_path) 45 | # save origin image 46 | cv2.imwrite(save_path, img) -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class UnNormalizer(object): 6 | def __init__(self, mean=None, std=None): 7 | if mean == None: 8 | self.mean = [0.485, 0.456, 0.406] 9 | else: 10 | self.mean = mean 11 | if std == None: 12 | self.std = [0.229, 0.224, 0.225] 13 | else: 14 | self.std = std 15 | 16 | def __call__(self, tensor): 17 | """ 18 | Args: 19 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 20 | Returns: 21 | Tensor: Normalized image. 22 | """ 23 | for t, m, s in zip(tensor, self.mean, self.std): 24 | t.mul_(s).add_(m) 25 | return tensor 26 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import numpy as np 4 | import cv2 5 | 6 | def load_config(config_file): 7 | with open(config_file, "r") as f: 8 | config = yaml.load(f, Loader=yaml.FullLoader) 9 | return config 10 | 11 | def setup_seed(seed): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | 17 | # plot multi-level ROIs 18 | def plot_roi(inputs, roi_list, unorm, vis, mode='train'): 19 | with torch.no_grad(): 20 | color = [(0, 255, 0), (255, 0, 0), (0, 0, 255)] 21 | for i in range(inputs.size(0)): 22 | img = unorm(inputs.data[i].cpu()).numpy().copy() 23 | img = np.clip(img * 255, 0, 255).astype(np.uint8) 24 | img = np.transpose(img, [1, 2, 0]) 25 | r, g, b = cv2.split(img) 26 | img = cv2.merge([b, g, r]) 27 | for j, roi in enumerate(roi_list): 28 | roi = roi[roi[:, 0] == i] 29 | if len(roi.size()) == 1: 30 | b = roi.data.cpu().numpy() 31 | cv2.rectangle(img, (b[1], b[2]), (b[3], b[4]), color[j % len(color)], 2) 32 | else: 33 | for k in range(roi.size(0)): 34 | b = roi[k].data.cpu().numpy() 35 | cv2.rectangle(img, (b[1], b[2]), (b[3], b[4]), color[j % len(color)], 2) 36 | img = np.transpose(img, [2, 0, 1]) 37 | vis.img('%s_img_%d' % (mode, i), img) 38 | 39 | # plot attention masks 40 | def plot_mask_cat(inputs, mask_cat, unorm, vis, mode='train'): 41 | with torch.no_grad(): 42 | for i in range(inputs.size(0)): 43 | img = unorm(inputs.data[i].cpu()).numpy().copy() 44 | img = np.clip(img * 255, 0, 255).astype(np.uint8) 45 | img = np.transpose(img, [1, 2, 0]) 46 | r, g, b = cv2.split(img) 47 | img = cv2.merge([b, g, r]) 48 | img = np.transpose(img, [2, 0, 1]) 49 | vis.img('%s_img_%d' % (mode, i), img) 50 | for j in range(mask_cat.size(1)): 51 | mask = mask_cat[i, j, :, :].data.cpu().numpy() 52 | img_mask = (255.0 * (mask - np.min(mask)) / (np.max(mask) - np.min(mask))).astype(np.uint8) 53 | # img_mask = (255.0 * mask).astype(np.uint8) 54 | img_mask = cv2.resize(img_mask, dsize=(448, 448)) 55 | vis.img('%s_img_%d_mask%d' % (mode, i, j), img_mask) -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import visdom 3 | import time 4 | import numpy as np 5 | import torch as t 6 | 7 | 8 | class Visualizer(object): 9 | """ 10 | wrapper for visdom 11 | you can still access naive visdom function by 12 | self.line, self.scater,self._send,etc. 13 | due to the implementation of `__getattr__` 14 | """ 15 | 16 | def __init__(self, env='default', **kwargs): 17 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) 18 | self._vis_kw = kwargs 19 | 20 | # e.g.('loss',23) the 23th value of loss 21 | self.index = {} 22 | self.log_text = '' 23 | 24 | def reinit(self, env='default', **kwargs): 25 | """ 26 | change the config of visdom 27 | """ 28 | self.vis = visdom.Visdom(env=env, **kwargs) 29 | return self 30 | 31 | def plot_many(self, d): 32 | """ 33 | plot multi values 34 | @params d: dict (name,value) i.e. ('loss',0.11) 35 | """ 36 | for k, v in d.items(): 37 | if v is not None: 38 | self.plot(k, v) 39 | 40 | def img_many(self, d): 41 | for k, v in d.items(): 42 | self.img(k, v) 43 | 44 | def plot(self, name, y, **kwargs): 45 | """ 46 | self.plot('loss',1.00) 47 | """ 48 | x = self.index.get(name, 0) 49 | self.vis.line(Y=np.array([y]), X=np.array([x]), 50 | win=name, 51 | opts=dict(title=name), 52 | update=None if x == 0 else 'append', 53 | **kwargs 54 | ) 55 | self.index[name] = x + 1 56 | 57 | def img(self, name, img_, **kwargs): 58 | """ 59 | self.img('input_img',t.Tensor(64,64)) 60 | self.img('input_imgs',t.Tensor(3,64,64)) 61 | self.img('input_imgs',t.Tensor(100,1,64,64)) 62 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) 63 | !!don't ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!! 64 | """ 65 | self.vis.image(img_, 66 | win=name, 67 | opts=dict(title=name), 68 | **kwargs 69 | ) 70 | 71 | def log(self, info, win='log_text'): 72 | """ 73 | self.log({'loss':1,'lr':0.0001}) 74 | """ 75 | self.log_text += ('[{time}] {info}
'.format( 76 | time=time.strftime('%m%d_%H%M%S'), \ 77 | info=info)) 78 | self.vis.text(self.log_text, win) 79 | 80 | def __getattr__(self, name): 81 | return getattr(self.vis, name) 82 | 83 | def state_dict(self): 84 | return { 85 | 'index': self.index, 86 | 'vis_kw': self._vis_kw, 87 | 'log_text': self.log_text, 88 | 'env': self.vis.env 89 | } 90 | 91 | def load_state_dict(self, d): 92 | self.vis = visdom.Visdom(env=d.get('env', self.vis.env), **(self.d.get('vis_kw'))) 93 | self.log_text = d.get('log_text', '') 94 | self.index = d.get('index', dict()) 95 | return self 96 | --------------------------------------------------------------------------------