├── .github └── FUNDING.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── Makefile ├── README.md ├── config.py ├── data ├── co_occour_count.npy └── stanford_filtered │ └── README.md ├── dataloaders ├── __init__.py ├── blob.py ├── image_transforms.py ├── mscoco.py └── visual_genome.py ├── docs ├── LICENSE.md ├── _config.yaml ├── _includes │ └── image.html ├── _layouts │ └── default.html ├── index.md ├── teaser.png └── upload.sh ├── install_package.sh ├── lib ├── __init__.py ├── draw_rectangles │ ├── draw_rectangles.c │ ├── draw_rectangles.pyx │ └── setup.py ├── evaluation │ ├── __init__.py │ ├── sg_eval.py │ ├── sg_eval_slow.py │ └── test_sg_eval.py ├── fpn │ ├── anchor_targets.py │ ├── box_intersections_cpu │ │ ├── bbox.c │ │ ├── bbox.pyx │ │ └── setup.py │ ├── box_utils.py │ ├── generate_anchors.py │ ├── make.sh │ ├── nms │ │ ├── Makefile │ │ ├── build.py │ │ ├── functions │ │ │ └── nms.py │ │ └── src │ │ │ ├── cuda │ │ │ ├── Makefile │ │ │ ├── nms_kernel.cu │ │ │ └── nms_kernel.h │ │ │ ├── nms_cuda.c │ │ │ └── nms_cuda.h │ ├── proposal_assignments │ │ ├── proposal_assignments_det.py │ │ ├── proposal_assignments_gtbox.py │ │ ├── proposal_assignments_postnms.py │ │ ├── proposal_assignments_rel.py │ │ └── rel_assignments.py │ └── roi_align │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── _ext │ │ ├── __init__.py │ │ └── roi_align │ │ │ └── __init__.py │ │ ├── build.py │ │ ├── functions │ │ ├── __init__.py │ │ └── roi_align.py │ │ ├── modules │ │ ├── __init__.py │ │ └── roi_align.py │ │ └── src │ │ ├── cuda │ │ ├── Makefile │ │ ├── roi_align_kernel.cu │ │ └── roi_align_kernel.h │ │ ├── roi_align_cuda.c │ │ └── roi_align_cuda.h ├── get_dataset_counts.py ├── get_union_boxes.py ├── lstm │ ├── __init__.py │ ├── decoder_rnn.py │ └── highway_lstm_cuda │ │ ├── __init__.py │ │ ├── _ext │ │ ├── __init__.py │ │ └── highway_lstm_layer │ │ │ └── __init__.py │ │ ├── alternating_highway_lstm.py │ │ ├── build.py │ │ ├── make.sh │ │ └── src │ │ ├── highway_lstm_cuda.c │ │ ├── highway_lstm_cuda.h │ │ ├── highway_lstm_kernel.cu │ │ └── highway_lstm_kernel.h ├── object_detector.py ├── pytorch_misc.py ├── rel_model.py ├── rel_model_stanford.py ├── resnet.py ├── sparse_targets.py ├── surgery.py ├── tree_lstm │ ├── __init__.py │ ├── decoder_tree_lstm.py │ ├── def_tree.py │ ├── draw_tree.py │ ├── gen_tree.py │ ├── graph_to_tree.py │ ├── tree_lstm.py │ └── tree_utils.py └── word_vectors.py ├── models ├── eval_rel_count.py ├── eval_rels.py ├── train_detector.py └── train_rels.py └── scripts ├── eval_models.sh ├── pretrain_detector.sh ├── refine_for_detection.sh ├── train_models_sgcls.sh ├── train_stanford.sh └── train_vctreenet.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [KaihuaTang] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: tkhchipaomg 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: ['https://kaihuatang.github.io/donate'] 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | .venv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | .spyproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # mkdocs documentation 97 | /site 98 | 99 | # mypy 100 | .mypy_cache/ 101 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.linting.pylintEnabled": false 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Rowan Zellers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | export PATH := /usr/local/cuda-9.0/bin:$(PATH) 2 | 3 | all: draw_rectangles box_intersections nms roi_align lstm 4 | 5 | draw_rectangles: 6 | cd lib/draw_rectangles; python setup.py build_ext --inplace 7 | box_intersections: 8 | cd lib/fpn/box_intersections_cpu; python setup.py build_ext --inplace 9 | nms: 10 | cd lib/fpn/nms; make 11 | roi_align: 12 | cd lib/fpn/roi_align; make 13 | lstm: 14 | cd lib/lstm/highway_lstm_cuda; bash make.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VCTree-Scene-Graph-Generation 2 | 3 | **If you like our work, and want to start your own scene graph generation project, you might be interested in our new SGG codebase: [Scene-Graph-Benchmark.pytorch](https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch). It's much easier to follow, and provides state-of-the-art baseline models.** 4 | 5 | Code for the Scene Graph Generation part of CVPR 2019 oral paper: "[Learning to Compose Dynamic Tree Structures for Visual Contexts][0]", as to the VQA part of this paper, please refer to [KaihuaTang/VCTree-Visual-Question-Answering][6] 6 | 7 | UGLY CODE WARNING! UGLY CODE WARNING! UGLY CODE WARNING! 8 | 9 | The code is directly modified from the project [rowanz/neural-motifs][1]. Most of the Codes about the proposed VCTree are located at lib/tree_lstm/*, and if you get any problem that cause you unable to run the project, you can check the issues under [rowanz/neural-motifs][1] first. 10 | 11 | **If my open source projects have inspired you, giving me some sponsorship will be a great help to my subsequent open source work.** 12 | [Support my subsequent open source work❤️🙏](https://kaihuatang.github.io/donate.html) 13 | 14 | # Dependencies 15 | - You may follow these commands to establish the environments under Ubuntu system 16 | ``` 17 | Install Anaconda 18 | conda update -n base conda 19 | conda create -n motif pip python=3.6 20 | conda install pytorch=0.3 torchvision cuda90 -c pytorch 21 | bash install_package.sh 22 | ``` 23 | 24 | # Prepare Dataset and Setup 25 | 26 | 0. Please follow the [Instruction][2] under ./data/stanford_filtered/ to download the dateset and put them under proper locations. 27 | 28 | 1. Update the config file with the dataset paths. Specifically: 29 | - Visual Genome (the VG_100K folder, image_data.json, VG-SGG.h5, and VG-SGG-dicts.json). See data/stanford_filtered/README.md for the steps I used to download these. 30 | - You'll also need to fix your PYTHONPATH: ```export PYTHONPATH=/home/YourName/ThePathOfYourProject``` 31 | 32 | 2. Compile everything. run ```make``` in the main directory: this compiles the Bilinear Interpolation operation for the RoIs. 33 | 34 | 3. Pretrain VG detection. The old version involved pretraining COCO as well, but we got rid of that for simplicity. Run ./scripts/pretrain_detector.sh 35 | Note: You might have to modify the learning rate and batch size, particularly if you don't have 3 Titan X GPUs (which is what I used). [You can also download the pretrained detector checkpoint here.](https://drive.google.com/open?id=11zKRr2OF5oclFL47kjFYBOxScotQzArX) Note that, this detector model is the default initialization of all VCTree models, so when you download this checkpoint, you need to change the "-ckpt THE_PATH_OF_INITIAL_CHECKPOINT_MODEL" under ./scripts/train_vctreenet 36 | 37 | 38 | # How to Train / Evaluation 39 | 0. Note that, most of the parameters are under config.py. The training stages and settings are manipulated through ./scripts/train_vctreenet.sh Each line of command in train_vctreenet.sh needs to manually indicate "-ckpt" model (initial parameters) and "-save_dir" the path to save model. Since we have hybrid learning strategy, each task predcls/sgcls/sgdet will have two options for supervised stage and reinformence finetuning stage, respectively. When iteratively switch the stages, the -ckpt PATH should start with previous -save_dir PATH. The first supervised stage will init with [detector checkpoint](https://drive.google.com/open?id=11zKRr2OF5oclFL47kjFYBOxScotQzArX) as mentioned above. 40 | 41 | 1. Train VG predicate classification (predcls) 42 | - stage 1 (supervised stage of hybrid learning): run ./scripts/train_vctreenet.sh 5 43 | - stage 2 (reinformence finetuning stage of hybrid learning): run ./scripts/train_vctreenet.sh 4 44 | - (By default, it will run on GPU 2, you can modify CUDA_VISIBLE_DEVICES under train_vctreenet.sh). 45 | - The model will be saved by the name "-save_dir checkpoints/THE_NAME_YOU_WILL_SAVE_THE_MODEL" 46 | 47 | 2. Train VG scene graph classification (sgcls) 48 | - stage 1 (supervised stage of hybrid learning): run ./scripts/train_vctreenet.sh 3 49 | - stage 2 (reinformence finetuning stage of hybrid learning): run ./scripts/train_vctreenet.sh 2 50 | - (By default, it will run on GPU 2, you can modify CUDA_VISIBLE_DEVICES under train_vctreenet.sh). 51 | - The model will be saved by the name "-save_dir checkpoints/THE_NAME_YOU_WILL_SAVE_THE_MODEL" 52 | 53 | 3. Train VG scene graph detection (sgdet) 54 | - stage 1 (supervised stage of hybrid learning): run ./scripts/train_vctreenet.sh 1 55 | - stage 2 (reinformence finetuning stage of hybrid learning): run ./scripts/train_vctreenet.sh 0 56 | - (By default, it will run on GPU 2, you can modify CUDA_VISIBLE_DEVICES under train_vctreenet.sh). 57 | - The model will be saved by the name "-save_dir checkpoints/THE_NAME_YOU_WILL_SAVE_THE_MODEL" 58 | 59 | 4. Evaluate predicate classification (predcls): 60 | - run ./scripts/eval_models.sh 0 61 | - OR, You can simply download our predcls checkpoint: [VCTree/PredCls][3]. 62 | 63 | 5. Evaluate scene graph classification (sgcls): 64 | - run ./scripts/eval_models.sh 1 65 | - OR, You can simply download our sgcls checkpoint: [VCTree/SGCls][4]. 66 | 67 | 6. Evaluate scene graph detection (sgdet): 68 | - run ./scripts/eval_models.sh 2 69 | - OR, You can simply download our sgdet checkpoint: [VCTree/SGDET][5]. 70 | 71 | 72 | # Other Things You Need To Know 73 | - When you evaluate your model, you will find 3 metrics are printed: 1st, "R@20/50/100" is what we use to report R@20/50/100 in our paper, 2nd, "cls avg" is corresponding mean recall mR@20/50/100 proposed by our paper, "total R" is another way to calculate recall that used in some previous papers/projects, which is quite tricky and unfair, because it almost always get higher recall. 74 | - The reinforcement part of hybrid learning is still far from satisfactory. Hence if you are interested in imporving our work, you may start with this part. 75 | 76 | # If this paper/project inspires your work, pls cite our work: 77 | ``` 78 | @inproceedings{tang2018learning, 79 | title={Learning to Compose Dynamic Tree Structures for Visual Contexts}, 80 | author={Tang, Kaihua and Zhang, Hanwang and Wu, Baoyuan and Luo, Wenhan and Liu, Wei}, 81 | booktitle= "Conference on Computer Vision and Pattern Recognition", 82 | year={2019} 83 | } 84 | ``` 85 | 86 | [0]: https://arxiv.org/abs/1812.01880 87 | [1]: https://github.com/rowanz/neural-motifs 88 | [2]: https://github.com/rowanz/neural-motifs/tree/master/data/stanford_filtered 89 | [3]: https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21768059&authkey=APvRgmSUEvf4h8s 90 | [4]: https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21768060&authkey=ADI-fKq10g-niGk 91 | [5]: https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21768063&authkey=ADOyKfb6MGR5seI 92 | [6]: https://github.com/KaihuaTang/VCTree-Visual-Question-Answering 93 | -------------------------------------------------------------------------------- /data/co_occour_count.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/data/co_occour_count.npy -------------------------------------------------------------------------------- /data/stanford_filtered/README.md: -------------------------------------------------------------------------------- 1 | # Filtered data 2 | Adapted from [Danfei Xu](https://github.com/danfeiX/scene-graph-TF-release/blob/master/data_tools/README.md). 3 | 4 | Follow the folling steps to get the dataset set up. 5 | 1. Download the VG images [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip) [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip). Extract these images to a file and link to them in `config.py` (eg. currently I have `VG_IMAGES=data/visual_genome/VG_100K`). 6 | 2. Download the [VG metadata](http://cvgl.stanford.edu/scene-graph/VG/image_data.json). I recommend extracting it to this directory (e.g. `data/stanford_filtered/image_data.json`), or you can edit the path in `config.py`. 7 | 3. Download the [scene graphs](http://cvgl.stanford.edu/scene-graph/dataset/VG-SGG.h5) and extract them to `data/stanford_filtered/VG-SGG.h5` 8 | 4. Download the [scene graph dataset metadata](http://cvgl.stanford.edu/scene-graph/dataset/VG-SGG-dicts.json) and extract it to `data/stanford_filtered/VG-SGG-dicts.json` 9 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/blob.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data blob, hopefully to make collating less painful and MGPU training possible 3 | """ 4 | from lib.fpn.anchor_targets import anchor_target_layer 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | 10 | class Blob(object): 11 | def __init__(self, mode='det', is_train=False, num_gpus=1, primary_gpu=0, batch_size_per_gpu=3): 12 | """ 13 | Initializes an empty Blob object. 14 | :param mode: 'det' for detection and 'rel' for det+relationship 15 | :param is_train: True if it's training 16 | """ 17 | assert mode in ('det', 'rel') 18 | assert num_gpus >= 1 19 | self.mode = mode 20 | self.is_train = is_train 21 | self.num_gpus = num_gpus 22 | self.batch_size_per_gpu = batch_size_per_gpu 23 | self.primary_gpu = primary_gpu 24 | 25 | self.imgs = [] # [num_images, 3, IM_SCALE, IM_SCALE] array 26 | self.im_sizes = [] # [num_images, 4] array of (h, w, scale, num_valid_anchors) 27 | self.all_anchor_inds = [] # [all_anchors, 2] array of (img_ind, anchor_idx). Only has valid 28 | # boxes (meaning some are gonna get cut out) 29 | self.all_anchors = [] # [num_im, IM_SCALE/4, IM_SCALE/4, num_anchors, 4] shapes. Anchors outside get squashed 30 | # to 0 31 | self.gt_boxes = [] # [num_gt, 4] boxes 32 | self.gt_classes = [] # [num_gt,2] array of img_ind, class 33 | self.gt_rels = [] # [num_rels, 3]. Each row is (gtbox0, gtbox1, rel). 34 | 35 | self.gt_sents = [] 36 | self.gt_nodes = [] 37 | self.sent_lengths = [] 38 | 39 | self.train_anchor_labels = [] # [train_anchors, 5] array of (img_ind, h, w, A, labels) 40 | self.train_anchors = [] # [train_anchors, 8] shapes with anchor, target 41 | 42 | self.train_anchor_inds = None # This will be split into GPUs, just (img_ind, h, w, A). 43 | 44 | self.batch_size = None 45 | self.gt_box_chunks = None 46 | self.anchor_chunks = None 47 | self.train_chunks = None 48 | self.proposal_chunks = None 49 | self.proposals = [] 50 | 51 | @property 52 | def is_flickr(self): 53 | return self.mode == 'flickr' 54 | 55 | @property 56 | def is_rel(self): 57 | return self.mode == 'rel' 58 | 59 | @property 60 | def volatile(self): 61 | return not self.is_train 62 | 63 | def append(self, d): 64 | """ 65 | Adds a single image to the blob 66 | :param datom: 67 | :return: 68 | """ 69 | i = len(self.imgs) 70 | self.imgs.append(d['img']) 71 | 72 | h, w, scale = d['img_size'] 73 | 74 | # all anchors 75 | self.im_sizes.append((h, w, scale)) 76 | 77 | gt_boxes_ = d['gt_boxes'].astype(np.float32) * d['scale'] 78 | self.gt_boxes.append(gt_boxes_) 79 | 80 | self.gt_classes.append(np.column_stack(( 81 | i * np.ones(d['gt_classes'].shape[0], dtype=np.int64), 82 | d['gt_classes'], 83 | ))) 84 | 85 | # Add relationship info 86 | if self.is_rel: 87 | self.gt_rels.append(np.column_stack(( 88 | i * np.ones(d['gt_relations'].shape[0], dtype=np.int64), 89 | d['gt_relations']))) 90 | 91 | # Augment with anchor targets 92 | if self.is_train: 93 | train_anchors_, train_anchor_inds_, train_anchor_targets_, train_anchor_labels_ = \ 94 | anchor_target_layer(gt_boxes_, (h, w)) 95 | 96 | self.train_anchors.append(np.hstack((train_anchors_, train_anchor_targets_))) 97 | 98 | self.train_anchor_labels.append(np.column_stack(( 99 | i * np.ones(train_anchor_inds_.shape[0], dtype=np.int64), 100 | train_anchor_inds_, 101 | train_anchor_labels_, 102 | ))) 103 | 104 | if 'proposals' in d: 105 | self.proposals.append(np.column_stack((i * np.ones(d['proposals'].shape[0], dtype=np.float32), 106 | d['scale'] * d['proposals'].astype(np.float32)))) 107 | 108 | 109 | 110 | def _chunkize(self, datom, tensor=torch.LongTensor): 111 | """ 112 | Turn data list into chunks, one per GPU 113 | :param datom: List of lists of numpy arrays that will be concatenated. 114 | :return: 115 | """ 116 | chunk_sizes = [0] * self.num_gpus 117 | for i in range(self.num_gpus): 118 | for j in range(self.batch_size_per_gpu): 119 | chunk_sizes[i] += datom[i * self.batch_size_per_gpu + j].shape[0] 120 | return Variable(tensor(np.concatenate(datom, 0)), volatile=self.volatile), chunk_sizes 121 | 122 | def reduce(self): 123 | """ Merges all the detections into flat lists + numbers of how many are in each""" 124 | if len(self.imgs) != self.batch_size_per_gpu * self.num_gpus: 125 | raise ValueError("Wrong batch size? imgs len {} bsize/gpu {} numgpus {}".format( 126 | len(self.imgs), self.batch_size_per_gpu, self.num_gpus 127 | )) 128 | 129 | self.imgs = Variable(torch.stack(self.imgs, 0), volatile=self.volatile) 130 | self.im_sizes = np.stack(self.im_sizes).reshape( 131 | (self.num_gpus, self.batch_size_per_gpu, 3)) 132 | 133 | if self.is_rel: 134 | self.gt_rels, self.gt_rel_chunks = self._chunkize(self.gt_rels) 135 | 136 | self.gt_boxes, self.gt_box_chunks = self._chunkize(self.gt_boxes, tensor=torch.FloatTensor) 137 | self.gt_classes, _ = self._chunkize(self.gt_classes) 138 | if self.is_train: 139 | self.train_anchor_labels, self.train_chunks = self._chunkize(self.train_anchor_labels) 140 | self.train_anchors, _ = self._chunkize(self.train_anchors, tensor=torch.FloatTensor) 141 | self.train_anchor_inds = self.train_anchor_labels[:, :-1].contiguous() 142 | 143 | if len(self.proposals) != 0: 144 | self.proposals, self.proposal_chunks = self._chunkize(self.proposals, tensor=torch.FloatTensor) 145 | 146 | 147 | 148 | def _scatter(self, x, chunk_sizes, dim=0): 149 | """ Helper function""" 150 | if self.num_gpus == 1: 151 | return x.cuda(self.primary_gpu, async=True) 152 | return torch.nn.parallel.scatter_gather.Scatter.apply( 153 | list(range(self.num_gpus)), chunk_sizes, dim, x) 154 | 155 | def scatter(self): 156 | """ Assigns everything to the GPUs""" 157 | self.imgs = self._scatter(self.imgs, [self.batch_size_per_gpu] * self.num_gpus) 158 | 159 | self.gt_classes_primary = self.gt_classes.cuda(self.primary_gpu, async=True) 160 | self.gt_boxes_primary = self.gt_boxes.cuda(self.primary_gpu, async=True) 161 | 162 | # Predcls might need these 163 | self.gt_classes = self._scatter(self.gt_classes, self.gt_box_chunks) 164 | self.gt_boxes = self._scatter(self.gt_boxes, self.gt_box_chunks) 165 | 166 | if self.is_train: 167 | 168 | self.train_anchor_inds = self._scatter(self.train_anchor_inds, 169 | self.train_chunks) 170 | self.train_anchor_labels = self.train_anchor_labels.cuda(self.primary_gpu, async=True) 171 | self.train_anchors = self.train_anchors.cuda(self.primary_gpu, async=True) 172 | 173 | if self.is_rel: 174 | self.gt_rels = self._scatter(self.gt_rels, self.gt_rel_chunks) 175 | else: 176 | if self.is_rel: 177 | self.gt_rels = self.gt_rels.cuda(self.primary_gpu, async=True) 178 | 179 | if self.proposal_chunks is not None: 180 | self.proposals = self._scatter(self.proposals, self.proposal_chunks) 181 | 182 | def __getitem__(self, index): 183 | """ 184 | Returns a tuple containing data 185 | :param index: Which GPU we're on, or 0 if no GPUs 186 | :return: If training: 187 | (image, im_size, img_start_ind, anchor_inds, anchors, gt_boxes, gt_classes, 188 | train_anchor_inds) 189 | test: 190 | (image, im_size, img_start_ind, anchor_inds, anchors) 191 | """ 192 | if index not in list(range(self.num_gpus)): 193 | raise ValueError("Out of bounds with index {} and {} gpus".format(index, self.num_gpus)) 194 | 195 | if self.is_rel: 196 | rels = self.gt_rels 197 | if index > 0 or self.num_gpus != 1: 198 | rels_i = rels[index] if self.is_rel else None 199 | elif self.is_flickr: 200 | rels = (self.gt_sents, self.gt_nodes) 201 | if index > 0 or self.num_gpus != 1: 202 | rels_i = (self.gt_sents[index], self.gt_nodes[index]) 203 | else: 204 | rels = None 205 | rels_i = None 206 | 207 | if self.proposal_chunks is None: 208 | proposals = None 209 | else: 210 | proposals = self.proposals 211 | 212 | if index == 0 and self.num_gpus == 1: 213 | image_offset = 0 214 | if self.is_train: 215 | return (self.imgs, self.im_sizes[0], image_offset, 216 | self.gt_boxes, self.gt_classes, rels, proposals, self.train_anchor_inds) 217 | return self.imgs, self.im_sizes[0], image_offset, self.gt_boxes, self.gt_classes, rels, proposals 218 | 219 | # Otherwise proposals is None 220 | assert proposals is None 221 | 222 | image_offset = self.batch_size_per_gpu * index 223 | # TODO: Return a namedtuple 224 | if self.is_train: 225 | return ( 226 | self.imgs[index], self.im_sizes[index], image_offset, 227 | self.gt_boxes[index], self.gt_classes[index], rels_i, None, self.train_anchor_inds[index]) 228 | return (self.imgs[index], self.im_sizes[index], image_offset, 229 | self.gt_boxes[index], self.gt_classes[index], rels_i, None) 230 | 231 | -------------------------------------------------------------------------------- /dataloaders/image_transforms.py: -------------------------------------------------------------------------------- 1 | # Some image transforms 2 | 3 | from PIL import Image, ImageOps, ImageFilter, ImageEnhance 4 | import numpy as np 5 | from random import randint 6 | # All of these need to be called on PIL imagez 7 | 8 | class SquarePad(object): 9 | def __call__(self, img): 10 | w, h = img.size 11 | img_padded = ImageOps.expand(img, border=(0, 0, max(h - w, 0), max(w - h, 0)), 12 | fill=(int(0.485 * 256), int(0.456 * 256), int(0.406 * 256))) 13 | return img_padded 14 | 15 | 16 | class Grayscale(object): 17 | """ 18 | Converts to grayscale (not always, sometimes). 19 | """ 20 | def __call__(self, img): 21 | factor = np.sqrt(np.sqrt(np.random.rand(1))) 22 | # print("gray {}".format(factor)) 23 | enhancer = ImageEnhance.Color(img) 24 | return enhancer.enhance(factor) 25 | 26 | 27 | class Brightness(object): 28 | """ 29 | Converts to grayscale (not always, sometimes). 30 | """ 31 | def __call__(self, img): 32 | factor = np.random.randn(1)/6+1 33 | factor = min(max(factor, 0.5), 1.5) 34 | # print("brightness {}".format(factor)) 35 | 36 | enhancer = ImageEnhance.Brightness(img) 37 | return enhancer.enhance(factor) 38 | 39 | 40 | class Contrast(object): 41 | """ 42 | Converts to grayscale (not always, sometimes). 43 | """ 44 | def __call__(self, img): 45 | factor = np.random.randn(1)/8+1.0 46 | factor = min(max(factor, 0.5), 1.5) 47 | # print("contrast {}".format(factor)) 48 | 49 | enhancer = ImageEnhance.Contrast(img) 50 | return enhancer.enhance(factor) 51 | 52 | 53 | class Hue(object): 54 | """ 55 | Converts to grayscale 56 | """ 57 | def __call__(self, img): 58 | # 30 seems good 59 | factor = int(np.random.randn(1)*8) 60 | factor = min(max(factor, -30), 30) 61 | factor = np.array(factor, dtype=np.uint8) 62 | 63 | hsv = np.array(img.convert('HSV')) 64 | hsv[:,:,0] += factor 65 | new_img = Image.fromarray(hsv, 'HSV').convert('RGB') 66 | 67 | return new_img 68 | 69 | 70 | class Sharpness(object): 71 | """ 72 | Converts to grayscale 73 | """ 74 | def __call__(self, img): 75 | factor = 1.0 + np.random.randn(1)/5 76 | # print("sharpness {}".format(factor)) 77 | enhancer = ImageEnhance.Sharpness(img) 78 | return enhancer.enhance(factor) 79 | 80 | 81 | def random_crop(img, boxes, box_scale, round_boxes=True, max_crop_fraction=0.1): 82 | """ 83 | Randomly crops the image 84 | :param img: PIL image 85 | :param boxes: Ground truth boxes 86 | :param box_scale: This is the scale that the boxes are at (e.g. 1024 wide). We'll preserve that ratio 87 | :param round_boxes: Set this to true if we're going to round the boxes to ints 88 | :return: Cropped image, new boxes 89 | """ 90 | 91 | w, h = img.size 92 | 93 | max_crop_w = int(w*max_crop_fraction) 94 | max_crop_h = int(h*max_crop_fraction) 95 | boxes_scaled = boxes * max(w,h) / box_scale 96 | max_to_crop_top = min(int(boxes_scaled[:, 1].min()), max_crop_h) 97 | max_to_crop_left = min(int(boxes_scaled[:, 0].min()), max_crop_w) 98 | max_to_crop_right = min(int(w - boxes_scaled[:, 2].max()), max_crop_w) 99 | max_to_crop_bottom = min(int(h - boxes_scaled[:, 3].max()), max_crop_h) 100 | 101 | crop_top = randint(0, max(max_to_crop_top, 0)) 102 | crop_left = randint(0, max(max_to_crop_left, 0)) 103 | crop_right = randint(0, max(max_to_crop_right, 0)) 104 | crop_bottom = randint(0, max(max_to_crop_bottom, 0)) 105 | img_cropped = img.crop((crop_left, crop_top, w - crop_right, h - crop_bottom)) 106 | 107 | new_boxes = box_scale / max(img_cropped.size) * np.column_stack( 108 | (boxes_scaled[:,0]-crop_left, boxes_scaled[:,1]-crop_top, boxes_scaled[:,2]-crop_left, boxes_scaled[:,3]-crop_top)) 109 | 110 | if round_boxes: 111 | new_boxes = np.round(new_boxes).astype(np.int32) 112 | return img_cropped, new_boxes 113 | 114 | 115 | class RandomOrder(object): 116 | """ Composes several transforms together in random order - or not at all! 117 | """ 118 | 119 | def __init__(self, transforms): 120 | self.transforms = transforms 121 | 122 | def __call__(self, img): 123 | if self.transforms is None: 124 | return img 125 | num_to_pick = np.random.choice(len(self.transforms)) 126 | if num_to_pick == 0: 127 | return img 128 | 129 | order = np.random.choice(len(self.transforms), size=num_to_pick, replace=False) 130 | for i in order: 131 | img = self.transforms[i](img) 132 | return img -------------------------------------------------------------------------------- /dataloaders/mscoco.py: -------------------------------------------------------------------------------- 1 | from config import COCO_PATH, IM_SCALE, BOX_SCALE 2 | import os 3 | from torch.utils.data import Dataset 4 | from pycocotools.coco import COCO 5 | from PIL import Image 6 | from lib.fpn.anchor_targets import anchor_target_layer 7 | from torchvision.transforms import Resize, Compose, ToTensor, Normalize 8 | from dataloaders.image_transforms import SquarePad, Grayscale, Brightness, Sharpness, Contrast, RandomOrder, Hue, random_crop 9 | import numpy as np 10 | from dataloaders.blob import Blob 11 | import torch 12 | 13 | class CocoDetection(Dataset): 14 | """ 15 | Adapted from the torchvision code 16 | """ 17 | 18 | def __init__(self, mode): 19 | """ 20 | :param mode: train2014 or val2014 21 | """ 22 | self.mode = mode 23 | self.root = os.path.join(COCO_PATH, mode) 24 | self.ann_file = os.path.join(COCO_PATH, 'annotations', 'instances_{}.json'.format(mode)) 25 | self.coco = COCO(self.ann_file) 26 | self.ids = [k for k in self.coco.imgs.keys() if len(self.coco.imgToAnns[k]) > 0] 27 | 28 | 29 | tform = [] 30 | if self.is_train: 31 | tform.append(RandomOrder([ 32 | Grayscale(), 33 | Brightness(), 34 | Contrast(), 35 | Sharpness(), 36 | Hue(), 37 | ])) 38 | 39 | tform += [ 40 | SquarePad(), 41 | Resize(IM_SCALE), 42 | ToTensor(), 43 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 44 | ] 45 | 46 | self.transform_pipeline = Compose(tform) 47 | self.ind_to_classes = ['__background__'] + [v['name'] for k, v in self.coco.cats.items()] 48 | # COCO inds are weird (84 inds in total but a bunch of numbers are skipped) 49 | self.id_to_ind = {coco_id:(ind+1) for ind, coco_id in enumerate(self.coco.cats.keys())} 50 | self.id_to_ind[0] = 0 51 | 52 | self.ind_to_id = {x:y for y,x in self.id_to_ind.items()} 53 | 54 | @property 55 | def is_train(self): 56 | return self.mode.startswith('train') 57 | 58 | def __getitem__(self, index): 59 | """ 60 | Args: 61 | index (int): Index 62 | 63 | Returns: entry dict 64 | """ 65 | img_id = self.ids[index] 66 | path = self.coco.loadImgs(img_id)[0]['file_name'] 67 | image_unpadded = Image.open(os.path.join(self.root, path)).convert('RGB') 68 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 69 | anns = self.coco.loadAnns(ann_ids) 70 | gt_classes = np.array([self.id_to_ind[x['category_id']] for x in anns], dtype=np.int64) 71 | 72 | if np.any(gt_classes >= len(self.ind_to_classes)): 73 | raise ValueError("OH NO {}".format(index)) 74 | 75 | if len(anns) == 0: 76 | raise ValueError("Annotations should not be empty") 77 | # gt_boxes = np.array((0, 4), dtype=np.float32) 78 | # else: 79 | gt_boxes = np.array([x['bbox'] for x in anns], dtype=np.float32) 80 | 81 | if np.any(gt_boxes[:, [0,1]] < 0): 82 | raise ValueError("GT boxes empty columns") 83 | if np.any(gt_boxes[:, [2,3]] < 0): 84 | raise ValueError("GT boxes empty h/w") 85 | gt_boxes[:, [2, 3]] += gt_boxes[:, [0, 1]] 86 | 87 | # Rescale so that the boxes are at BOX_SCALE 88 | if self.is_train: 89 | image_unpadded, gt_boxes = random_crop(image_unpadded, 90 | gt_boxes * BOX_SCALE / max(image_unpadded.size), 91 | BOX_SCALE, 92 | round_boxes=False, 93 | ) 94 | else: 95 | # Seems a bit silly because we won't be using GT boxes then but whatever 96 | gt_boxes = gt_boxes * BOX_SCALE / max(image_unpadded.size) 97 | w, h = image_unpadded.size 98 | box_scale_factor = BOX_SCALE / max(w, h) 99 | 100 | # Optionally flip the image if we're doing training 101 | flipped = self.is_train and np.random.random() > 0.5 102 | if flipped: 103 | scaled_w = int(box_scale_factor * float(w)) 104 | image_unpadded = image_unpadded.transpose(Image.FLIP_LEFT_RIGHT) 105 | gt_boxes[:, [0, 2]] = scaled_w - gt_boxes[:, [2, 0]] 106 | 107 | img_scale_factor = IM_SCALE / max(w, h) 108 | if h > w: 109 | im_size = (IM_SCALE, int(w*img_scale_factor), img_scale_factor) 110 | elif h < w: 111 | im_size = (int(h*img_scale_factor), IM_SCALE, img_scale_factor) 112 | else: 113 | im_size = (IM_SCALE, IM_SCALE, img_scale_factor) 114 | 115 | entry = { 116 | 'img': self.transform_pipeline(image_unpadded), 117 | 'img_size': im_size, 118 | 'gt_boxes': gt_boxes, 119 | 'gt_classes': gt_classes, 120 | 'scale': IM_SCALE / BOX_SCALE, 121 | 'index': index, 122 | 'image_id': img_id, 123 | 'flipped': flipped, 124 | 'fn': path, 125 | } 126 | 127 | return entry 128 | 129 | @classmethod 130 | def splits(cls, *args, **kwargs): 131 | """ Helper method to generate splits of the dataset""" 132 | train = cls('train2014', *args, **kwargs) 133 | val = cls('val2014', *args, **kwargs) 134 | return train, val 135 | 136 | def __len__(self): 137 | return len(self.ids) 138 | 139 | 140 | def coco_collate(data, num_gpus=3, is_train=False): 141 | blob = Blob(mode='det', is_train=is_train, num_gpus=num_gpus, 142 | batch_size_per_gpu=len(data) // num_gpus) 143 | for d in data: 144 | blob.append(d) 145 | blob.reduce() 146 | return blob 147 | 148 | 149 | class CocoDataLoader(torch.utils.data.DataLoader): 150 | """ 151 | Iterates through the data, filtering out None, 152 | but also loads everything as a (cuda) variable 153 | """ 154 | # def __iter__(self): 155 | # for x in super(CocoDataLoader, self).__iter__(): 156 | # if isinstance(x, tuple) or isinstance(x, list): 157 | # yield tuple(y.cuda(async=True) if hasattr(y, 'cuda') else y for y in x) 158 | # else: 159 | # yield x.cuda(async=True) 160 | 161 | @classmethod 162 | def splits(cls, train_data, val_data, batch_size=3, num_workers=1, num_gpus=3, **kwargs): 163 | train_load = cls( 164 | dataset=train_data, 165 | batch_size=batch_size*num_gpus, 166 | shuffle=True, 167 | num_workers=num_workers, 168 | collate_fn=lambda x: coco_collate(x, num_gpus=num_gpus, is_train=True), 169 | drop_last=True, 170 | # pin_memory=True, 171 | **kwargs, 172 | ) 173 | val_load = cls( 174 | dataset=val_data, 175 | batch_size=batch_size*num_gpus, 176 | shuffle=False, 177 | num_workers=num_workers, 178 | collate_fn=lambda x: coco_collate(x, num_gpus=num_gpus, is_train=False), 179 | drop_last=True, 180 | # pin_memory=True, 181 | **kwargs, 182 | ) 183 | return train_load, val_load 184 | 185 | 186 | if __name__ == '__main__': 187 | train, val = CocoDetection.splits() 188 | gtbox = train[0]['gt_boxes'] 189 | img_size = train[0]['img_size'] 190 | anchor_strides, labels, bbox_targets = anchor_target_layer(gtbox, img_size) 191 | -------------------------------------------------------------------------------- /docs/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Heiswayi Nrird 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/_config.yaml: -------------------------------------------------------------------------------- 1 | exclude: [README.md, LICENSE.md] 2 | 3 | defaults: 4 | - values: 5 | layout: default 6 | -------------------------------------------------------------------------------- /docs/_includes/image.html: -------------------------------------------------------------------------------- 1 |
2 | {{ include.description }} 3 |
4 | 5 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | {{ page.title }} 9 | 10 | 11 | 12 | 13 | 62 | 75 | 76 | 77 | 78 | 79 | {{ content }} 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | permalink: / 3 | title: Neural Motifs 4 | author: Rowan Zellers 5 | description: Scene Graph Parsing with Global Context (CVPR 2018) 6 | google_analytics_id: UA-84290243-3 7 | --- 8 | # Neural Motifs: Scene Graph Parsing with Global Context (CVPR 2018) 9 | 10 | ### by [Rowan Zellers](https://rowanzellers.com), [Mark Yatskar](https://homes.cs.washington.edu/~my89/), [Sam Thomson](https://http://samthomson.com/), [Yejin Choi](https://homes.cs.washington.edu/~yejin/) 11 | 12 | 13 | {% include image.html url="teaser.png" description="teaser" %} 14 | 15 | # Overview 16 | 17 | * In this work, we investigate the problem of producing structured graph representations of visual scenes. Similar to object detection, we must predict a box around each object. Here, we also need to predict an edge (with one of several labels, possibly `background`) between every ordered pair of boxes, producing a directed graph where the edges hopefully represent the semantics and interactions present in the scene. 18 | * We present an analysis of the [Visual Genome Scene Graphs dataset](http://visualgenome.org/). In particular: 19 | * Object labels (e.g. person, shirt) are highly predictive of edge labels (e.g. wearing), but **not vice versa**. 20 | * Over 90% of the edges in the dataset are non-semantic. 21 | * There is a significant amount of structure in the dataset, in the form of graph motifs (regularly appearing substructures). 22 | * Motivated by our analysis, we present a simple baseline that outperforms previous approaches. 23 | * We introduce Stacked Motif Networks (MotifNet), which is a novel architecture that is designed to capture higher order motifs in scene graphs. In doing so, it achieves a sizeable performance gain over prior state-of-the-art. 24 | 25 | # Read the paper! 26 | The old version of the paper is available at [arxiv link](https://arxiv.org/abs/1711.06640) - camera ready version coming soon! 27 | 28 | # Bibtex 29 | ``` 30 | @inproceedings{zellers2018scenegraphs, 31 | title={Neural Motifs: Scene Graph Parsing with Global Context}, 32 | author={Zellers, Rowan and Yatskar, Mark and Thomson, Sam and Choi, Yejin}, 33 | booktitle = "Conference on Computer Vision and Pattern Recognition", 34 | year={2018} 35 | } 36 | ``` 37 | 38 | # View some examples! 39 | 40 | Check out [this tool](https://rowanzellers.com/scenegraph2/) I made to visualize the scene graph predictions. Disclaimer: the predictions are from an earlier version of the model, but hopefully they're still helpful! 41 | 42 | # Code 43 | 44 | Visit the [`neural-motifs` GitHub repository](https://github.com/rowanz/neural-motifs) for our reference implementation and instructions for running our code. 45 | 46 | It is released under the MIT license. 47 | 48 | # Checkpoints available for download 49 | * [Pretrained Detector](https://drive.google.com/open?id=11zKRr2OF5oclFL47kjFYBOxScotQzArX) 50 | * [Motifnet-SGDet](https://drive.google.com/open?id=1thd_5uSamJQaXAPVGVOUZGAOfGCYZYmb) 51 | * [Motifnet-SGCls/PredCls](https://drive.google.com/open?id=12qziGKYjFD3LAnoy4zDT3bcg5QLC0qN6) 52 | 53 | # questions? 54 | 55 | Feel free to get in touch! My main website is at [rowanzellers.com](https://rowanzellers.com) 56 | -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/docs/teaser.png -------------------------------------------------------------------------------- /docs/upload.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | scp -r _site/* USERNAME@SITE:~/rowanzellers.com/neuralmotifs -------------------------------------------------------------------------------- /install_package.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | git clone https://github.com/waleedka/coco 4 | cd coco/PythonAPI/ 5 | make 6 | python setup.py build_ext install 7 | cd ../../ 8 | 9 | conda install h5py 10 | 11 | conda install matplotlib 12 | 13 | conda install pandas 14 | 15 | conda install dill 16 | 17 | conda install tqdm 18 | 19 | pip install overrides 20 | 21 | pip install scikit-image -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/__init__.py -------------------------------------------------------------------------------- /lib/draw_rectangles/draw_rectangles.pyx: -------------------------------------------------------------------------------- 1 | ###### 2 | # Draws rectangles 3 | ###### 4 | 5 | cimport cython 6 | import numpy as np 7 | cimport numpy as np 8 | 9 | DTYPE = np.float32 10 | ctypedef np.float32_t DTYPE_t 11 | 12 | def draw_union_boxes(bbox_pairs, pooling_size, padding=0): 13 | """ 14 | Draws union boxes for the image. 15 | :param box_pairs: [num_pairs, 8] 16 | :param fmap_size: Size of the original feature map 17 | :param stride: ratio between fmap size and original img (<1) 18 | :param pooling_size: resize everything to this size 19 | :return: [num_pairs, 2, pooling_size, pooling_size arr 20 | """ 21 | assert padding == 0, "Padding>0 not supported yet" 22 | return draw_union_boxes_c(bbox_pairs, pooling_size) 23 | 24 | cdef DTYPE_t minmax(DTYPE_t x): 25 | return min(max(x, 0), 1) 26 | 27 | cdef np.ndarray[DTYPE_t, ndim=4] draw_union_boxes_c( 28 | np.ndarray[DTYPE_t, ndim=2] box_pairs, unsigned int pooling_size): 29 | """ 30 | Parameters 31 | ---------- 32 | boxes: (N, 4) ndarray of float. everything has arbitrary ratios 33 | query_boxes: (K, 4) ndarray of float 34 | Returns 35 | ------- 36 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 37 | """ 38 | cdef unsigned int N = box_pairs.shape[0] 39 | 40 | cdef np.ndarray[DTYPE_t, ndim = 4] uboxes = np.zeros( 41 | (N, 2, pooling_size, pooling_size), dtype=DTYPE) 42 | cdef DTYPE_t x1_union, y1_union, x2_union, y2_union, w, h, x1_box, y1_box, x2_box, y2_box, y_contrib, x_contrib 43 | cdef unsigned int n, i, j, k 44 | 45 | for n in range(N): 46 | x1_union = min(box_pairs[n, 0], box_pairs[n, 4]) 47 | y1_union = min(box_pairs[n, 1], box_pairs[n, 5]) 48 | x2_union = max(box_pairs[n, 2], box_pairs[n, 6]) 49 | y2_union = max(box_pairs[n, 3], box_pairs[n, 7]) 50 | 51 | w = x2_union - x1_union 52 | h = y2_union - y1_union 53 | 54 | for i in range(2): 55 | # Now everything is in the range [0, pooling_size]. 56 | x1_box = (box_pairs[n, 0+4*i] - x1_union)*pooling_size / w 57 | y1_box = (box_pairs[n, 1+4*i] - y1_union)*pooling_size / h 58 | x2_box = (box_pairs[n, 2+4*i] - x1_union)*pooling_size / w 59 | y2_box = (box_pairs[n, 3+4*i] - y1_union)*pooling_size / h 60 | # print("{:.3f}, {:.3f}, {:.3f}, {:.3f}".format(x1_box, y1_box, x2_box, y2_box)) 61 | for j in range(pooling_size): 62 | y_contrib = minmax(j+1-y1_box)*minmax(y2_box-j) 63 | for k in range(pooling_size): 64 | x_contrib = minmax(k+1-x1_box)*minmax(x2_box-k) 65 | # print("j {} yc {} k {} xc {}".format(j, y_contrib, k, x_contrib)) 66 | uboxes[n,i,j,k] = x_contrib*y_contrib 67 | return uboxes 68 | -------------------------------------------------------------------------------- /lib/draw_rectangles/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup(name="draw_rectangles_cython", ext_modules=cythonize('draw_rectangles.pyx'), include_dirs=[numpy.get_include()]) -------------------------------------------------------------------------------- /lib/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/evaluation/__init__.py -------------------------------------------------------------------------------- /lib/evaluation/sg_eval_slow.py: -------------------------------------------------------------------------------- 1 | # JUST TO CHECK THAT IT IS EXACTLY THE SAME.................................. 2 | import numpy as np 3 | from config import MODES 4 | 5 | class BasicSceneGraphEvaluator: 6 | 7 | def __init__(self, mode): 8 | self.result_dict = {} 9 | self.mode = {'sgdet':'sg_det', 'sgcls':'sg_cls', 'predcls':'pred_cls'}[mode] 10 | 11 | self.result_dict = {} 12 | self.result_dict[self.mode + '_recall'] = {20:[], 50:[], 100:[]} 13 | 14 | 15 | @classmethod 16 | def all_modes(cls): 17 | evaluators = {m: cls(mode=m) for m in MODES} 18 | return evaluators 19 | def evaluate_scene_graph_entry(self, gt_entry, pred_entry, iou_thresh=0.5): 20 | 21 | roidb_entry = { 22 | 'max_overlaps': np.ones(gt_entry['gt_classes'].shape[0], dtype=np.int64), 23 | 'boxes': gt_entry['gt_boxes'], 24 | 'gt_relations': gt_entry['gt_relations'], 25 | 'gt_classes': gt_entry['gt_classes'], 26 | } 27 | sg_entry = { 28 | 'boxes': pred_entry['pred_boxes'], 29 | 'relations': pred_entry['pred_rels'], 30 | 'obj_scores': pred_entry['obj_scores'], 31 | 'rel_scores': pred_entry['rel_scores'], 32 | 'pred_classes': pred_entry['pred_classes'], 33 | } 34 | 35 | pred_triplets, triplet_boxes = \ 36 | eval_relation_recall(sg_entry, roidb_entry, 37 | self.result_dict, 38 | self.mode, 39 | iou_thresh=iou_thresh) 40 | return pred_triplets, triplet_boxes 41 | 42 | 43 | def save(self, fn): 44 | np.save(fn, self.result_dict) 45 | 46 | 47 | def print_stats(self): 48 | print('======================' + self.mode + '============================') 49 | for k, v in self.result_dict[self.mode + '_recall'].items(): 50 | print('R@%i: %f' % (k, np.mean(v))) 51 | 52 | def save(self, fn): 53 | np.save(fn, self.result_dict) 54 | 55 | def print_stats(self): 56 | print('======================' + self.mode + '============================') 57 | for k, v in self.result_dict[self.mode + '_recall'].items(): 58 | print('R@%i: %f' % (k, np.mean(v))) 59 | 60 | 61 | def eval_relation_recall(sg_entry, 62 | roidb_entry, 63 | result_dict, 64 | mode, 65 | iou_thresh): 66 | 67 | # gt 68 | gt_inds = np.where(roidb_entry['max_overlaps'] == 1)[0] 69 | gt_boxes = roidb_entry['boxes'][gt_inds].copy().astype(float) 70 | num_gt_boxes = gt_boxes.shape[0] 71 | gt_relations = roidb_entry['gt_relations'].copy() 72 | gt_classes = roidb_entry['gt_classes'].copy() 73 | 74 | num_gt_relations = gt_relations.shape[0] 75 | if num_gt_relations == 0: 76 | return (None, None) 77 | gt_class_scores = np.ones(num_gt_boxes) 78 | gt_predicate_scores = np.ones(num_gt_relations) 79 | gt_triplets, gt_triplet_boxes, _ = _triplet(gt_relations[:,2], 80 | gt_relations[:,:2], 81 | gt_classes, 82 | gt_boxes, 83 | gt_predicate_scores, 84 | gt_class_scores) 85 | 86 | # pred 87 | box_preds = sg_entry['boxes'] 88 | num_boxes = box_preds.shape[0] 89 | relations = sg_entry['relations'] 90 | classes = sg_entry['pred_classes'].copy() 91 | class_scores = sg_entry['obj_scores'].copy() 92 | 93 | num_relations = relations.shape[0] 94 | 95 | if mode =='pred_cls': 96 | # if predicate classification task 97 | # use ground truth bounding boxes 98 | assert(num_boxes == num_gt_boxes) 99 | classes = gt_classes 100 | class_scores = gt_class_scores 101 | boxes = gt_boxes 102 | elif mode =='sg_cls': 103 | assert(num_boxes == num_gt_boxes) 104 | # if scene graph classification task 105 | # use gt boxes, but predicted classes 106 | # classes = np.argmax(class_preds, 1) 107 | # class_scores = class_preds.max(axis=1) 108 | boxes = gt_boxes 109 | elif mode =='sg_det': 110 | # if scene graph detection task 111 | # use preicted boxes and predicted classes 112 | # classes = np.argmax(class_preds, 1) 113 | # class_scores = class_preds.max(axis=1) 114 | boxes = box_preds 115 | else: 116 | raise NotImplementedError('Incorrect Mode! %s' % mode) 117 | 118 | pred_triplets = np.column_stack(( 119 | classes[relations[:, 0]], 120 | relations[:,2], 121 | classes[relations[:, 1]], 122 | )) 123 | pred_triplet_boxes = np.column_stack(( 124 | boxes[relations[:, 0]], 125 | boxes[relations[:, 1]], 126 | )) 127 | relation_scores = np.column_stack(( 128 | class_scores[relations[:, 0]], 129 | sg_entry['rel_scores'], 130 | class_scores[relations[:, 1]], 131 | )).prod(1) 132 | 133 | sorted_inds = np.argsort(relation_scores)[::-1] 134 | # compue recall 135 | for k in result_dict[mode + '_recall']: 136 | this_k = min(k, num_relations) 137 | keep_inds = sorted_inds[:this_k] 138 | recall = _relation_recall(gt_triplets, 139 | pred_triplets[keep_inds,:], 140 | gt_triplet_boxes, 141 | pred_triplet_boxes[keep_inds,:], 142 | iou_thresh) 143 | result_dict[mode + '_recall'][k].append(recall) 144 | 145 | # for visualization 146 | return pred_triplets[sorted_inds, :], pred_triplet_boxes[sorted_inds, :] 147 | 148 | 149 | def _triplet(predicates, relations, classes, boxes, 150 | predicate_scores, class_scores): 151 | 152 | # format predictions into triplets 153 | assert(predicates.shape[0] == relations.shape[0]) 154 | num_relations = relations.shape[0] 155 | triplets = np.zeros([num_relations, 3]).astype(np.int32) 156 | triplet_boxes = np.zeros([num_relations, 8]).astype(np.int32) 157 | triplet_scores = np.zeros([num_relations]).astype(np.float32) 158 | for i in range(num_relations): 159 | triplets[i, 1] = predicates[i] 160 | sub_i, obj_i = relations[i,:2] 161 | triplets[i, 0] = classes[sub_i] 162 | triplets[i, 2] = classes[obj_i] 163 | triplet_boxes[i, :4] = boxes[sub_i, :] 164 | triplet_boxes[i, 4:] = boxes[obj_i, :] 165 | # compute triplet score 166 | score = class_scores[sub_i] 167 | score *= class_scores[obj_i] 168 | score *= predicate_scores[i] 169 | triplet_scores[i] = score 170 | return triplets, triplet_boxes, triplet_scores 171 | 172 | 173 | def _relation_recall(gt_triplets, pred_triplets, 174 | gt_boxes, pred_boxes, iou_thresh): 175 | 176 | # compute the R@K metric for a set of predicted triplets 177 | 178 | num_gt = gt_triplets.shape[0] 179 | num_correct_pred_gt = 0 180 | 181 | for gt, gt_box in zip(gt_triplets, gt_boxes): 182 | keep = np.zeros(pred_triplets.shape[0]).astype(bool) 183 | for i, pred in enumerate(pred_triplets): 184 | if gt[0] == pred[0] and gt[1] == pred[1] and gt[2] == pred[2]: 185 | keep[i] = True 186 | if not np.any(keep): 187 | continue 188 | boxes = pred_boxes[keep,:] 189 | sub_iou = iou(gt_box[:4], boxes[:,:4]) 190 | obj_iou = iou(gt_box[4:], boxes[:,4:]) 191 | inds = np.intersect1d(np.where(sub_iou >= iou_thresh)[0], 192 | np.where(obj_iou >= iou_thresh)[0]) 193 | if inds.size > 0: 194 | num_correct_pred_gt += 1 195 | return float(num_correct_pred_gt) / float(num_gt) 196 | 197 | 198 | def iou(gt_box, pred_boxes): 199 | # computer Intersection-over-Union between two sets of boxes 200 | ixmin = np.maximum(gt_box[0], pred_boxes[:,0]) 201 | iymin = np.maximum(gt_box[1], pred_boxes[:,1]) 202 | ixmax = np.minimum(gt_box[2], pred_boxes[:,2]) 203 | iymax = np.minimum(gt_box[3], pred_boxes[:,3]) 204 | iw = np.maximum(ixmax - ixmin + 1., 0.) 205 | ih = np.maximum(iymax - iymin + 1., 0.) 206 | inters = iw * ih 207 | 208 | # union 209 | uni = ((gt_box[2] - gt_box[0] + 1.) * (gt_box[3] - gt_box[1] + 1.) + 210 | (pred_boxes[:, 2] - pred_boxes[:, 0] + 1.) * 211 | (pred_boxes[:, 3] - pred_boxes[:, 1] + 1.) - inters) 212 | 213 | overlaps = inters / uni 214 | return overlaps 215 | -------------------------------------------------------------------------------- /lib/fpn/anchor_targets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generates anchor targets to train the detector. Does this during the collate step in training 3 | as it's much cheaper to do this on a separate thread. 4 | 5 | Heavily adapted from faster_rcnn/rpn_msr/anchor_target_layer.py. 6 | """ 7 | import numpy as np 8 | import numpy.random as npr 9 | 10 | from config import IM_SCALE, RPN_NEGATIVE_OVERLAP, RPN_POSITIVE_OVERLAP, \ 11 | RPN_BATCHSIZE, RPN_FG_FRACTION, ANCHOR_SIZE, ANCHOR_SCALES, ANCHOR_RATIOS 12 | from lib.fpn.box_intersections_cpu.bbox import bbox_overlaps 13 | from lib.fpn.generate_anchors import generate_anchors 14 | 15 | 16 | def anchor_target_layer(gt_boxes, im_size, 17 | allowed_border=0): 18 | """ 19 | Assign anchors to ground-truth targets. Produces anchor classification 20 | labels and bounding-box regression targets. 21 | 22 | for each (H, W) location i 23 | generate 3 anchor boxes centered on cell i 24 | filter out-of-image anchors 25 | measure GT overlap 26 | 27 | :param gt_boxes: [x1, y1, x2, y2] boxes. These are assumed to be at the same scale as 28 | the image (IM_SCALE) 29 | :param im_size: Size of the image (h, w). This is assumed to be scaled to IM_SCALE 30 | """ 31 | if max(im_size) != IM_SCALE: 32 | raise ValueError("im size is {}".format(im_size)) 33 | h, w = im_size 34 | 35 | # Get the indices of the anchors in the feature map. 36 | # h, w, A, 4 37 | ans_np = generate_anchors(base_size=ANCHOR_SIZE, 38 | feat_stride=16, 39 | anchor_scales=ANCHOR_SCALES, 40 | anchor_ratios=ANCHOR_RATIOS, 41 | ) 42 | ans_np_flat = ans_np.reshape((-1, 4)) 43 | inds_inside = np.where( 44 | (ans_np_flat[:, 0] >= -allowed_border) & 45 | (ans_np_flat[:, 1] >= -allowed_border) & 46 | (ans_np_flat[:, 2] < w + allowed_border) & # width 47 | (ans_np_flat[:, 3] < h + allowed_border) # height 48 | )[0] 49 | good_ans_flat = ans_np_flat[inds_inside] 50 | if good_ans_flat.size == 0: 51 | raise ValueError("There were no good anchors for an image of size {} with boxes {}".format(im_size, gt_boxes)) 52 | 53 | # overlaps between the anchors and the gt boxes [num_anchors, num_gtboxes] 54 | overlaps = bbox_overlaps(good_ans_flat, gt_boxes) 55 | anchor_to_gtbox = overlaps.argmax(axis=1) 56 | max_overlaps = overlaps[np.arange(anchor_to_gtbox.shape[0]), anchor_to_gtbox] 57 | gtbox_to_anchor = overlaps.argmax(axis=0) 58 | gt_max_overlaps = overlaps[gtbox_to_anchor, np.arange(overlaps.shape[1])] 59 | gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0] 60 | 61 | # Good anchors are those that match SOMEWHERE within a decent tolerance 62 | # label: 1 is positive, 0 is negative, -1 is dont care. 63 | # assign bg labels first so that positive labels can clobber them 64 | labels = (-1) * np.ones(overlaps.shape[0], dtype=np.int64) 65 | labels[max_overlaps < RPN_NEGATIVE_OVERLAP] = 0 66 | labels[gt_argmax_overlaps] = 1 67 | labels[max_overlaps >= RPN_POSITIVE_OVERLAP] = 1 68 | 69 | # subsample positive labels if we have too many 70 | num_fg = int(RPN_FG_FRACTION * RPN_BATCHSIZE) 71 | fg_inds = np.where(labels == 1)[0] 72 | if len(fg_inds) > num_fg: 73 | labels[npr.choice(fg_inds, size=(len(fg_inds) - num_fg), replace=False)] = -1 74 | 75 | # subsample negative labels if we have too many 76 | num_bg = RPN_BATCHSIZE - np.sum(labels == 1) 77 | bg_inds = np.where(labels == 0)[0] 78 | if len(bg_inds) > num_bg: 79 | labels[npr.choice(bg_inds, size=(len(bg_inds) - num_bg), replace=False)] = -1 80 | # print("{} fg {} bg ratio{:.3f} inds inside {}".format(RPN_BATCHSIZE-num_bg, num_bg, (RPN_BATCHSIZE-num_bg)/RPN_BATCHSIZE, inds_inside.shape[0])) 81 | 82 | 83 | # Get the labels at the original size 84 | labels_unmap = (-1) * np.ones(ans_np_flat.shape[0], dtype=np.int64) 85 | labels_unmap[inds_inside] = labels 86 | 87 | # h, w, A 88 | labels_unmap_res = labels_unmap.reshape(ans_np.shape[:-1]) 89 | anchor_inds = np.column_stack(np.where(labels_unmap_res >= 0)) 90 | 91 | # These ought to be in the same order 92 | anchor_inds_flat = np.where(labels >= 0)[0] 93 | anchors = good_ans_flat[anchor_inds_flat] 94 | bbox_targets = gt_boxes[anchor_to_gtbox[anchor_inds_flat]] 95 | labels = labels[anchor_inds_flat] 96 | 97 | assert np.all(labels >= 0) 98 | 99 | 100 | # Anchors: [num_used, 4] 101 | # Anchor_inds: [num_used, 3] (h, w, A) 102 | # bbox_targets: [num_used, 4] 103 | # labels: [num_used] 104 | 105 | return anchors, anchor_inds, bbox_targets, labels 106 | -------------------------------------------------------------------------------- /lib/fpn/box_intersections_cpu/bbox.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | DTYPE = np.float 13 | ctypedef np.float_t DTYPE_t 14 | 15 | def bbox_overlaps(boxes, query_boxes): 16 | cdef np.ndarray[DTYPE_t, ndim=2] boxes_contig = np.ascontiguousarray(boxes, dtype=DTYPE) 17 | cdef np.ndarray[DTYPE_t, ndim=2] query_contig = np.ascontiguousarray(query_boxes, dtype=DTYPE) 18 | 19 | return bbox_overlaps_c(boxes_contig, query_contig) 20 | 21 | cdef np.ndarray[DTYPE_t, ndim=2] bbox_overlaps_c( 22 | np.ndarray[DTYPE_t, ndim=2] boxes, 23 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 24 | """ 25 | Parameters 26 | ---------- 27 | boxes: (N, 4) ndarray of float 28 | query_boxes: (K, 4) ndarray of float 29 | Returns 30 | ------- 31 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 32 | """ 33 | cdef unsigned int N = boxes.shape[0] 34 | cdef unsigned int K = query_boxes.shape[0] 35 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) 36 | cdef DTYPE_t iw, ih, box_area 37 | cdef DTYPE_t ua 38 | cdef unsigned int k, n 39 | for k in range(K): 40 | box_area = ( 41 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 42 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 43 | ) 44 | for n in range(N): 45 | iw = ( 46 | min(boxes[n, 2], query_boxes[k, 2]) - 47 | max(boxes[n, 0], query_boxes[k, 0]) + 1 48 | ) 49 | if iw > 0: 50 | ih = ( 51 | min(boxes[n, 3], query_boxes[k, 3]) - 52 | max(boxes[n, 1], query_boxes[k, 1]) + 1 53 | ) 54 | if ih > 0: 55 | ua = float( 56 | (boxes[n, 2] - boxes[n, 0] + 1) * 57 | (boxes[n, 3] - boxes[n, 1] + 1) + 58 | box_area - iw * ih 59 | ) 60 | overlaps[n, k] = iw * ih / ua 61 | return overlaps 62 | 63 | 64 | def bbox_intersections(boxes, query_boxes): 65 | cdef np.ndarray[DTYPE_t, ndim=2] boxes_contig = np.ascontiguousarray(boxes, dtype=DTYPE) 66 | cdef np.ndarray[DTYPE_t, ndim=2] query_contig = np.ascontiguousarray(query_boxes, dtype=DTYPE) 67 | 68 | return bbox_intersections_c(boxes_contig, query_contig) 69 | 70 | 71 | cdef np.ndarray[DTYPE_t, ndim=2] bbox_intersections_c( 72 | np.ndarray[DTYPE_t, ndim=2] boxes, 73 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 74 | """ 75 | For each query box compute the intersection ratio covered by boxes 76 | ---------- 77 | Parameters 78 | ---------- 79 | boxes: (N, 4) ndarray of float 80 | query_boxes: (K, 4) ndarray of float 81 | Returns 82 | ------- 83 | overlaps: (N, K) ndarray of intersec between boxes and query_boxes 84 | """ 85 | cdef unsigned int N = boxes.shape[0] 86 | cdef unsigned int K = query_boxes.shape[0] 87 | cdef np.ndarray[DTYPE_t, ndim=2] intersec = np.zeros((N, K), dtype=DTYPE) 88 | cdef DTYPE_t iw, ih, box_area 89 | cdef DTYPE_t ua 90 | cdef unsigned int k, n 91 | for k in range(K): 92 | box_area = ( 93 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 94 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 95 | ) 96 | for n in range(N): 97 | iw = ( 98 | min(boxes[n, 2], query_boxes[k, 2]) - 99 | max(boxes[n, 0], query_boxes[k, 0]) + 1 100 | ) 101 | if iw > 0: 102 | ih = ( 103 | min(boxes[n, 3], query_boxes[k, 3]) - 104 | max(boxes[n, 1], query_boxes[k, 1]) + 1 105 | ) 106 | if ih > 0: 107 | intersec[n, k] = iw * ih / box_area 108 | return intersec -------------------------------------------------------------------------------- /lib/fpn/box_intersections_cpu/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup(name="bbox_cython", ext_modules=cythonize('bbox.pyx'), include_dirs=[numpy.get_include()]) -------------------------------------------------------------------------------- /lib/fpn/box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | from lib.fpn.box_intersections_cpu.bbox import bbox_overlaps as bbox_overlaps_np 5 | from lib.fpn.box_intersections_cpu.bbox import bbox_intersections as bbox_intersections_np 6 | 7 | 8 | def bbox_loss(prior_boxes, deltas, gt_boxes, eps=1e-4, scale_before=1): 9 | """ 10 | Computes the loss for predicting the GT boxes from prior boxes 11 | :param prior_boxes: [num_boxes, 4] (x1, y1, x2, y2) 12 | :param deltas: [num_boxes, 4] (tx, ty, th, tw) 13 | :param gt_boxes: [num_boxes, 4] (x1, y1, x2, y2) 14 | :return: 15 | """ 16 | prior_centers = center_size(prior_boxes) #(cx, cy, w, h) 17 | gt_centers = center_size(gt_boxes) #(cx, cy, w, h) 18 | 19 | center_targets = (gt_centers[:, :2] - prior_centers[:, :2]) / prior_centers[:, 2:] 20 | size_targets = torch.log(gt_centers[:, 2:]) - torch.log(prior_centers[:, 2:]) 21 | all_targets = torch.cat((center_targets, size_targets), 1) 22 | 23 | loss = F.smooth_l1_loss(deltas, all_targets, size_average=False)/(eps + prior_centers.size(0)) 24 | 25 | return loss 26 | 27 | 28 | def bbox_preds(boxes, deltas): 29 | """ 30 | Converts "deltas" (predicted by the network) along with prior boxes 31 | into (x1, y1, x2, y2) representation. 32 | :param boxes: Prior boxes, represented as (x1, y1, x2, y2) 33 | :param deltas: Offsets (tx, ty, tw, th) 34 | :param box_strides [num_boxes,] distance apart between boxes. anchor box can't go more than 35 | \pm box_strides/2 from its current position. If None then we'll use the widths 36 | and heights 37 | :return: Transformed boxes 38 | """ 39 | 40 | if boxes.size(0) == 0: 41 | return boxes 42 | prior_centers = center_size(boxes) 43 | 44 | xys = prior_centers[:, :2] + prior_centers[:, 2:] * deltas[:, :2] 45 | 46 | whs = torch.exp(deltas[:, 2:]) * prior_centers[:, 2:] 47 | 48 | return point_form(torch.cat((xys, whs), 1)) 49 | 50 | 51 | def center_size(boxes): 52 | """ Convert prior_boxes to (cx, cy, w, h) 53 | representation for comparison to center-size form ground truth data. 54 | Args: 55 | boxes: (tensor) point_form boxes 56 | Return: 57 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 58 | """ 59 | wh = boxes[:, 2:] - boxes[:, :2] + 1.0 60 | 61 | if isinstance(boxes, np.ndarray): 62 | return np.column_stack((boxes[:, :2] + 0.5 * wh, wh)) 63 | return torch.cat((boxes[:, :2] + 0.5 * wh, wh), 1) 64 | 65 | 66 | def point_form(boxes): 67 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 68 | representation for comparison to point form ground truth data. 69 | Args: 70 | boxes: (tensor) center-size default boxes from priorbox layers. 71 | Return: 72 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 73 | """ 74 | if isinstance(boxes, np.ndarray): 75 | return np.column_stack((boxes[:, :2] - 0.5 * boxes[:, 2:], 76 | boxes[:, :2] + 0.5 * (boxes[:, 2:] - 2.0))) 77 | return torch.cat((boxes[:, :2] - 0.5 * boxes[:, 2:], 78 | boxes[:, :2] + 0.5 * (boxes[:, 2:] - 2.0)), 1) # xmax, ymax 79 | 80 | 81 | ########################################################################### 82 | ### Torch Utils, creds to Max de Groot 83 | ########################################################################### 84 | 85 | def bbox_intersections(box_a, box_b): 86 | """ We resize both tensors to [A,B,2] without new malloc: 87 | [A,2] -> [A,1,2] -> [A,B,2] 88 | [B,2] -> [1,B,2] -> [A,B,2] 89 | Then we compute the area of intersect between box_a and box_b. 90 | Args: 91 | box_a: (tensor) bounding boxes, Shape: [A,4]. 92 | box_b: (tensor) bounding boxes, Shape: [B,4]. 93 | Return: 94 | (tensor) intersection area, Shape: [A,B]. 95 | """ 96 | if isinstance(box_a, np.ndarray): 97 | assert isinstance(box_b, np.ndarray) 98 | return bbox_intersections_np(box_a, box_b) 99 | A = box_a.size(0) 100 | B = box_b.size(0) 101 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 102 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 103 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 104 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 105 | inter = torch.clamp((max_xy - min_xy + 1.0), min=0) 106 | return inter[:, :, 0] * inter[:, :, 1] 107 | 108 | 109 | def bbox_overlaps(box_a, box_b): 110 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 111 | is simply the intersection over union of two boxes. Here we operate on 112 | ground truth boxes and default boxes. 113 | E.g.: 114 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 115 | Args: 116 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 117 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 118 | Return: 119 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 120 | """ 121 | if isinstance(box_a, np.ndarray): 122 | assert isinstance(box_b, np.ndarray) 123 | return bbox_overlaps_np(box_a, box_b) 124 | 125 | inter = bbox_intersections(box_a, box_b) 126 | area_a = ((box_a[:, 2] - box_a[:, 0] + 1.0) * 127 | (box_a[:, 3] - box_a[:, 1] + 1.0)).unsqueeze(1).expand_as(inter) # [A,B] 128 | area_b = ((box_b[:, 2] - box_b[:, 0] + 1.0) * 129 | (box_b[:, 3] - box_b[:, 1] + 1.0)).unsqueeze(0).expand_as(inter) # [A,B] 130 | union = area_a + area_b - inter 131 | return inter / union # [A,B] 132 | 133 | 134 | def nms_overlaps(boxes): 135 | """ get overlaps for each channel""" 136 | assert boxes.dim() == 3 137 | N = boxes.size(0) 138 | nc = boxes.size(1) 139 | max_xy = torch.min(boxes[:, None, :, 2:].expand(N, N, nc, 2), 140 | boxes[None, :, :, 2:].expand(N, N, nc, 2)) 141 | 142 | min_xy = torch.max(boxes[:, None, :, :2].expand(N, N, nc, 2), 143 | boxes[None, :, :, :2].expand(N, N, nc, 2)) 144 | 145 | inter = torch.clamp((max_xy - min_xy + 1.0), min=0) 146 | 147 | # n, n, 151 148 | inters = inter[:,:,:,0]*inter[:,:,:,1] 149 | boxes_flat = boxes.view(-1, 4) 150 | areas_flat = (boxes_flat[:,2]- boxes_flat[:,0]+1.0)*( 151 | boxes_flat[:,3]- boxes_flat[:,1]+1.0) 152 | areas = areas_flat.view(boxes.size(0), boxes.size(1)) 153 | union = -inters + areas[None] + areas[:, None] 154 | return inters / union 155 | 156 | -------------------------------------------------------------------------------- /lib/fpn/generate_anchors.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Faster R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick and Sean Bell 6 | # -------------------------------------------------------- 7 | from config import IM_SCALE 8 | 9 | import numpy as np 10 | 11 | 12 | # Verify that we compute the same anchors as Shaoqing's matlab implementation: 13 | # 14 | # >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat 15 | # >> anchors 16 | # 17 | # anchors = 18 | # 19 | # -83 -39 100 56 20 | # -175 -87 192 104 21 | # -359 -183 376 200 22 | # -55 -55 72 72 23 | # -119 -119 136 136 24 | # -247 -247 264 264 25 | # -35 -79 52 96 26 | # -79 -167 96 184 27 | # -167 -343 184 360 28 | 29 | # array([[ -83., -39., 100., 56.], 30 | # [-175., -87., 192., 104.], 31 | # [-359., -183., 376., 200.], 32 | # [ -55., -55., 72., 72.], 33 | # [-119., -119., 136., 136.], 34 | # [-247., -247., 264., 264.], 35 | # [ -35., -79., 52., 96.], 36 | # [ -79., -167., 96., 184.], 37 | # [-167., -343., 184., 360.]]) 38 | 39 | def generate_anchors(base_size=16, feat_stride=16, anchor_scales=(8,16,32), anchor_ratios=(0.5,1,2)): 40 | """ A wrapper function to generate anchors given different scales 41 | Also return the number of anchors in variable 'length' 42 | """ 43 | anchors = generate_base_anchors(base_size=base_size, 44 | ratios=np.array(anchor_ratios), 45 | scales=np.array(anchor_scales)) 46 | A = anchors.shape[0] 47 | shift_x = np.arange(0, IM_SCALE // feat_stride) * feat_stride # Same as shift_x 48 | shift_x, shift_y = np.meshgrid(shift_x, shift_x) 49 | 50 | shifts = np.stack([shift_x, shift_y, shift_x, shift_y], -1) # h, w, 4 51 | all_anchors = shifts[:, :, None] + anchors[None, None] #h, w, A, 4 52 | return all_anchors 53 | 54 | # shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose() 55 | # K = shifts.shape[0] 56 | # # width changes faster, so here it is H, W, C 57 | # anchors = anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)) 58 | # anchors = anchors.reshape((K * A, 4)).astype(np.float32, copy=False) 59 | # length = np.int32(anchors.shape[0]) 60 | 61 | 62 | def generate_base_anchors(base_size=16, ratios=[0.5, 1, 2], scales=2 ** np.arange(3, 6)): 63 | """ 64 | Generate anchor (reference) windows by enumerating aspect ratios X 65 | scales wrt a reference (0, 0, 15, 15) window. 66 | """ 67 | 68 | base_anchor = np.array([1, 1, base_size, base_size]) - 1 69 | ratio_anchors = _ratio_enum(base_anchor, ratios) 70 | anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) 71 | for i in range(ratio_anchors.shape[0])]) 72 | return anchors 73 | 74 | 75 | def _whctrs(anchor): 76 | """ 77 | Return width, height, x center, and y center for an anchor (window). 78 | """ 79 | 80 | w = anchor[2] - anchor[0] + 1 81 | h = anchor[3] - anchor[1] + 1 82 | x_ctr = anchor[0] + 0.5 * (w - 1) 83 | y_ctr = anchor[1] + 0.5 * (h - 1) 84 | return w, h, x_ctr, y_ctr 85 | 86 | 87 | def _mkanchors(ws, hs, x_ctr, y_ctr): 88 | """ 89 | Given a vector of widths (ws) and heights (hs) around a center 90 | (x_ctr, y_ctr), output a set of anchors (windows). 91 | """ 92 | 93 | ws = ws[:, np.newaxis] 94 | hs = hs[:, np.newaxis] 95 | anchors = np.hstack((x_ctr - 0.5 * (ws - 1), 96 | y_ctr - 0.5 * (hs - 1), 97 | x_ctr + 0.5 * (ws - 1), 98 | y_ctr + 0.5 * (hs - 1))) 99 | return anchors 100 | 101 | 102 | def _ratio_enum(anchor, ratios): 103 | """ 104 | Enumerate a set of anchors for each aspect ratio wrt an anchor. 105 | """ 106 | 107 | w, h, x_ctr, y_ctr = _whctrs(anchor) 108 | size = w * h 109 | size_ratios = size / ratios 110 | # NOTE: CHANGED TO NOT HAVE ROUNDING 111 | ws = np.sqrt(size_ratios) 112 | hs = ws * ratios 113 | anchors = _mkanchors(ws, hs, x_ctr, y_ctr) 114 | return anchors 115 | 116 | 117 | def _scale_enum(anchor, scales): 118 | """ 119 | Enumerate a set of anchors for each scale wrt an anchor. 120 | """ 121 | 122 | w, h, x_ctr, y_ctr = _whctrs(anchor) 123 | ws = w * scales 124 | hs = h * scales 125 | anchors = _mkanchors(ws, hs, x_ctr, y_ctr) 126 | return anchors 127 | -------------------------------------------------------------------------------- /lib/fpn/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd anchors 4 | python setup.py build_ext --inplace 5 | cd .. 6 | 7 | cd box_intersections_cpu 8 | python setup.py build_ext --inplace 9 | cd .. 10 | 11 | cd cpu_nms 12 | python build.py 13 | cd .. 14 | 15 | cd roi_align 16 | python build.py -C src/cuda clean 17 | python build.py -C src/cuda clean 18 | cd .. 19 | 20 | echo "Done compiling hopefully" 21 | -------------------------------------------------------------------------------- /lib/fpn/nms/Makefile: -------------------------------------------------------------------------------- 1 | all: src/cuda/nms.cu.o 2 | python build.py 3 | 4 | src/cuda/nms.cu.o: src/cuda/nms_kernel.cu 5 | $(MAKE) -C src/cuda 6 | 7 | clean: 8 | $(MAKE) -C src/cuda clean 9 | -------------------------------------------------------------------------------- /lib/fpn/nms/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | # Might have to export PATH=/usr/local/cuda-8.0/bin${PATH:+:${PATH}} 5 | 6 | sources = [] 7 | headers = [] 8 | defines = [] 9 | with_cuda = False 10 | 11 | if torch.cuda.is_available(): 12 | print('Including CUDA code.') 13 | sources += ['src/nms_cuda.c'] 14 | headers += ['src/nms_cuda.h'] 15 | defines += [('WITH_CUDA', None)] 16 | with_cuda = True 17 | 18 | this_file = os.path.dirname(os.path.realpath(__file__)) 19 | print(this_file) 20 | extra_objects = ['src/cuda/nms.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi = create_extension( 24 | '_ext.nms', 25 | headers=headers, 26 | sources=sources, 27 | define_macros=defines, 28 | relative_to=__file__, 29 | with_cuda=with_cuda, 30 | extra_objects=extra_objects 31 | ) 32 | 33 | if __name__ == '__main__': 34 | ffi.build() 35 | 36 | -------------------------------------------------------------------------------- /lib/fpn/nms/functions/nms.py: -------------------------------------------------------------------------------- 1 | # Le code for doing NMS 2 | import torch 3 | import numpy as np 4 | from .._ext import nms 5 | 6 | 7 | def apply_nms(scores, boxes, pre_nms_topn=12000, post_nms_topn=2000, boxes_per_im=None, 8 | nms_thresh=0.7): 9 | """ 10 | Note - this function is non-differentiable so everything is assumed to be a tensor, not 11 | a variable. 12 | """ 13 | just_inds = boxes_per_im is None 14 | if boxes_per_im is None: 15 | boxes_per_im = [boxes.size(0)] 16 | 17 | 18 | s = 0 19 | keep = [] 20 | im_per = [] 21 | for bpi in boxes_per_im: 22 | e = s + int(bpi) 23 | keep_im = _nms_single_im(scores[s:e], boxes[s:e], pre_nms_topn, post_nms_topn, nms_thresh) 24 | keep.append(keep_im + s) 25 | im_per.append(keep_im.size(0)) 26 | 27 | s = e 28 | 29 | inds = torch.cat(keep, 0) 30 | if just_inds: 31 | return inds 32 | return inds, im_per 33 | 34 | 35 | def _nms_single_im(scores, boxes, pre_nms_topn=12000, post_nms_topn=2000, nms_thresh=0.7): 36 | keep = torch.IntTensor(scores.size(0)) 37 | vs, idx = torch.sort(scores, dim=0, descending=True) 38 | if idx.size(0) > pre_nms_topn: 39 | idx = idx[:pre_nms_topn] 40 | boxes_sorted = boxes[idx].contiguous() 41 | num_out = nms.nms_apply(keep, boxes_sorted, nms_thresh) 42 | num_out = min(num_out, post_nms_topn) 43 | keep = keep[:num_out].long() 44 | keep = idx[keep.cuda(scores.get_device())] 45 | return keep 46 | -------------------------------------------------------------------------------- /lib/fpn/nms/src/cuda/Makefile: -------------------------------------------------------------------------------- 1 | all: nms_kernel.cu nms_kernel.h 2 | /usr/local/cuda/bin/nvcc -c -o nms.cu.o nms_kernel.cu --compiler-options -fPIC -gencode arch=compute_61,code=sm_61 3 | clean: 4 | rm nms.cu.o 5 | -------------------------------------------------------------------------------- /lib/fpn/nms/src/cuda/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | 8 | #include 9 | #include 10 | 11 | #define CUDA_CHECK(condition) \ 12 | /* Code block avoids redefinition of cudaError_t error */ \ 13 | do { \ 14 | cudaError_t error = condition; \ 15 | if (error != cudaSuccess) { \ 16 | std::cout << cudaGetErrorString(error) << std::endl; \ 17 | } \ 18 | } while (0) 19 | 20 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 21 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 22 | 23 | __device__ inline float devIoU(float const * const a, float const * const b) { 24 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 25 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 26 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 27 | float interS = width * height; 28 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 29 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 30 | return interS / (Sa + Sb - interS); 31 | } 32 | 33 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 34 | const float *dev_boxes, unsigned long long *dev_mask) { 35 | const int row_start = blockIdx.y; 36 | const int col_start = blockIdx.x; 37 | 38 | // if (row_start > col_start) return; 39 | 40 | const int row_size = 41 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 42 | const int col_size = 43 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 44 | 45 | __shared__ float block_boxes[threadsPerBlock * 5]; 46 | if (threadIdx.x < col_size) { 47 | block_boxes[threadIdx.x * 4 + 0] = 48 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; 49 | block_boxes[threadIdx.x * 4 + 1] = 50 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; 51 | block_boxes[threadIdx.x * 4 + 2] = 52 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; 53 | block_boxes[threadIdx.x * 4 + 3] = 54 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; 55 | } 56 | __syncthreads(); 57 | 58 | if (threadIdx.x < row_size) { 59 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 60 | const float *cur_box = dev_boxes + cur_box_idx * 4; 61 | int i = 0; 62 | unsigned long long t = 0; 63 | int start = 0; 64 | if (row_start == col_start) { 65 | start = threadIdx.x + 1; 66 | } 67 | for (i = start; i < col_size; i++) { 68 | if (devIoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) { 69 | t |= 1ULL << i; 70 | } 71 | } 72 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 73 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 74 | } 75 | } 76 | 77 | void _set_device(int device_id) { 78 | int current_device; 79 | CUDA_CHECK(cudaGetDevice(¤t_device)); 80 | if (current_device == device_id) { 81 | return; 82 | } 83 | // The call to cudaSetDevice must come before any calls to Get, which 84 | // may perform initialization using the GPU. 85 | CUDA_CHECK(cudaSetDevice(device_id)); 86 | } 87 | 88 | extern "C" int ApplyNMSGPU(int* keep_out, const float* boxes_dev, const int boxes_num, 89 | float nms_overlap_thresh, int device_id) { 90 | _set_device(device_id); 91 | 92 | unsigned long long* mask_dev = NULL; 93 | 94 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 95 | 96 | CUDA_CHECK(cudaMalloc(&mask_dev, 97 | boxes_num * col_blocks * sizeof(unsigned long long))); 98 | 99 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 100 | DIVUP(boxes_num, threadsPerBlock)); 101 | dim3 threads(threadsPerBlock); 102 | nms_kernel<<>>(boxes_num, 103 | nms_overlap_thresh, 104 | boxes_dev, 105 | mask_dev); 106 | 107 | std::vector mask_host(boxes_num * col_blocks); 108 | CUDA_CHECK(cudaMemcpy(&mask_host[0], 109 | mask_dev, 110 | sizeof(unsigned long long) * boxes_num * col_blocks, 111 | cudaMemcpyDeviceToHost)); 112 | 113 | std::vector remv(col_blocks); 114 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 115 | 116 | int num_to_keep = 0; 117 | for (int i = 0; i < boxes_num; i++) { 118 | int nblock = i / threadsPerBlock; 119 | int inblock = i % threadsPerBlock; 120 | 121 | if (!(remv[nblock] & (1ULL << inblock))) { 122 | keep_out[num_to_keep++] = i; 123 | unsigned long long *p = &mask_host[0] + i * col_blocks; 124 | for (int j = nblock; j < col_blocks; j++) { 125 | remv[j] |= p[j]; 126 | } 127 | } 128 | } 129 | 130 | CUDA_CHECK(cudaFree(mask_dev)); 131 | return num_to_keep; 132 | } 133 | -------------------------------------------------------------------------------- /lib/fpn/nms/src/cuda/nms_kernel.h: -------------------------------------------------------------------------------- 1 | int ApplyNMSGPU(int* keep_out, const float* boxes_dev, const int boxes_num, 2 | float nms_overlap_thresh, int device_id); 3 | 4 | -------------------------------------------------------------------------------- /lib/fpn/nms/src/nms_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda/nms_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int nms_apply(THIntTensor* keep, THCudaTensor* boxes_sorted, const float nms_thresh) 8 | { 9 | int* keep_data = THIntTensor_data(keep); 10 | const float* boxes_sorted_data = THCudaTensor_data(state, boxes_sorted); 11 | 12 | const int boxes_num = THCudaTensor_size(state, boxes_sorted, 0); 13 | 14 | const int devId = THCudaTensor_getDevice(state, boxes_sorted); 15 | 16 | int numTotalKeep = ApplyNMSGPU(keep_data, boxes_sorted_data, boxes_num, nms_thresh, devId); 17 | return numTotalKeep; 18 | } 19 | 20 | 21 | -------------------------------------------------------------------------------- /lib/fpn/nms/src/nms_cuda.h: -------------------------------------------------------------------------------- 1 | int nms_apply(THIntTensor* keep, THCudaTensor* boxes_sorted, const float nms_thresh); -------------------------------------------------------------------------------- /lib/fpn/proposal_assignments/proposal_assignments_det.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import numpy.random as npr 4 | from config import BG_THRESH_HI, BG_THRESH_LO, FG_FRACTION, ROIS_PER_IMG 5 | from lib.fpn.box_utils import bbox_overlaps 6 | from lib.pytorch_misc import to_variable 7 | import torch 8 | 9 | ############################################################# 10 | # The following is only for object detection 11 | @to_variable 12 | def proposal_assignments_det(rpn_rois, gt_boxes, gt_classes, image_offset, fg_thresh=0.5): 13 | """ 14 | Assign object detection proposals to ground-truth targets. Produces proposal 15 | classification labels and bounding-box regression targets. 16 | :param rpn_rois: [img_ind, x1, y1, x2, y2] 17 | :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1 18 | :param gt_classes: [num_boxes, 2] array of [img_ind, class] 19 | :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) 20 | :return: 21 | rois: [num_rois, 5] 22 | labels: [num_rois] array of labels 23 | bbox_targets [num_rois, 4] array of targets for the labels. 24 | """ 25 | fg_rois_per_image = int(np.round(ROIS_PER_IMG * FG_FRACTION)) 26 | 27 | gt_img_inds = gt_classes[:, 0] - image_offset 28 | 29 | all_boxes = torch.cat([rpn_rois[:, 1:], gt_boxes], 0) 30 | 31 | ims_per_box = torch.cat([rpn_rois[:, 0].long(), gt_img_inds], 0) 32 | 33 | im_sorted, idx = torch.sort(ims_per_box, 0) 34 | all_boxes = all_boxes[idx] 35 | 36 | # Assume that the GT boxes are already sorted in terms of image id 37 | num_images = int(im_sorted[-1]) + 1 38 | 39 | labels = [] 40 | rois = [] 41 | bbox_targets = [] 42 | for im_ind in range(num_images): 43 | g_inds = (gt_img_inds == im_ind).nonzero() 44 | 45 | if g_inds.dim() == 0: 46 | continue 47 | g_inds = g_inds.squeeze(1) 48 | g_start = g_inds[0] 49 | g_end = g_inds[-1] + 1 50 | 51 | t_inds = (im_sorted == im_ind).nonzero().squeeze(1) 52 | t_start = t_inds[0] 53 | t_end = t_inds[-1] + 1 54 | 55 | # Max overlaps: for each predicted box, get the max ROI 56 | # Get the indices into the GT boxes too (must offset by the box start) 57 | ious = bbox_overlaps(all_boxes[t_start:t_end], gt_boxes[g_start:g_end]) 58 | max_overlaps, gt_assignment = ious.max(1) 59 | max_overlaps = max_overlaps.cpu().numpy() 60 | # print("Best overlap is {}".format(max_overlaps.max())) 61 | # print("\ngt assignment is {} while g_start is {} \n ---".format(gt_assignment, g_start)) 62 | gt_assignment += g_start 63 | 64 | keep_inds_np, num_fg = _sel_inds(max_overlaps, fg_thresh, fg_rois_per_image, 65 | ROIS_PER_IMG) 66 | 67 | if keep_inds_np.size == 0: 68 | continue 69 | 70 | keep_inds = torch.LongTensor(keep_inds_np).cuda(rpn_rois.get_device()) 71 | 72 | labels_ = gt_classes[:, 1][gt_assignment[keep_inds]] 73 | bbox_target_ = gt_boxes[gt_assignment[keep_inds]] 74 | 75 | # Clamp labels_ for the background RoIs to 0 76 | if num_fg < labels_.size(0): 77 | labels_[num_fg:] = 0 78 | 79 | rois_ = torch.cat(( 80 | im_sorted[t_start:t_end, None][keep_inds].float(), 81 | all_boxes[t_start:t_end][keep_inds], 82 | ), 1) 83 | 84 | labels.append(labels_) 85 | rois.append(rois_) 86 | bbox_targets.append(bbox_target_) 87 | 88 | rois = torch.cat(rois, 0) 89 | labels = torch.cat(labels, 0) 90 | bbox_targets = torch.cat(bbox_targets, 0) 91 | return rois, labels, bbox_targets 92 | 93 | 94 | def _sel_inds(max_overlaps, fg_thresh=0.5, fg_rois_per_image=128, rois_per_image=256): 95 | # Select foreground RoIs as those with >= FG_THRESH overlap 96 | fg_inds = np.where(max_overlaps >= fg_thresh)[0] 97 | 98 | # Guard against the case when an image has fewer than fg_rois_per_image 99 | # foreground RoIs 100 | fg_rois_per_this_image = min(fg_rois_per_image, fg_inds.shape[0]) 101 | # Sample foreground regions without replacement 102 | if fg_inds.size > 0: 103 | fg_inds = npr.choice(fg_inds, size=fg_rois_per_this_image, replace=False) 104 | 105 | # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI) 106 | bg_inds = np.where((max_overlaps < BG_THRESH_HI) & (max_overlaps >= BG_THRESH_LO))[0] 107 | 108 | # Compute number of background RoIs to take from this image (guarding 109 | # against there being fewer than desired) 110 | bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image 111 | bg_rois_per_this_image = min(bg_rois_per_this_image, bg_inds.size) 112 | # Sample background regions without replacement 113 | if bg_inds.size > 0: 114 | bg_inds = npr.choice(bg_inds, size=bg_rois_per_this_image, replace=False) 115 | 116 | return np.append(fg_inds, bg_inds), fg_rois_per_this_image 117 | 118 | -------------------------------------------------------------------------------- /lib/fpn/proposal_assignments/proposal_assignments_gtbox.py: -------------------------------------------------------------------------------- 1 | from lib.pytorch_misc import enumerate_by_image, gather_nd, random_choose 2 | from lib.fpn.box_utils import bbox_preds, center_size, bbox_overlaps 3 | import torch 4 | from lib.pytorch_misc import diagonal_inds, to_variable 5 | from config import RELS_PER_IMG, REL_FG_FRACTION 6 | 7 | 8 | @to_variable 9 | def proposal_assignments_gtbox(rois, gt_boxes, gt_classes, gt_rels, image_offset, fg_thresh=0.5): 10 | """ 11 | Assign object detection proposals to ground-truth targets. Produces proposal 12 | classification labels and bounding-box regression targets. 13 | :param rpn_rois: [img_ind, x1, y1, x2, y2] 14 | :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems 15 | :param gt_classes: [num_boxes, 2] array of [img_ind, class] 16 | Note, the img_inds here start at image_offset 17 | :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type]. 18 | Note, the img_inds here start at image_offset 19 | :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) 20 | :return: 21 | rois: [num_rois, 5] 22 | labels: [num_rois] array of labels 23 | bbox_targets [num_rois, 4] array of targets for the labels. 24 | rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) 25 | """ 26 | im_inds = rois[:,0].long() 27 | 28 | num_im = im_inds[-1] + 1 29 | 30 | # Offset the image indices in fg_rels to refer to absolute indices (not just within img i) 31 | fg_rels = gt_rels.clone() 32 | fg_rels[:,0] -= image_offset 33 | offset = {} 34 | for i, s, e in enumerate_by_image(im_inds): 35 | offset[i] = s 36 | for i, s, e in enumerate_by_image(fg_rels[:, 0]): 37 | fg_rels[s:e, 1:3] += offset[i] 38 | 39 | # Try ALL things, not just intersections. 40 | is_cand = (im_inds[:, None] == im_inds[None]) 41 | is_cand.view(-1)[diagonal_inds(is_cand)] = 0 42 | 43 | # # Compute salience 44 | # gt_inds = fg_rels[:, 1:3].contiguous().view(-1) 45 | # labels_arange = labels.data.new(labels.size(0)) 46 | # torch.arange(0, labels.size(0), out=labels_arange) 47 | # salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long() 48 | # labels = torch.stack((labels, salience_labels), 1) 49 | 50 | # Add in some BG labels 51 | 52 | # NOW WE HAVE TO EXCLUDE THE FGs. 53 | # TODO: check if this causes an error if many duplicate GTs havent been filtered out 54 | 55 | is_cand.view(-1)[fg_rels[:,1]*im_inds.size(0) + fg_rels[:,2]] = 0 56 | is_bgcand = is_cand.nonzero() 57 | # TODO: make this sample on a per image case 58 | # If too many then sample 59 | num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im)) 60 | if num_fg < fg_rels.size(0): 61 | fg_rels = random_choose(fg_rels, num_fg) 62 | 63 | # If too many then sample 64 | num_bg = min(is_bgcand.size(0) if is_bgcand.dim() > 0 else 0, 65 | int(RELS_PER_IMG * num_im) - num_fg) 66 | if num_bg > 0: 67 | bg_rels = torch.cat(( 68 | im_inds[is_bgcand[:, 0]][:, None], 69 | is_bgcand, 70 | (is_bgcand[:, 0, None] < -10).long(), 71 | ), 1) 72 | 73 | if num_bg < is_bgcand.size(0): 74 | bg_rels = random_choose(bg_rels, num_bg) 75 | rel_labels = torch.cat((fg_rels, bg_rels), 0) 76 | else: 77 | rel_labels = fg_rels 78 | 79 | 80 | # last sort by rel. 81 | _, perm = torch.sort(rel_labels[:, 0]*(gt_boxes.size(0)**2) + 82 | rel_labels[:,1]*gt_boxes.size(0) + rel_labels[:,2]) 83 | 84 | rel_labels = rel_labels[perm].contiguous() 85 | 86 | labels = gt_classes[:,1].contiguous() 87 | return rois, labels, rel_labels 88 | -------------------------------------------------------------------------------- /lib/fpn/proposal_assignments/proposal_assignments_postnms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Goal: assign ROIs to targets 3 | # -------------------------------------------------------- 4 | 5 | 6 | import numpy as np 7 | import numpy.random as npr 8 | from .proposal_assignments_rel import _sel_rels 9 | from lib.fpn.box_utils import bbox_overlaps 10 | from lib.pytorch_misc import to_variable 11 | import torch 12 | 13 | 14 | @to_variable 15 | def proposal_assignments_postnms( 16 | rois, gt_boxes, gt_classes, gt_rels, nms_inds, image_offset, fg_thresh=0.5, 17 | max_objs=100, max_rels=100, rand_val=0.01): 18 | """ 19 | Assign object detection proposals to ground-truth targets. Produces proposal 20 | classification labels and bounding-box regression targets. 21 | :param rpn_rois: [img_ind, x1, y1, x2, y2] 22 | :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1] 23 | :param gt_classes: [num_boxes, 2] array of [img_ind, class] 24 | :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type] 25 | :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) 26 | :return: 27 | rois: [num_rois, 5] 28 | labels: [num_rois] array of labels 29 | rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) 30 | """ 31 | pred_inds_np = rois[:, 0].cpu().numpy().astype(np.int64) 32 | pred_boxes_np = rois[:, 1:].cpu().numpy() 33 | nms_inds_np = nms_inds.cpu().numpy() 34 | sup_inds_np = np.setdiff1d(np.arange(pred_boxes_np.shape[0]), nms_inds_np) 35 | 36 | # split into chosen and suppressed 37 | chosen_inds_np = pred_inds_np[nms_inds_np] 38 | chosen_boxes_np = pred_boxes_np[nms_inds_np] 39 | 40 | suppre_inds_np = pred_inds_np[sup_inds_np] 41 | suppre_boxes_np = pred_boxes_np[sup_inds_np] 42 | 43 | gt_boxes_np = gt_boxes.cpu().numpy() 44 | gt_classes_np = gt_classes.cpu().numpy() 45 | gt_rels_np = gt_rels.cpu().numpy() 46 | 47 | gt_classes_np[:, 0] -= image_offset 48 | gt_rels_np[:, 0] -= image_offset 49 | 50 | num_im = gt_classes_np[:, 0].max()+1 51 | 52 | rois = [] 53 | obj_labels = [] 54 | rel_labels = [] 55 | num_box_seen = 0 56 | 57 | for im_ind in range(num_im): 58 | chosen_ind = np.where(chosen_inds_np == im_ind)[0] 59 | suppre_ind = np.where(suppre_inds_np == im_ind)[0] 60 | 61 | gt_ind = np.where(gt_classes_np[:, 0] == im_ind)[0] 62 | gt_boxes_i = gt_boxes_np[gt_ind] 63 | gt_classes_i = gt_classes_np[gt_ind, 1] 64 | gt_rels_i = gt_rels_np[gt_rels_np[:, 0] == im_ind, 1:] 65 | 66 | # Get IOUs between chosen and GT boxes and if needed we'll add more in 67 | 68 | chosen_boxes_i = chosen_boxes_np[chosen_ind] 69 | suppre_boxes_i = suppre_boxes_np[suppre_ind] 70 | 71 | n_chosen = chosen_boxes_i.shape[0] 72 | n_suppre = suppre_boxes_i.shape[0] 73 | n_gt_box = gt_boxes_i.shape[0] 74 | 75 | # add a teensy bit of random noise because some GT boxes might be duplicated, etc. 76 | pred_boxes_i = np.concatenate((chosen_boxes_i, suppre_boxes_i, gt_boxes_i), 0) 77 | ious = bbox_overlaps(pred_boxes_i, gt_boxes_i) + rand_val*( 78 | np.random.rand(pred_boxes_i.shape[0], gt_boxes_i.shape[0])-0.5) 79 | 80 | # Let's say that a box can only be assigned ONCE for now because we've already done 81 | # the NMS and stuff. 82 | is_hit = ious > fg_thresh 83 | 84 | obj_assignments_i = is_hit.argmax(1) 85 | obj_assignments_i[~is_hit.any(1)] = -1 86 | 87 | vals, first_occurance_ind = np.unique(obj_assignments_i, return_index=True) 88 | obj_assignments_i[np.setdiff1d( 89 | np.arange(obj_assignments_i.shape[0]), first_occurance_ind)] = -1 90 | 91 | extra_to_add = np.where(obj_assignments_i[n_chosen:] != -1)[0] + n_chosen 92 | 93 | # Add them in somewhere at random 94 | num_inds_to_have = min(max_objs, n_chosen + extra_to_add.shape[0]) 95 | boxes_i = np.zeros((num_inds_to_have, 4), dtype=np.float32) 96 | labels_i = np.zeros(num_inds_to_have, dtype=np.int64) 97 | 98 | inds_from_nms = np.sort(np.random.choice(num_inds_to_have, size=n_chosen, replace=False)) 99 | inds_from_elsewhere = np.setdiff1d(np.arange(num_inds_to_have), inds_from_nms) 100 | 101 | boxes_i[inds_from_nms] = chosen_boxes_i 102 | labels_i[inds_from_nms] = gt_classes_i[obj_assignments_i[:n_chosen]] 103 | 104 | boxes_i[inds_from_elsewhere] = pred_boxes_i[extra_to_add] 105 | labels_i[inds_from_elsewhere] = gt_classes_i[obj_assignments_i[extra_to_add]] 106 | 107 | # Now, we do the relationships. same as for rle 108 | all_rels_i = _sel_rels(bbox_overlaps(boxes_i, gt_boxes_i), 109 | boxes_i, 110 | labels_i, 111 | gt_classes_i, 112 | gt_rels_i, 113 | fg_thresh=fg_thresh, 114 | fg_rels_per_image=100) 115 | all_rels_i[:,0:2] += num_box_seen 116 | 117 | rois.append(np.column_stack(( 118 | im_ind * np.ones(boxes_i.shape[0], dtype=np.float32), 119 | boxes_i, 120 | ))) 121 | obj_labels.append(labels_i) 122 | rel_labels.append(np.column_stack(( 123 | im_ind*np.ones(all_rels_i.shape[0], dtype=np.int64), 124 | all_rels_i, 125 | ))) 126 | num_box_seen += boxes_i.size 127 | 128 | rois = torch.FloatTensor(np.concatenate(rois, 0)).cuda(gt_boxes.get_device(), async=True) 129 | labels = torch.LongTensor(np.concatenate(obj_labels, 0)).cuda(gt_boxes.get_device(), async=True) 130 | rel_labels = torch.LongTensor(np.concatenate(rel_labels, 0)).cuda(gt_boxes.get_device(), 131 | async=True) 132 | 133 | return rois, labels, rel_labels 134 | -------------------------------------------------------------------------------- /lib/fpn/proposal_assignments/rel_assignments.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Goal: assign ROIs to targets 3 | # -------------------------------------------------------- 4 | 5 | 6 | import numpy as np 7 | import numpy.random as npr 8 | from config import BG_THRESH_HI, BG_THRESH_LO, REL_FG_FRACTION, RELS_PER_IMG_REFINE 9 | from lib.fpn.box_utils import bbox_overlaps 10 | from lib.pytorch_misc import to_variable, nonintersecting_2d_inds 11 | from collections import defaultdict 12 | import torch 13 | 14 | @to_variable 15 | def rel_assignments(im_inds, rpn_rois, roi_gtlabels, gt_boxes, gt_classes, gt_rels, image_offset, 16 | fg_thresh=0.5, num_sample_per_gt=4, filter_non_overlap=True): 17 | """ 18 | Assign object detection proposals to ground-truth targets. Produces proposal 19 | classification labels and bounding-box regression targets. 20 | :param rpn_rois: [img_ind, x1, y1, x2, y2] 21 | :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1] 22 | :param gt_classes: [num_boxes, 2] array of [img_ind, class] 23 | :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type] 24 | :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) 25 | :return: 26 | rois: [num_rois, 5] 27 | labels: [num_rois] array of labels 28 | bbox_targets [num_rois, 4] array of targets for the labels. 29 | rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) 30 | """ 31 | fg_rels_per_image = int(np.round(REL_FG_FRACTION * 64)) 32 | 33 | pred_inds_np = im_inds.cpu().numpy() 34 | pred_boxes_np = rpn_rois.cpu().numpy() 35 | pred_boxlabels_np = roi_gtlabels.cpu().numpy() 36 | gt_boxes_np = gt_boxes.cpu().numpy() 37 | gt_classes_np = gt_classes.cpu().numpy() 38 | gt_rels_np = gt_rels.cpu().numpy() 39 | 40 | gt_classes_np[:, 0] -= image_offset 41 | gt_rels_np[:, 0] -= image_offset 42 | 43 | num_im = gt_classes_np[:, 0].max()+1 44 | 45 | # print("Pred inds {} pred boxes {} pred box labels {} gt classes {} gt rels {}".format( 46 | # pred_inds_np, pred_boxes_np, pred_boxlabels_np, gt_classes_np, gt_rels_np 47 | # )) 48 | 49 | rel_labels = [] 50 | fg_rel_labels = [] 51 | num_box_seen = 0 52 | for im_ind in range(num_im): 53 | pred_ind = np.where(pred_inds_np == im_ind)[0] 54 | 55 | gt_ind = np.where(gt_classes_np[:, 0] == im_ind)[0] 56 | gt_boxes_i = gt_boxes_np[gt_ind] 57 | gt_classes_i = gt_classes_np[gt_ind, 1] 58 | gt_rels_i = gt_rels_np[gt_rels_np[:, 0] == im_ind, 1:] 59 | 60 | # [num_pred, num_gt] 61 | pred_boxes_i = pred_boxes_np[pred_ind] 62 | pred_boxlabels_i = pred_boxlabels_np[pred_ind] 63 | 64 | ious = bbox_overlaps(pred_boxes_i, gt_boxes_i) 65 | is_match = (pred_boxlabels_i[:,None] == gt_classes_i[None]) & (ious >= fg_thresh) 66 | 67 | # FOR BG. Limit ourselves to only IOUs that overlap, but are not the exact same box 68 | pbi_iou = bbox_overlaps(pred_boxes_i, pred_boxes_i) 69 | if filter_non_overlap: 70 | rel_possibilities = (pbi_iou < 1) & (pbi_iou > 0) 71 | rels_intersect = rel_possibilities 72 | else: 73 | rel_possibilities = np.ones((pred_boxes_i.shape[0], pred_boxes_i.shape[0]), 74 | dtype=np.int64) - np.eye(pred_boxes_i.shape[0], 75 | dtype=np.int64) 76 | rels_intersect = (pbi_iou < 1) & (pbi_iou > 0) 77 | 78 | # ONLY select relations between ground truth because otherwise we get useless data 79 | rel_possibilities[pred_boxlabels_i == 0] = 0 80 | rel_possibilities[:, pred_boxlabels_i == 0] = 0 81 | 82 | # Sample the GT relationships. 83 | fg_rels = [] 84 | p_size = [] 85 | for i, (from_gtind, to_gtind, rel_id) in enumerate(gt_rels_i): 86 | fg_rels_i = [] 87 | fg_scores_i = [] 88 | 89 | for from_ind in np.where(is_match[:, from_gtind])[0]: 90 | for to_ind in np.where(is_match[:, to_gtind])[0]: 91 | if from_ind != to_ind: 92 | fg_rels_i.append((from_ind, to_ind, rel_id)) 93 | fg_scores_i.append((ious[from_ind, from_gtind] * ious[to_ind, to_gtind])) 94 | rel_possibilities[from_ind, to_ind] = 0 95 | if len(fg_rels_i) == 0: 96 | continue 97 | p = np.array(fg_scores_i) 98 | p = p / p.sum() 99 | p_size.append(p.shape[0]) 100 | num_to_add = min(p.shape[0], num_sample_per_gt) 101 | for rel_to_add in npr.choice(p.shape[0], p=p, size=num_to_add, replace=False): 102 | fg_rels.append(fg_rels_i[rel_to_add]) 103 | 104 | fg_rels = np.array(fg_rels, dtype=np.int64) 105 | all_fg_rels = fg_rels 106 | if fg_rels.size > 0 and fg_rels.shape[0] > fg_rels_per_image: 107 | fg_rels = fg_rels[npr.choice(fg_rels.shape[0], size=fg_rels_per_image, replace=False)] 108 | elif fg_rels.size == 0: 109 | fg_rels = np.zeros((0, 3), dtype=np.int64) 110 | all_fg_rels = np.zeros((1, 3), dtype=np.int64) 111 | 112 | 113 | bg_rels = np.column_stack(np.where(rel_possibilities)) 114 | bg_rels = np.column_stack((bg_rels, np.zeros(bg_rels.shape[0], dtype=np.int64))) 115 | 116 | num_bg_rel = min(64 - fg_rels.shape[0], bg_rels.shape[0]) 117 | if bg_rels.size > 0: 118 | # Sample 4x as many intersecting relationships as non-intersecting. 119 | # bg_rels_intersect = rels_intersect[bg_rels[:, 0], bg_rels[:, 1]] 120 | # p = bg_rels_intersect.astype(np.float32) 121 | # p[bg_rels_intersect == 0] = 0.2 122 | # p[bg_rels_intersect == 1] = 0.8 123 | # p /= p.sum() 124 | bg_rels = bg_rels[ 125 | np.random.choice(bg_rels.shape[0], 126 | #p=p, 127 | size=num_bg_rel, replace=False)] 128 | else: 129 | bg_rels = np.zeros((0, 3), dtype=np.int64) 130 | 131 | if fg_rels.size == 0 and bg_rels.size == 0: 132 | # Just put something here 133 | bg_rels = np.array([[0, 0, 0]], dtype=np.int64) 134 | 135 | # print("GTR {} -> AR {} vs {}".format(gt_rels.shape, fg_rels.shape, bg_rels.shape)) 136 | all_rels_i = np.concatenate((fg_rels, bg_rels), 0) 137 | all_rels_i[:,0:2] += num_box_seen 138 | all_fg_rels[:,0:2] += num_box_seen 139 | 140 | all_rels_i = all_rels_i[np.lexsort((all_rels_i[:,1], all_rels_i[:,0]))] 141 | all_fg_rels = all_fg_rels[np.lexsort((all_fg_rels[:,1], all_fg_rels[:,0]))] 142 | 143 | rel_labels.append(np.column_stack(( 144 | im_ind*np.ones(all_rels_i.shape[0], dtype=np.int64), 145 | all_rels_i, 146 | ))) 147 | fg_rel_labels.append(np.column_stack(( 148 | im_ind*np.ones(all_fg_rels.shape[0], dtype=np.int64), 149 | all_fg_rels, 150 | ))) 151 | num_box_seen += pred_boxes_i.shape[0] 152 | rel_labels = torch.LongTensor(np.concatenate(rel_labels, 0)).cuda(rpn_rois.get_device(), 153 | async=True) 154 | 155 | fg_rel_labels = torch.LongTensor(np.concatenate(fg_rel_labels, 0)).cuda(rpn_rois.get_device(), 156 | async=True) 157 | return rel_labels, fg_rel_labels 158 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/Makefile: -------------------------------------------------------------------------------- 1 | all: src/cuda/roi_align.cu.o 2 | python build.py 3 | 4 | src/cuda/roi_align.cu.o: src/cuda/roi_align_kernel.cu 5 | $(MAKE) -C src/cuda 6 | 7 | clean: 8 | $(MAKE) -C src/cuda clean 9 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/fpn/roi_align/__init__.py -------------------------------------------------------------------------------- /lib/fpn/roi_align/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/fpn/roi_align/_ext/__init__.py -------------------------------------------------------------------------------- /lib/fpn/roi_align/_ext/roi_align/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._roi_align import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | locals[symbol] = _wrap_function(fn, _ffi) 10 | __all__.append(symbol) 11 | 12 | _import_symbols(locals()) 13 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | # Might have to export PATH=/usr/local/cuda-8.0/bin${PATH:+:${PATH}} 5 | 6 | # sources = ['src/roi_align.c'] 7 | # headers = ['src/roi_align.h'] 8 | sources = [] 9 | headers = [] 10 | defines = [] 11 | with_cuda = False 12 | 13 | if torch.cuda.is_available(): 14 | print('Including CUDA code.') 15 | sources += ['src/roi_align_cuda.c'] 16 | headers += ['src/roi_align_cuda.h'] 17 | defines += [('WITH_CUDA', None)] 18 | with_cuda = True 19 | 20 | this_file = os.path.dirname(os.path.realpath(__file__)) 21 | print(this_file) 22 | extra_objects = ['src/cuda/roi_align.cu.o'] 23 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 24 | 25 | ffi = create_extension( 26 | '_ext.roi_align', 27 | headers=headers, 28 | sources=sources, 29 | define_macros=defines, 30 | relative_to=__file__, 31 | with_cuda=with_cuda, 32 | extra_objects=extra_objects 33 | ) 34 | 35 | if __name__ == '__main__': 36 | ffi.build() 37 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/fpn/roi_align/functions/__init__.py -------------------------------------------------------------------------------- /lib/fpn/roi_align/functions/roi_align.py: -------------------------------------------------------------------------------- 1 | """ 2 | performs ROI aligning 3 | """ 4 | 5 | import torch 6 | from torch.autograd import Function 7 | from .._ext import roi_align 8 | 9 | class RoIAlignFunction(Function): 10 | def __init__(self, aligned_height, aligned_width, spatial_scale): 11 | self.aligned_width = int(aligned_width) 12 | self.aligned_height = int(aligned_height) 13 | self.spatial_scale = float(spatial_scale) 14 | 15 | self.feature_size = None 16 | 17 | def forward(self, features, rois): 18 | self.save_for_backward(rois) 19 | 20 | rois_normalized = rois.clone() 21 | 22 | self.feature_size = features.size() 23 | batch_size, num_channels, data_height, data_width = self.feature_size 24 | 25 | height = (data_height -1) / self.spatial_scale 26 | width = (data_width - 1) / self.spatial_scale 27 | 28 | rois_normalized[:,1] /= width 29 | rois_normalized[:,2] /= height 30 | rois_normalized[:,3] /= width 31 | rois_normalized[:,4] /= height 32 | 33 | 34 | num_rois = rois.size(0) 35 | 36 | output = features.new(num_rois, num_channels, self.aligned_height, 37 | self.aligned_width).zero_() 38 | 39 | if features.is_cuda: 40 | res = roi_align.roi_align_forward_cuda(self.aligned_height, 41 | self.aligned_width, 42 | self.spatial_scale, features, 43 | rois_normalized, output) 44 | assert res == 1 45 | else: 46 | raise ValueError 47 | 48 | return output 49 | 50 | def backward(self, grad_output): 51 | assert(self.feature_size is not None and grad_output.is_cuda) 52 | 53 | rois = self.saved_tensors[0] 54 | 55 | rois_normalized = rois.clone() 56 | 57 | batch_size, num_channels, data_height, data_width = self.feature_size 58 | 59 | height = (data_height -1) / self.spatial_scale 60 | width = (data_width - 1) / self.spatial_scale 61 | 62 | rois_normalized[:,1] /= width 63 | rois_normalized[:,2] /= height 64 | rois_normalized[:,3] /= width 65 | rois_normalized[:,4] /= height 66 | 67 | grad_input = rois_normalized.new(batch_size, num_channels, data_height, 68 | data_width).zero_() 69 | res = roi_align.roi_align_backward_cuda(self.aligned_height, 70 | self.aligned_width, 71 | self.spatial_scale, grad_output, 72 | rois_normalized, grad_input) 73 | assert res == 1 74 | return grad_input, None 75 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/fpn/roi_align/modules/__init__.py -------------------------------------------------------------------------------- /lib/fpn/roi_align/modules/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.roi_align import RoIAlignFunction 4 | 5 | 6 | class RoIAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoIAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoIAlignFunction(self.aligned_height, self.aligned_width, 16 | self.spatial_scale)(features, rois) 17 | 18 | class RoIAlignAvg(Module): 19 | def __init__(self, aligned_height, aligned_width, spatial_scale): 20 | super(RoIAlignAvg, self).__init__() 21 | 22 | self.aligned_width = int(aligned_width) 23 | self.aligned_height = int(aligned_height) 24 | self.spatial_scale = float(spatial_scale) 25 | 26 | def forward(self, features, rois): 27 | x = RoIAlignFunction(self.aligned_height+1, self.aligned_width+1, 28 | self.spatial_scale)(features, rois) 29 | return avg_pool2d(x, kernel_size=2, stride=1) 30 | 31 | class RoIAlignMax(Module): 32 | def __init__(self, aligned_height, aligned_width, spatial_scale): 33 | super(RoIAlignMax, self).__init__() 34 | 35 | self.aligned_width = int(aligned_width) 36 | self.aligned_height = int(aligned_height) 37 | self.spatial_scale = float(spatial_scale) 38 | 39 | def forward(self, features, rois): 40 | x = RoIAlignFunction(self.aligned_height+1, self.aligned_width+1, 41 | self.spatial_scale)(features, rois) 42 | return max_pool2d(x, kernel_size=2, stride=1) 43 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/src/cuda/Makefile: -------------------------------------------------------------------------------- 1 | all: roi_align_kernel.cu roi_align_kernel.h 2 | /usr/local/cuda/bin/nvcc -c -o roi_align.cu.o roi_align_kernel.cu --compiler-options -fPIC -gencode arch=compute_61,code=sm_61 3 | clean: 4 | rm roi_align.cu.o 5 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/src/cuda/roi_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include 6 | #include 7 | #include 8 | #include "roi_align_kernel.h" 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 12 | i += blockDim.x * gridDim.x) 13 | 14 | 15 | __global__ void ROIAlignForward(const int nthreads, const float* image_ptr, const float* boxes_ptr, 16 | int num_boxes, int batch, int image_height, int image_width, int crop_height, 17 | int crop_width, int depth, float extrapolation_value, float* crops_ptr) { 18 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { 19 | // (n, c, ph, pw) is an element in the aligned output 20 | int idx = out_idx; 21 | const int x = idx % crop_width; 22 | idx /= crop_width; 23 | const int y = idx % crop_height; 24 | idx /= crop_height; 25 | const int d = idx % depth; 26 | const int b = idx / depth; 27 | 28 | const int b_in = int(boxes_ptr[b*5]); 29 | const float x1 = boxes_ptr[b * 5 + 1]; 30 | const float y1 = boxes_ptr[b * 5 + 2]; 31 | const float x2 = boxes_ptr[b * 5 + 3]; 32 | const float y2 = boxes_ptr[b * 5 + 4]; 33 | if (b_in < 0 || b_in >= batch) { 34 | continue; 35 | } 36 | 37 | const float height_scale = 38 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 39 | : 0; 40 | const float width_scale = 41 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 42 | 43 | const float in_y = (crop_height > 1) 44 | ? y1 * (image_height - 1) + y * height_scale 45 | : 0.5 * (y1 + y2) * (image_height - 1); 46 | if (in_y < 0 || in_y > image_height - 1) { 47 | crops_ptr[out_idx] = extrapolation_value; 48 | continue; 49 | } 50 | 51 | const float in_x = (crop_width > 1) 52 | ? x1 * (image_width - 1) + x * width_scale 53 | : 0.5 * (x1 + x2) * (image_width - 1); 54 | if (in_x < 0 || in_x > image_width - 1) { 55 | crops_ptr[out_idx] = extrapolation_value; 56 | continue; 57 | } 58 | 59 | const int top_y_index = floorf(in_y); 60 | const int bottom_y_index = ceilf(in_y); 61 | const float y_lerp = in_y - top_y_index; 62 | 63 | const int left_x_index = floorf(in_x); 64 | const int right_x_index = ceilf(in_x); 65 | const float x_lerp = in_x - left_x_index; 66 | 67 | const float top_left = image_ptr[((b_in*depth + d) * image_height 68 | + top_y_index) * image_width + left_x_index]; 69 | const float top_right = image_ptr[((b_in*depth + d) * image_height 70 | + top_y_index) * image_width + right_x_index]; 71 | const float bottom_left = image_ptr[((b_in*depth + d) * image_height 72 | + bottom_y_index) * image_width + left_x_index]; 73 | const float bottom_right = image_ptr[((b_in*depth + d) * image_height 74 | + bottom_y_index) * image_width + right_x_index]; 75 | 76 | const float top = top_left + (top_right - top_left) * x_lerp; 77 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 78 | crops_ptr[out_idx] = top + (bottom - top) * y_lerp; 79 | } 80 | } 81 | 82 | int ROIAlignForwardLaucher(const float* image_ptr, const float* boxes_ptr, 83 | int num_boxes, int batch, int image_height, int image_width, int crop_height, 84 | int crop_width, int depth, float extrapolation_value, float* crops_ptr, cudaStream_t stream) { 85 | 86 | const int kThreadsPerBlock = 1024; 87 | const int output_size = num_boxes * crop_height * crop_width * depth; 88 | cudaError_t err; 89 | 90 | ROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>> 91 | (output_size, image_ptr, boxes_ptr, num_boxes, batch, image_height, image_width, 92 | crop_height, crop_width, depth, extrapolation_value, crops_ptr); 93 | 94 | err = cudaGetLastError(); 95 | if(cudaSuccess != err) { 96 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 97 | exit( -1 ); 98 | } 99 | 100 | return 1; 101 | } 102 | 103 | __global__ void ROIAlignBackward( 104 | const int nthreads, const float* grads_ptr, const float* boxes_ptr, 105 | int num_boxes, int batch, int image_height, 106 | int image_width, int crop_height, int crop_width, int depth, 107 | float* grads_image_ptr) { 108 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { 109 | 110 | // out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 111 | int idx = out_idx; 112 | const int x = idx % crop_width; 113 | idx /= crop_width; 114 | const int y = idx % crop_height; 115 | idx /= crop_height; 116 | const int d = idx % depth; 117 | const int b = idx / depth; 118 | 119 | const int b_in = boxes_ptr[b * 5]; 120 | const float x1 = boxes_ptr[b * 5 + 1]; 121 | const float y1 = boxes_ptr[b * 5 + 2]; 122 | const float x2 = boxes_ptr[b * 5 + 3]; 123 | const float y2 = boxes_ptr[b * 5 + 4]; 124 | if (b_in < 0 || b_in >= batch) { 125 | continue; 126 | } 127 | 128 | const float height_scale = 129 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 130 | : 0; 131 | const float width_scale = 132 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 133 | 134 | const float in_y = (crop_height > 1) 135 | ? y1 * (image_height - 1) + y * height_scale 136 | : 0.5 * (y1 + y2) * (image_height - 1); 137 | if (in_y < 0 || in_y > image_height - 1) { 138 | continue; 139 | } 140 | 141 | const float in_x = (crop_width > 1) 142 | ? x1 * (image_width - 1) + x * width_scale 143 | : 0.5 * (x1 + x2) * (image_width - 1); 144 | if (in_x < 0 || in_x > image_width - 1) { 145 | continue; 146 | } 147 | 148 | const int top_y_index = floorf(in_y); 149 | const int bottom_y_index = ceilf(in_y); 150 | const float y_lerp = in_y - top_y_index; 151 | 152 | const int left_x_index = floorf(in_x); 153 | const int right_x_index = ceilf(in_x); 154 | const float x_lerp = in_x - left_x_index; 155 | 156 | const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; 157 | atomicAdd( 158 | grads_image_ptr + ((b_in*depth + d)*image_height + top_y_index) * image_width + left_x_index, 159 | (1 - x_lerp) * dtop); 160 | atomicAdd(grads_image_ptr + 161 | ((b_in * depth + d)*image_height+top_y_index)*image_width + right_x_index, 162 | x_lerp * dtop); 163 | 164 | const float dbottom = y_lerp * grads_ptr[out_idx]; 165 | atomicAdd(grads_image_ptr + ((b_in*depth+d)*image_height+bottom_y_index)*image_width+left_x_index, 166 | (1 - x_lerp) * dbottom); 167 | atomicAdd(grads_image_ptr + ((b_in*depth+d)*image_height+bottom_y_index)*image_width+right_x_index, 168 | x_lerp * dbottom); 169 | } 170 | } 171 | 172 | int ROIAlignBackwardLaucher(const float* grads_ptr, const float* boxes_ptr, int num_boxes, 173 | int batch, int image_height, int image_width, int crop_height, int crop_width, int depth, 174 | float* grads_image_ptr, cudaStream_t stream) { 175 | const int kThreadsPerBlock = 1024; 176 | const int output_size = num_boxes * crop_height * crop_width * depth; 177 | cudaError_t err; 178 | 179 | ROIAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>> 180 | (output_size, grads_ptr, boxes_ptr, num_boxes, batch, image_height, image_width, crop_height, 181 | crop_width, depth, grads_image_ptr); 182 | 183 | err = cudaGetLastError(); 184 | if(cudaSuccess != err) { 185 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 186 | exit( -1 ); 187 | } 188 | 189 | return 1; 190 | } 191 | 192 | 193 | #ifdef __cplusplus 194 | } 195 | #endif 196 | 197 | 198 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/src/cuda/roi_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_ALIGN_KERNEL 2 | #define _ROI_ALIGN_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | __global__ void ROIAlignForward(const int nthreads, const float* image_ptr, const float* boxes_ptr, int num_boxes, int batch, int image_height, int image_width, int crop_height, 9 | int crop_width, int depth, float extrapolation_value, float* crops_ptr); 10 | 11 | int ROIAlignForwardLaucher( 12 | const float* image_ptr, const float* boxes_ptr, 13 | int num_boxes, int batch, int image_height, int image_width, int crop_height, 14 | int crop_width, int depth, float extrapolation_value, float* crops_ptr, cudaStream_t stream); 15 | 16 | __global__ void ROIAlignBackward(const int nthreads, const float* grads_ptr, 17 | const float* boxes_ptr, int num_boxes, int batch, int image_height, 18 | int image_width, int crop_height, int crop_width, int depth, 19 | float* grads_image_ptr); 20 | 21 | int ROIAlignBackwardLaucher(const float* grads_ptr, const float* boxes_ptr, int num_boxes, 22 | int batch, int image_height, int image_width, int crop_height, 23 | int crop_width, int depth, float* grads_image_ptr, cudaStream_t stream); 24 | 25 | #ifdef __cplusplus 26 | } 27 | #endif 28 | 29 | #endif 30 | 31 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/src/roi_align_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda/roi_align_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int roi_align_forward_cuda(int crop_height, int crop_width, float spatial_scale, 8 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output) 9 | { 10 | // Grab the input tensor 11 | float * image_ptr = THCudaTensor_data(state, features); 12 | float * boxes_ptr = THCudaTensor_data(state, rois); 13 | 14 | float * crops_ptr = THCudaTensor_data(state, output); 15 | 16 | // Number of ROIs 17 | int num_boxes = THCudaTensor_size(state, rois, 0); 18 | int size_rois = THCudaTensor_size(state, rois, 1); 19 | if (size_rois != 5) 20 | { 21 | return 0; 22 | } 23 | 24 | // batch size 25 | int batch = THCudaTensor_size(state, features, 0); 26 | // data height 27 | int image_height = THCudaTensor_size(state, features, 2); 28 | // data width 29 | int image_width = THCudaTensor_size(state, features, 3); 30 | // Number of channels 31 | int depth = THCudaTensor_size(state, features, 1); 32 | 33 | cudaStream_t stream = THCState_getCurrentStream(state); 34 | float extrapolation_value = 0.0; 35 | 36 | ROIAlignForwardLaucher( 37 | image_ptr, boxes_ptr, num_boxes, batch, image_height, image_width, 38 | crop_height, crop_width, depth, extrapolation_value, crops_ptr, 39 | stream); 40 | 41 | return 1; 42 | } 43 | 44 | int roi_align_backward_cuda(int crop_height, int crop_width, float spatial_scale, 45 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad) 46 | { 47 | // Grab the input tensor 48 | float * grads_ptr = THCudaTensor_data(state, top_grad); 49 | float * boxes_ptr = THCudaTensor_data(state, rois); 50 | 51 | float * grads_image_ptr = THCudaTensor_data(state, bottom_grad); 52 | 53 | // Number of ROIs 54 | int num_boxes = THCudaTensor_size(state, rois, 0); 55 | int size_rois = THCudaTensor_size(state, rois, 1); 56 | if (size_rois != 5) 57 | { 58 | return 0; 59 | } 60 | 61 | // batch size 62 | int batch = THCudaTensor_size(state, bottom_grad, 0); 63 | // data height 64 | int image_height = THCudaTensor_size(state, bottom_grad, 2); 65 | // data width 66 | int image_width = THCudaTensor_size(state, bottom_grad, 3); 67 | // Number of channels 68 | int depth = THCudaTensor_size(state, bottom_grad, 1); 69 | 70 | cudaStream_t stream = THCState_getCurrentStream(state); 71 | 72 | ROIAlignBackwardLaucher( 73 | grads_ptr, boxes_ptr, num_boxes, batch, image_height, image_width, 74 | crop_height, crop_width, depth, grads_image_ptr, stream); 75 | return 1; 76 | } 77 | -------------------------------------------------------------------------------- /lib/fpn/roi_align/src/roi_align_cuda.h: -------------------------------------------------------------------------------- 1 | int roi_align_forward_cuda(int crop_height, int crop_width, float spatial_scale, 2 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output); 3 | 4 | int roi_align_backward_cuda(int crop_height, int crop_width, float spatial_scale, 5 | THCudaTensor * top_grad, THCudaTensor * rois, 6 | THCudaTensor * bottom_grad); 7 | -------------------------------------------------------------------------------- /lib/get_dataset_counts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get counts of all of the examples in the dataset. Used for creating the baseline 3 | dictionary model 4 | """ 5 | 6 | import numpy as np 7 | from dataloaders.visual_genome import VG 8 | from lib.fpn.box_intersections_cpu.bbox import bbox_overlaps 9 | from lib.pytorch_misc import nonintersecting_2d_inds 10 | 11 | 12 | def get_counts(train_data=VG(mode='train', filter_duplicate_rels=False, num_val_im=5000), must_overlap=True): 13 | """ 14 | Get counts of all of the relations. Used for modeling directly P(rel | o1, o2) 15 | :param train_data: 16 | :param must_overlap: 17 | :return: 18 | """ 19 | fg_matrix = np.zeros(( 20 | train_data.num_classes, 21 | train_data.num_classes, 22 | train_data.num_predicates, 23 | ), dtype=np.int64) 24 | 25 | bg_matrix = np.zeros(( 26 | train_data.num_classes, 27 | train_data.num_classes, 28 | ), dtype=np.int64) 29 | 30 | for ex_ind in range(len(train_data)): 31 | gt_classes = train_data.gt_classes[ex_ind].copy() 32 | gt_relations = train_data.relationships[ex_ind].copy() 33 | gt_boxes = train_data.gt_boxes[ex_ind].copy() 34 | 35 | # For the foreground, we'll just look at everything 36 | o1o2 = gt_classes[gt_relations[:, :2]] 37 | for (o1, o2), gtr in zip(o1o2, gt_relations[:,2]): 38 | fg_matrix[o1, o2, gtr] += 1 39 | 40 | # For the background, get all of the things that overlap. 41 | o1o2_total = gt_classes[np.array( 42 | box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)] 43 | for (o1, o2) in o1o2_total: 44 | bg_matrix[o1, o2] += 1 45 | 46 | return fg_matrix, bg_matrix 47 | 48 | 49 | def box_filter(boxes, must_overlap=False): 50 | """ Only include boxes that overlap as possible relations. 51 | If no overlapping boxes, use all of them.""" 52 | n_cands = boxes.shape[0] 53 | 54 | overlaps = bbox_overlaps(boxes.astype(np.float), boxes.astype(np.float)) > 0 55 | np.fill_diagonal(overlaps, 0) 56 | 57 | all_possib = np.ones_like(overlaps, dtype=np.bool) 58 | np.fill_diagonal(all_possib, 0) 59 | 60 | if must_overlap: 61 | possible_boxes = np.column_stack(np.where(overlaps)) 62 | 63 | if possible_boxes.size == 0: 64 | possible_boxes = np.column_stack(np.where(all_possib)) 65 | else: 66 | possible_boxes = np.column_stack(np.where(all_possib)) 67 | return possible_boxes 68 | 69 | if __name__ == '__main__': 70 | fg, bg = get_counts(must_overlap=False) 71 | -------------------------------------------------------------------------------- /lib/get_union_boxes.py: -------------------------------------------------------------------------------- 1 | """ 2 | credits to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/network.py#L91 3 | """ 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from lib.fpn.roi_align.functions.roi_align import RoIAlignFunction 9 | from lib.draw_rectangles.draw_rectangles import draw_union_boxes 10 | import numpy as np 11 | from torch.nn.modules.module import Module 12 | from torch import nn 13 | from config import BATCHNORM_MOMENTUM 14 | 15 | class UnionBoxesAndFeats(Module): 16 | def __init__(self, pooling_size=7, stride=16, dim=256, concat=False, use_feats=True): 17 | """ 18 | :param pooling_size: Pool the union boxes to this dimension 19 | :param stride: pixel spacing in the entire image 20 | :param dim: Dimension of the feats 21 | :param concat: Whether to concat (yes) or add (False) the representations 22 | """ 23 | super(UnionBoxesAndFeats, self).__init__() 24 | 25 | self.pooling_size = pooling_size 26 | self.stride = stride 27 | 28 | self.dim = dim 29 | self.use_feats = use_feats 30 | 31 | self.conv = nn.Sequential( 32 | nn.Conv2d(2, dim //2, kernel_size=7, stride=2, padding=3, bias=True), 33 | nn.ReLU(inplace=True), 34 | nn.BatchNorm2d(dim//2, momentum=BATCHNORM_MOMENTUM), 35 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 36 | nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1, bias=True), 37 | nn.ReLU(inplace=True), 38 | nn.BatchNorm2d(dim, momentum=BATCHNORM_MOMENTUM), 39 | ) 40 | self.concat = concat 41 | 42 | def forward(self, fmap, rois, union_inds): 43 | union_pools = union_boxes(fmap, rois, union_inds, pooling_size=self.pooling_size, stride=self.stride) 44 | if not self.use_feats: 45 | return union_pools.detach() 46 | 47 | pair_rois = torch.cat((rois[:, 1:][union_inds[:, 0]], rois[:, 1:][union_inds[:, 1]]),1).data.cpu().numpy() 48 | # rects_np = get_rect_features(pair_rois, self.pooling_size*2-1) - 0.5 49 | rects_np = draw_union_boxes(pair_rois, self.pooling_size*4-1) - 0.5 50 | rects = Variable(torch.FloatTensor(rects_np).cuda(fmap.get_device()), volatile=fmap.volatile) 51 | if self.concat: 52 | return torch.cat((union_pools, self.conv(rects)), 1) 53 | return union_pools + self.conv(rects) 54 | 55 | # def get_rect_features(roi_pairs, pooling_size): 56 | # rects_np = draw_union_boxes(roi_pairs, pooling_size) 57 | # # add union + intersection 58 | # stuff_to_cat = [ 59 | # rects_np.max(1), 60 | # rects_np.min(1), 61 | # np.minimum(1-rects_np[:,0], rects_np[:,1]), 62 | # np.maximum(1-rects_np[:,0], rects_np[:,1]), 63 | # np.minimum(rects_np[:,0], 1-rects_np[:,1]), 64 | # np.maximum(rects_np[:,0], 1-rects_np[:,1]), 65 | # np.minimum(1-rects_np[:,0], 1-rects_np[:,1]), 66 | # np.maximum(1-rects_np[:,0], 1-rects_np[:,1]), 67 | # ] 68 | # rects_np = np.concatenate([rects_np] + [x[:,None] for x in stuff_to_cat], 1) 69 | # return rects_np 70 | 71 | 72 | def union_boxes(fmap, rois, union_inds, pooling_size=14, stride=16): 73 | """ 74 | :param fmap: (batch_size, d, IM_SIZE/stride, IM_SIZE/stride) 75 | :param rois: (num_rois, 5) with [im_ind, x1, y1, x2, y2] 76 | :param union_inds: (num_urois, 2) with [roi_ind1, roi_ind2] 77 | :param pooling_size: we'll resize to this 78 | :param stride: 79 | :return: 80 | """ 81 | assert union_inds.size(1) == 2 82 | im_inds = rois[:,0][union_inds[:,0]] 83 | assert (im_inds.data == rois.data[:,0][union_inds[:,1]]).sum() == union_inds.size(0) 84 | union_rois = torch.cat(( 85 | im_inds[:,None], 86 | torch.min(rois[:, 1:3][union_inds[:, 0]], rois[:, 1:3][union_inds[:, 1]]), 87 | torch.max(rois[:, 3:5][union_inds[:, 0]], rois[:, 3:5][union_inds[:, 1]]), 88 | ),1) 89 | 90 | # (num_rois, d, pooling_size, pooling_size) 91 | union_pools = RoIAlignFunction(pooling_size, pooling_size, 92 | spatial_scale=1/stride)(fmap, union_rois) 93 | return union_pools 94 | 95 | -------------------------------------------------------------------------------- /lib/lstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/lstm/__init__.py -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/lstm/highway_lstm_cuda/__init__.py -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/lstm/highway_lstm_cuda/_ext/__init__.py -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/_ext/highway_lstm_layer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._highway_lstm_layer import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | locals[symbol] = _wrap_function(fn, _ffi) 10 | __all__.append(symbol) 11 | 12 | _import_symbols(locals()) 13 | -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/build.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | import os 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | if not torch.cuda.is_available(): 7 | raise Exception('HighwayLSTM can only be compiled with CUDA') 8 | 9 | sources = ['src/highway_lstm_cuda.c'] 10 | headers = ['src/highway_lstm_cuda.h'] 11 | defines = [('WITH_CUDA', None)] 12 | with_cuda = True 13 | 14 | this_file = os.path.dirname(os.path.realpath(__file__)) 15 | extra_objects = ['src/highway_lstm_kernel.cu.o'] 16 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 17 | 18 | ffi = create_extension( 19 | '_ext.highway_lstm_layer', 20 | headers=headers, 21 | sources=sources, 22 | define_macros=defines, 23 | relative_to=__file__, 24 | with_cuda=with_cuda, 25 | extra_objects=extra_objects 26 | ) 27 | 28 | if __name__ == '__main__': 29 | ffi.build() 30 | -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_PATH=/usr/local/cuda/ 4 | 5 | # Which CUDA capabilities do we want to pre-build for? 6 | # https://developer.nvidia.com/cuda-gpus 7 | # Compute/shader model Cards 8 | # 61 P4, P40, Titan X 9 | # 60 P100 10 | # 52 M40 11 | # 37 K80 12 | # 35 K40, K20 13 | # 30 K10, Grid K520 (AWS G2) 14 | 15 | CUDA_MODELS=(52 61) 16 | 17 | # Nvidia doesn't guarantee binary compatability across GPU versions. 18 | # However, binary compatibility within one GPU generation can be guaranteed 19 | # under certain conditions because they share the basic instruction set. 20 | # This is the case between two GPU versions that do not show functional 21 | # differences at all (for instance when one version is a scaled down version 22 | # of the other), or when one version is functionally included in the other. 23 | 24 | # To fix this problem, we can create a 'fat binary' which generates multiple 25 | # translations of the CUDA source. The most appropriate version is chosen at 26 | # runtime by the CUDA driver. See: 27 | # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-compilation 28 | # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#fatbinaries 29 | CUDA_MODEL_TARGETS="" 30 | for i in "${CUDA_MODELS[@]}" 31 | do 32 | CUDA_MODEL_TARGETS+=" -gencode arch=compute_${i},code=sm_${i}" 33 | done 34 | 35 | echo "Building kernel for following target architectures: " 36 | echo $CUDA_MODEL_TARGETS 37 | 38 | cd src 39 | echo "Compiling kernel" 40 | /usr/local/cuda/bin/nvcc -c -o highway_lstm_kernel.cu.o highway_lstm_kernel.cu --compiler-options -fPIC $CUDA_MODEL_TARGETS 41 | cd ../ 42 | python build.py 43 | -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/src/highway_lstm_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "highway_lstm_kernel.h" 3 | 4 | extern THCState *state; 5 | 6 | int highway_lstm_forward_cuda(int inputSize, int hiddenSize, int miniBatch, 7 | int numLayers, int seqLength, 8 | THCudaTensor *x, 9 | THIntTensor *lengths, 10 | THCudaTensor *h_data, 11 | THCudaTensor *c_data, 12 | THCudaTensor *tmp_i, 13 | THCudaTensor *tmp_h, 14 | THCudaTensor *T, 15 | THCudaTensor *bias, 16 | THCudaTensor *dropout, 17 | THCudaTensor *gates, 18 | int isTraining) { 19 | 20 | float * x_ptr = THCudaTensor_data(state, x); 21 | int * lengths_ptr = THIntTensor_data(lengths); 22 | float * h_data_ptr = THCudaTensor_data(state, h_data); 23 | float * c_data_ptr = THCudaTensor_data(state, c_data); 24 | float * tmp_i_ptr = THCudaTensor_data(state, tmp_i); 25 | float * tmp_h_ptr = THCudaTensor_data(state, tmp_h); 26 | float * T_ptr = THCudaTensor_data(state, T); 27 | float * bias_ptr = THCudaTensor_data(state, bias); 28 | float * dropout_ptr = THCudaTensor_data(state, dropout); 29 | float * gates_ptr; 30 | if (isTraining == 1) { 31 | gates_ptr = THCudaTensor_data(state, gates); 32 | } else { 33 | gates_ptr = NULL; 34 | } 35 | 36 | cudaStream_t stream = THCState_getCurrentStream(state); 37 | cublasHandle_t handle = THCState_getCurrentBlasHandle(state); 38 | 39 | highway_lstm_forward_ongpu(inputSize, hiddenSize, miniBatch, numLayers, 40 | seqLength, x_ptr, lengths_ptr, h_data_ptr, c_data_ptr, tmp_i_ptr, 41 | tmp_h_ptr, T_ptr, bias_ptr, dropout_ptr, gates_ptr, 42 | isTraining, stream, handle); 43 | 44 | return 1; 45 | 46 | } 47 | 48 | int highway_lstm_backward_cuda(int inputSize, int hiddenSize, int miniBatch, int numLayers, int seqLength, 49 | THCudaTensor *out_grad, 50 | THIntTensor *lengths, 51 | THCudaTensor *h_data_grad, 52 | THCudaTensor *c_data_grad, 53 | THCudaTensor *x, 54 | THCudaTensor *h_data, 55 | THCudaTensor *c_data, 56 | THCudaTensor *T, 57 | THCudaTensor *gates_out, 58 | THCudaTensor *dropout_in, 59 | THCudaTensor *h_gates_grad, 60 | THCudaTensor *i_gates_grad, 61 | THCudaTensor *h_out_grad, 62 | THCudaTensor *x_grad, 63 | THCudaTensor *T_grad, 64 | THCudaTensor *bias_grad, 65 | int isTraining, 66 | int do_weight_grad) { 67 | 68 | float * out_grad_ptr = THCudaTensor_data(state, out_grad); 69 | int * lengths_ptr = THIntTensor_data(lengths); 70 | float * h_data_grad_ptr = THCudaTensor_data(state, h_data_grad); 71 | float * c_data_grad_ptr = THCudaTensor_data(state, c_data_grad); 72 | float * x_ptr = THCudaTensor_data(state, x); 73 | float * h_data_ptr = THCudaTensor_data(state, h_data); 74 | float * c_data_ptr = THCudaTensor_data(state, c_data); 75 | float * T_ptr = THCudaTensor_data(state, T); 76 | float * gates_out_ptr = THCudaTensor_data(state, gates_out); 77 | float * dropout_in_ptr = THCudaTensor_data(state, dropout_in); 78 | float * h_gates_grad_ptr = THCudaTensor_data(state, h_gates_grad); 79 | float * i_gates_grad_ptr = THCudaTensor_data(state, i_gates_grad); 80 | float * h_out_grad_ptr = THCudaTensor_data(state, h_out_grad); 81 | float * x_grad_ptr = THCudaTensor_data(state, x_grad); 82 | float * T_grad_ptr = THCudaTensor_data(state, T_grad); 83 | float * bias_grad_ptr = THCudaTensor_data(state, bias_grad); 84 | 85 | cudaStream_t stream = THCState_getCurrentStream(state); 86 | cublasHandle_t handle = THCState_getCurrentBlasHandle(state); 87 | 88 | highway_lstm_backward_ongpu(inputSize, hiddenSize, miniBatch, numLayers, 89 | seqLength, out_grad_ptr, lengths_ptr, h_data_grad_ptr, c_data_grad_ptr, 90 | x_ptr, h_data_ptr, c_data_ptr, T_ptr, gates_out_ptr, dropout_in_ptr, 91 | h_gates_grad_ptr, i_gates_grad_ptr, h_out_grad_ptr, 92 | x_grad_ptr, T_grad_ptr, bias_grad_ptr, isTraining, do_weight_grad, 93 | stream, handle); 94 | 95 | return 1; 96 | 97 | } 98 | -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/src/highway_lstm_cuda.h: -------------------------------------------------------------------------------- 1 | int highway_lstm_forward_cuda(int inputSize, int hiddenSize, int miniBatch, int numLayers, int seqLength, 2 | THCudaTensor *x, THIntTensor *lengths, THCudaTensor *h_data, 3 | THCudaTensor *c_data, THCudaTensor *tmp_i, 4 | THCudaTensor *tmp_h, THCudaTensor *T, THCudaTensor *bias, 5 | THCudaTensor *dropout, THCudaTensor *gates, int isTraining); 6 | 7 | int highway_lstm_backward_cuda(int inputSize, int hiddenSize, int miniBatch, 8 | int numLayers, int seqLength, THCudaTensor *out_grad, THIntTensor *lengths, 9 | THCudaTensor *h_data_grad, THCudaTensor *c_data_grad, THCudaTensor *x, 10 | THCudaTensor *h_data, THCudaTensor *c_data, THCudaTensor *T, 11 | THCudaTensor *gates_out, THCudaTensor *dropout_in, 12 | THCudaTensor *h_gates_grad, THCudaTensor *i_gates_grad, 13 | THCudaTensor *h_out_grad, THCudaTensor *x_grad, THCudaTensor *T_grad, 14 | THCudaTensor *bias_grad, int isTraining, int do_weight_grad); 15 | -------------------------------------------------------------------------------- /lib/lstm/highway_lstm_cuda/src/highway_lstm_kernel.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void highway_lstm_forward_ongpu(int inputSize, int hiddenSize, int miniBatch, int numLayers, int seqLength, float *x, int *lengths, float*h_data, float *c_data, float *tmp_i, float *tmp_h, float *T, float *bias, float *dropout, float *gates, int is_training, cudaStream_t stream, cublasHandle_t handle); 8 | 9 | void highway_lstm_backward_ongpu(int inputSize, int hiddenSize, int miniBatch, int numLayers, int seqLength, float *out_grad, int *lengths, float *h_data_grad, float *c_data_grad, float *x, float *h_data, float *c_data, float *T, float *gates_out, float *dropout_in, float *h_gates_grad, float *i_gates_grad, float *h_out_grad, float *x_grad, float *T_grad, float *bias_grad, int isTraining, int do_weight_grad, cudaStream_t stream, cublasHandle_t handle); 10 | 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | -------------------------------------------------------------------------------- /lib/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from torchvision.models.resnet import model_urls, conv3x3, BasicBlock 5 | from torchvision.models.vgg import vgg16 6 | from config import BATCHNORM_MOMENTUM 7 | 8 | class Bottleneck(nn.Module): 9 | expansion = 4 10 | 11 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_end=True): 12 | super(Bottleneck, self).__init__() 13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes, momentum=BATCHNORM_MOMENTUM) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes, momentum=BATCHNORM_MOMENTUM) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = nn.BatchNorm2d(planes * 4, momentum=BATCHNORM_MOMENTUM) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.downsample = downsample 22 | self.stride = stride 23 | self.relu_end = relu_end 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv3(out) 37 | out = self.bn3(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | 44 | if self.relu_end: 45 | out = self.relu(out) 46 | return out 47 | 48 | 49 | class ResNet(nn.Module): 50 | 51 | def __init__(self, block, layers, num_classes=1000): 52 | self.inplanes = 64 53 | super(ResNet, self).__init__() 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 55 | bias=False) 56 | self.bn1 = nn.BatchNorm2d(64, momentum=BATCHNORM_MOMENTUM) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 59 | self.layer1 = self._make_layer(block, 64, layers[0]) 60 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 61 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 62 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) # HACK 63 | self.avgpool = nn.AvgPool2d(7) 64 | self.fc = nn.Linear(512 * block.expansion, num_classes) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def _make_layer(self, block, planes, blocks, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | nn.Conv2d(self.inplanes, planes * block.expansion, 79 | kernel_size=1, stride=stride, bias=False), 80 | nn.BatchNorm2d(planes * block.expansion, momentum=BATCHNORM_MOMENTUM), 81 | ) 82 | 83 | layers = [] 84 | layers.append(block(self.inplanes, planes, stride, downsample)) 85 | self.inplanes = planes * block.expansion 86 | for i in range(1, blocks): 87 | layers.append(block(self.inplanes, planes)) 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | x = self.relu(x) 95 | x = self.maxpool(x) 96 | 97 | x = self.layer1(x) 98 | x = self.layer2(x) 99 | x = self.layer3(x) 100 | x = self.layer4(x) 101 | 102 | x = self.avgpool(x) 103 | x = x.view(x.size(0), -1) 104 | x = self.fc(x) 105 | 106 | return x 107 | 108 | def resnet101(pretrained=False, **kwargs): 109 | """Constructs a ResNet-101 model. 110 | 111 | Args: 112 | pretrained (bool): If True, returns a model pre-trained on ImageNet 113 | """ 114 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 115 | if pretrained: 116 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 117 | return model 118 | 119 | def resnet_l123(): 120 | model = resnet101(pretrained=True) 121 | del model.layer4 122 | del model.avgpool 123 | del model.fc 124 | return model 125 | 126 | def resnet_l4(relu_end=True): 127 | model = resnet101(pretrained=True) 128 | l4 = model.layer4 129 | if not relu_end: 130 | l4[-1].relu_end = False 131 | l4[0].conv2.stride = (1, 1) 132 | l4[0].downsample[0].stride = (1, 1) 133 | return l4 134 | 135 | def vgg_fc(relu_end=True, linear_end=True): 136 | model = vgg16(pretrained=True) 137 | vfc = model.classifier 138 | del vfc._modules['6'] # Get rid of linear layer 139 | del vfc._modules['5'] # Get rid of linear layer 140 | if not relu_end: 141 | del vfc._modules['4'] # Get rid of linear layer 142 | if not linear_end: 143 | del vfc._modules['3'] 144 | return vfc 145 | 146 | 147 | -------------------------------------------------------------------------------- /lib/sparse_targets.py: -------------------------------------------------------------------------------- 1 | from lib.word_vectors import obj_edge_vectors 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from config import DATA_PATH 7 | import os 8 | from lib.get_dataset_counts import get_counts 9 | 10 | 11 | class FrequencyBias(nn.Module): 12 | """ 13 | The goal of this is to provide a simplified way of computing 14 | P(predicate | obj1, obj2, img). 15 | """ 16 | 17 | def __init__(self, eps=1e-3): 18 | super(FrequencyBias, self).__init__() 19 | 20 | fg_matrix, bg_matrix = get_counts(must_overlap=True) 21 | bg_matrix += 1 22 | fg_matrix[:, :, 0] = bg_matrix 23 | 24 | pred_dist = np.log(fg_matrix / fg_matrix.sum(2)[:, :, None] + eps) 25 | 26 | self.num_objs = pred_dist.shape[0] 27 | pred_dist = torch.FloatTensor(pred_dist).view(-1, pred_dist.shape[2]) 28 | 29 | self.obj_baseline = nn.Embedding(pred_dist.size(0), pred_dist.size(1)) 30 | self.obj_baseline.weight.data = pred_dist 31 | 32 | def index_with_labels(self, labels): 33 | """ 34 | :param labels: [batch_size, 2] 35 | :return: 36 | """ 37 | return self.obj_baseline(labels[:, 0] * self.num_objs + labels[:, 1]) 38 | 39 | def forward(self, obj_cands0, obj_cands1): 40 | """ 41 | :param obj_cands0: [batch_size, 151] prob distibution over cands. 42 | :param obj_cands1: [batch_size, 151] prob distibution over cands. 43 | :return: [batch_size, #predicates] array, which contains potentials for 44 | each possibility 45 | """ 46 | # [batch_size, 151, 151] repr of the joint distribution 47 | joint_cands = obj_cands0[:, :, None] * obj_cands1[:, None] 48 | 49 | # [151, 151, 51] of targets per. 50 | baseline = joint_cands.view(joint_cands.size(0), -1) @ self.obj_baseline.weight 51 | 52 | return baseline 53 | -------------------------------------------------------------------------------- /lib/surgery.py: -------------------------------------------------------------------------------- 1 | # create predictions from the other stuff 2 | """ 3 | Go from proposals + scores to relationships. 4 | 5 | pred-cls: No bbox regression, obj dist is exactly known 6 | sg-cls : No bbox regression 7 | sg-det : Bbox regression 8 | 9 | in all cases we'll return: 10 | boxes, objs, rels, pred_scores 11 | 12 | """ 13 | 14 | import numpy as np 15 | import torch 16 | from lib.pytorch_misc import unravel_index 17 | from lib.fpn.box_utils import bbox_overlaps 18 | # from ad3 import factor_graph as fg 19 | from time import time 20 | 21 | def filter_dets(boxes, obj_scores, obj_classes, rel_inds, pred_scores, gt_boxes, gt_classes, gt_rels): 22 | """ 23 | Filters detections.... 24 | :param boxes: [num_box, topk, 4] if bbox regression else [num_box, 4] 25 | :param obj_scores: [num_box] probabilities for the scores 26 | :param obj_classes: [num_box] class labels for the topk 27 | :param rel_inds: [num_rel, 2] TENSOR consisting of (im_ind0, im_ind1) 28 | :param pred_scores: [topk, topk, num_rel, num_predicates] 29 | :param use_nms: True if use NMS to filter dets. 30 | :return: boxes, objs, rels, pred_scores 31 | 32 | """ 33 | if boxes.dim() != 2: 34 | raise ValueError("Boxes needs to be [num_box, 4] but its {}".format(boxes.size())) 35 | 36 | num_box = boxes.size(0) 37 | assert obj_scores.size(0) == num_box 38 | 39 | assert obj_classes.size() == obj_scores.size() 40 | num_rel = rel_inds.size(0) 41 | assert rel_inds.size(1) == 2 42 | assert pred_scores.size(0) == num_rel 43 | 44 | obj_scores0 = obj_scores.data[rel_inds[:,0]] 45 | obj_scores1 = obj_scores.data[rel_inds[:,1]] 46 | 47 | pred_scores_max, pred_classes_argmax = pred_scores.data[:,1:].max(1) 48 | pred_classes_argmax = pred_classes_argmax + 1 49 | 50 | rel_scores_argmaxed = pred_scores_max * obj_scores0 * obj_scores1 51 | rel_scores_vs, rel_scores_idx = torch.sort(rel_scores_argmaxed.view(-1), dim=0, descending=True) 52 | 53 | rels = rel_inds[rel_scores_idx].cpu().numpy() 54 | pred_scores_sorted = pred_scores[rel_scores_idx].data.cpu().numpy() 55 | obj_scores_np = obj_scores.data.cpu().numpy() 56 | objs_np = obj_classes.data.cpu().numpy() 57 | boxes_out = boxes.data.cpu().numpy() 58 | 59 | return boxes_out, objs_np, obj_scores_np, rels, pred_scores_sorted, gt_boxes, gt_classes, gt_rels 60 | 61 | # def _get_similar_boxes(boxes, obj_classes_topk, nms_thresh=0.3): 62 | # """ 63 | # Assuming bg is NOT A LABEL. 64 | # :param boxes: [num_box, topk, 4] if bbox regression else [num_box, 4] 65 | # :param obj_classes: [num_box, topk] class labels 66 | # :return: num_box, topk, num_box, topk array containing similarities. 67 | # """ 68 | # topk = obj_classes_topk.size(1) 69 | # num_box = boxes.size(0) 70 | # 71 | # box_flat = boxes.view(-1, 4) if boxes.dim() == 3 else boxes[:, None].expand( 72 | # num_box, topk, 4).contiguous().view(-1, 4) 73 | # jax = bbox_overlaps(box_flat, box_flat).data > nms_thresh 74 | # # Filter out things that are not gonna compete. 75 | # classes_eq = obj_classes_topk.data.view(-1)[:, None] == obj_classes_topk.data.view(-1)[None, :] 76 | # jax &= classes_eq 77 | # boxes_are_similar = jax.view(num_box, topk, num_box, topk) 78 | # return boxes_are_similar.cpu().numpy().astype(np.bool) 79 | -------------------------------------------------------------------------------- /lib/tree_lstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Scene-Graph-Generation/1d7983123c534f72786b241046f8d2a007756d31/lib/tree_lstm/__init__.py -------------------------------------------------------------------------------- /lib/tree_lstm/decoder_tree_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | from typing import Optional, Tuple 7 | 8 | from lib.tree_lstm import tree_utils 9 | from lib.tree_lstm.tree_lstm import BiTreeLSTM_Backward, BiTreeLSTM_Foreward 10 | from lib.fpn.box_utils import nms_overlaps 11 | from lib.word_vectors import obj_edge_vectors 12 | from lib.lstm.highway_lstm_cuda.alternating_highway_lstm import block_orthogonal 13 | import numpy as np 14 | 15 | class DecoderTreeLSTM(torch.nn.Module): 16 | def __init__(self, classes, embed_dim, inputs_dim, hidden_dim, direction='backward', dropout=0.2, pass_root = False, not_rl = True): 17 | super(DecoderTreeLSTM, self).__init__() 18 | """ 19 | Initializes the RNN 20 | :param embed_dim: Dimension of the embeddings 21 | :param encoder_hidden_dim: Hidden dim of the encoder, for attention purposes 22 | :param hidden_dim: Hidden dim of the decoder 23 | :param vocab_size: Number of words in the vocab 24 | :param bos_token: To use during decoding (non teacher forcing mode)) 25 | :param bos: beginning of sentence token 26 | :param unk: unknown token (not used) 27 | direction = foreward | backward 28 | """ 29 | self.classes = classes 30 | self.hidden_size = hidden_dim 31 | self.inputs_dim = inputs_dim 32 | self.nms_thresh = 0.5 33 | self.dropout = dropout 34 | self.pass_root = pass_root 35 | # generate embed layer 36 | embed_vecs = obj_edge_vectors(['start'] + self.classes, wv_dim=embed_dim) 37 | self.obj_embed = nn.Embedding(len(self.classes), embed_dim) 38 | self.obj_embed.weight.data = embed_vecs 39 | # generate out layer 40 | self.out = nn.Linear(self.hidden_size, len(self.classes)) 41 | self.out.weight = torch.nn.init.xavier_normal(self.out.weight, gain=1.0) 42 | self.out.bias.data.fill_(0.0) 43 | self.not_rl = not_rl 44 | 45 | if direction == 'backward': 46 | self.input_size = inputs_dim + embed_dim 47 | self.decoderLSTM = BiTreeLSTM_Backward(self.input_size, self.hidden_size, self.pass_root, is_pass_embed=True, embed_layer=self.obj_embed, embed_out_layer=self.out, not_rl=not_rl) 48 | elif direction == 'foreward': 49 | self.input_size = inputs_dim + embed_dim * 2 50 | self.decoderLSTM = BiTreeLSTM_Foreward(self.input_size, self.hidden_size, self.pass_root, is_pass_embed=True, embed_layer=self.obj_embed, embed_out_layer=self.out, not_rl=not_rl) 51 | else: 52 | print('Error Decoder LSTM Direction') 53 | 54 | 55 | def forward(self, forest, features, num_obj, labels=None, boxes_for_nms=None, batch_size=0): 56 | # generate dropout 57 | if self.dropout > 0.0: 58 | dropout_mask = get_dropout_mask(self.dropout, self.hidden_size) 59 | else: 60 | dropout_mask = None 61 | 62 | # generate tree lstm input/output class 63 | out_h = None 64 | out_dists = None 65 | out_commitments = None 66 | h_order = Variable(torch.LongTensor(num_obj).zero_().cuda()) 67 | order_idx = 0 68 | lstm_io = tree_utils.TreeLSTM_IO(out_h, h_order, order_idx, out_dists, out_commitments, dropout_mask) 69 | 70 | for idx in range(len(forest)): 71 | self.decoderLSTM(forest[idx], features, lstm_io) 72 | 73 | out_h = torch.index_select(lstm_io.hidden, 0, lstm_io.order.long()) 74 | out_dists = torch.index_select(lstm_io.dists, 0, lstm_io.order.long())[:-batch_size] 75 | out_commitments = torch.index_select(lstm_io.commitments, 0, lstm_io.order.long())[:-batch_size] 76 | 77 | # Do NMS here as a post-processing step 78 | if boxes_for_nms is not None and not self.training and self.not_rl: 79 | is_overlap = nms_overlaps(boxes_for_nms.data).view( 80 | boxes_for_nms.size(0), boxes_for_nms.size(0), boxes_for_nms.size(1) 81 | ).cpu().numpy() >= self.nms_thresh 82 | # is_overlap[np.arange(boxes_for_nms.size(0)), np.arange(boxes_for_nms.size(0))] = False 83 | 84 | out_dists_sampled = F.softmax(out_dists, 1).data.cpu().numpy() 85 | out_dists_sampled[:,0] = 0 86 | 87 | out_commitments = out_commitments.data.new(out_commitments.shape[0]).fill_(0) 88 | 89 | for i in range(out_commitments.size(0)): 90 | box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape) 91 | out_commitments[int(box_ind)] = int(cls_ind) 92 | out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = 0.0 93 | out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample 94 | 95 | out_commitments = Variable(out_commitments.view(-1)) 96 | else: 97 | out_commitments = out_commitments.view(-1) 98 | 99 | if self.training and self.not_rl and (labels is not None): 100 | out_commitments = labels.clone() 101 | else: 102 | out_commitments = torch.cat((out_commitments, Variable(torch.randn(batch_size).long().fill_(0).cuda()).view(-1)), 0) 103 | 104 | return out_dists, out_commitments 105 | 106 | 107 | def get_dropout_mask(dropout_probability: float, h_dim: int): 108 | """ 109 | Computes and returns an element-wise dropout mask for a given tensor, where 110 | each element in the mask is dropped out with probability dropout_probability. 111 | Note that the mask is NOT applied to the tensor - the tensor is passed to retain 112 | the correct CUDA tensor type for the mask. 113 | 114 | Parameters 115 | ---------- 116 | dropout_probability : float, required. 117 | Probability of dropping a dimension of the input. 118 | tensor_for_masking : torch.Variable, required. 119 | 120 | 121 | Returns 122 | ------- 123 | A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability). 124 | This scaling ensures expected values and variances of the output of applying this mask 125 | and the original tensor are the same. 126 | """ 127 | binary_mask = Variable(torch.FloatTensor(h_dim).cuda().fill_(0.0)) 128 | binary_mask.data.copy_(torch.rand(h_dim) > dropout_probability) 129 | # Scale mask by 1/keep_prob to preserve output statistics. 130 | dropout_mask = binary_mask.float().div(1.0 - dropout_probability) 131 | return dropout_mask -------------------------------------------------------------------------------- /lib/tree_lstm/def_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | import numpy as np 7 | 8 | from config import BOX_SCALE, IM_SCALE 9 | 10 | 11 | class BasicBiTree(object): 12 | def __init__(self, idx, is_root=False): 13 | self.index = int(idx) 14 | self.is_root = is_root 15 | self.left_child = None 16 | self.right_child = None 17 | self.parent = None 18 | self.num_child = 0 19 | 20 | def add_left_child(self, child): 21 | if self.left_child is not None: 22 | print('Left child already exist') 23 | return 24 | child.parent = self 25 | self.num_child += 1 26 | self.left_child = child 27 | 28 | def add_right_child(self, child): 29 | if self.right_child is not None: 30 | print('Right child already exist') 31 | return 32 | child.parent = self 33 | self.num_child += 1 34 | self.right_child = child 35 | 36 | def get_total_child(self): 37 | sum = 0 38 | sum += self.num_child 39 | if self.left_child is not None: 40 | sum += self.left_child.get_total_child() 41 | if self.right_child is not None: 42 | sum += self.right_child.get_total_child() 43 | return sum 44 | 45 | def depth(self): 46 | if hasattr(self, '_depth'): 47 | return self._depth 48 | if self.parent is None: 49 | count = 1 50 | else: 51 | count = self.parent.depth() + 1 52 | self._depth = count 53 | return self._depth 54 | 55 | def max_depth(self): 56 | if hasattr(self, '_max_depth'): 57 | return self._max_depth 58 | count = 0 59 | if self.left_child is not None: 60 | left_depth = self.left_child.max_depth() 61 | if left_depth > count: 62 | count = left_depth 63 | if self.right_child is not None: 64 | right_depth = self.right_child.max_depth() 65 | if right_depth > count: 66 | count = right_depth 67 | count += 1 68 | self._max_depth = count 69 | return self._max_depth 70 | 71 | # by index 72 | def is_descendant(self, idx): 73 | left_flag = False 74 | right_flag = False 75 | # node is left child 76 | if self.left_child is not None: 77 | if self.left_child.index is idx: 78 | return True 79 | else: 80 | left_flag = self.left_child.is_descendant(idx) 81 | # node is right child 82 | if self.right_child is not None: 83 | if self.right_child.index is idx: 84 | return True 85 | else: 86 | right_flag = self.right_child.is_descendant(idx) 87 | # node is descendant 88 | if left_flag or right_flag: 89 | return True 90 | else: 91 | return False 92 | 93 | # whether input node is under left sub tree 94 | def is_left_descendant(self, idx): 95 | if self.left_child is not None: 96 | if self.left_child.index is idx: 97 | return True 98 | else: 99 | return self.left_child.is_descendant(idx) 100 | else: 101 | return False 102 | 103 | # whether input node is under right sub tree 104 | def is_right_descendant(self, idx): 105 | if self.right_child is not None: 106 | if self.right_child.index is idx: 107 | return True 108 | else: 109 | return self.right_child.is_descendant(idx) 110 | else: 111 | return False 112 | 113 | 114 | class ArbitraryTree(object): 115 | def __init__(self, idx, score, label=-1, box=None, im_idx=-1, is_root=False): 116 | self.index = int(idx) 117 | self.is_root = is_root 118 | self.score = float(score) 119 | self.children = [] 120 | self.label = label 121 | self.embeded_label = None 122 | self.box = box.view(-1) if box is not None else None #[x1,y1,x2,y2] 123 | self.im_idx = int(im_idx) # which image it comes from 124 | self.parent = None 125 | self.node_order = -1 # the n_th node added to the tree 126 | 127 | def generate_bi_tree(self): 128 | # generate a BiTree node, parent/child relationship are not inherited 129 | return BiTree(self.index, self.score, self.label, self.box, self.im_idx, self.is_root) 130 | 131 | def add_child(self, child): 132 | child.parent = self 133 | self.children.append(child) 134 | 135 | def print(self): 136 | print('index: ', self.index) 137 | print('node_order: ', self.node_order) 138 | print('num of child: ', len(self.children)) 139 | for node in self.children: 140 | node.print() 141 | 142 | def find_node_by_order(self, order, result_node): 143 | if self.node_order == order: 144 | result_node = self 145 | elif len(self.children) > 0: 146 | for i in range(len(self.children)): 147 | result_node = self.children[i].find_node_by_order(order, result_node) 148 | 149 | return result_node 150 | 151 | def find_node_by_index(self, index, result_node): 152 | if self.index == index: 153 | result_node = self 154 | elif len(self.children) > 0: 155 | for i in range(len(self.children)): 156 | result_node = self.children[i].find_node_by_index(index, result_node) 157 | 158 | return result_node 159 | 160 | def search_best_insert(self, score_map, best_score, insert_node, best_depend_node, best_insert_node, ignore_root = True): 161 | if self.is_root and ignore_root: 162 | pass 163 | elif float(score_map[self.index, insert_node.index]) > float(best_score): 164 | best_score = score_map[self.index, insert_node.index] 165 | best_depend_node = self 166 | best_insert_node = insert_node 167 | 168 | # iteratively search child 169 | for i in range(self.get_child_num()): 170 | best_score, best_depend_node, best_insert_node = \ 171 | self.children[i].search_best_insert(score_map, best_score, insert_node, best_depend_node, best_insert_node) 172 | 173 | return best_score, best_depend_node, best_insert_node 174 | 175 | def get_child_num(self): 176 | return len(self.children) 177 | 178 | def get_total_child(self): 179 | sum = 0 180 | num_current_child = self.get_child_num() 181 | sum += num_current_child 182 | for i in range(num_current_child): 183 | sum += self.children[i].get_total_child() 184 | return sum 185 | 186 | # only support binary tree 187 | class BiTree(BasicBiTree): 188 | def __init__(self, idx, node_score, label, box, im_idx, is_root=False): 189 | super(BiTree, self).__init__(idx, is_root) 190 | self.state_c = None 191 | self.state_h = None 192 | self.state_c_backward = None 193 | self.state_h_backward = None 194 | # used to select node 195 | self.node_score = float(node_score) 196 | self.label = label 197 | self.embeded_label = None 198 | self.box = box.view(-1) #[x1,y1,x2,y2] 199 | self.im_idx = int(im_idx) # which image it comes from 200 | 201 | -------------------------------------------------------------------------------- /lib/tree_lstm/draw_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | import numpy as np 7 | from PIL import Image, ImageFont, ImageDraw, ImageEnhance 8 | from skimage.transform import resize 9 | from config import IM_SCALE 10 | 11 | tree_max_depth = 9 12 | 13 | label_to_text = {"-1":"unknown-1", "0":"unknown-0", "1": "airplane", "2": "animal", "3": "arm", "4": "bag", "5": "banana", "6": "basket", "7": "beach", "8": "bear", "9": "bed", "10": "bench", "11": "bike", "12": "bird", "13": "board", "14": "boat", "15": "book", "16": "boot", "17": "bottle", "18": "bowl", "19": "box", "20": "boy", "21": "branch", "22": "building", "23": "bus", "24": "cabinet", "25": "cap", "26": "car", "27": "cat", "28": "chair", "29": "child", "30": "clock", "31": "coat", "32": "counter", "33": "cow", "34": "cup", "35": "curtain", "36": "desk", "37": "dog", "38": "door", "39": "drawer", "40": "ear", "41": "elephant", "42": "engine", "43": "eye", "44": "face", "45": "fence", "46": "finger", "47": "flag", "48": "flower", "49": "food", "50": "fork", "51": "fruit", "52": "giraffe", "53": "girl", "54": "glass", "55": "glove", "56": "guy", "57": "hair", "58": "hand", "59": "handle", "60": "hat", "61": "head", "62": "helmet", "63": "hill", "64": "horse", "65": "house", "66": "jacket", "67": "jean", "68": "kid", "69": "kite", "70": "lady", "71": "lamp", "72": "laptop", "73": "leaf", "74": "leg", "75": "letter", "76": "light", "77": "logo", "78": "man", "79": "men", "80": "motorcycle", "81": "mountain", "82": "mouth", "83": "neck", "84": "nose", "85": "number", "86": "orange", "87": "pant", "88": "paper", "89": "paw", "90": "people", "91": "person", "92": "phone", "93": "pillow", "94": "pizza", "95": "plane", "96": "plant", "97": "plate", "98": "player", "99": "pole", "100": "post", "101": "pot", "102": "racket", "103": "railing", "104": "rock", "105": "roof", "106": "room", "107": "screen", "108": "seat", "109": "sheep", "110": "shelf", "111": "shirt", "112": "shoe", "113": "short", "114": "sidewalk", "115": "sign", "116": "sink", "117": "skateboard", "118": "ski", "119": "skier", "120": "sneaker", "121": "snow", "122": "sock", "123": "stand", "124": "street", "125": "surfboard", "126": "table", "127": "tail", "128": "tie", "129": "tile", "130": "tire", "131": "toilet", "132": "towel", "133": "tower", "134": "track", "135": "train", "136": "tree", "137": "truck", "138": "trunk", "139": "umbrella", "140": "vase", "141": "vegetable", "142": "vehicle", "143": "wave", "144": "wheel", "145": "window", "146": "windshield", "147": "wing", "148": "wire", "149": "woman", "150": "zebra"} 14 | 15 | def draw_tree_region_v2(tree, image, example_id, pred_labels): 16 | """ 17 | tree: A tree structure 18 | image: origin image batch [batch_size, 3, IM_SIZE, IM_SIZE] 19 | output: a image with roi bbox, the color of box correspond to the depth of roi node 20 | """ 21 | sample_image = image[tree.im_idx].view(image.shape[1:]).clone() 22 | sample_image = (revert_normalize(sample_image) * 255).int() 23 | sample_image = torch.clamp(sample_image, 0, 255) 24 | sample_image = sample_image.permute(1,2,0).contiguous().data.cpu().numpy().astype(dtype = np.uint8) 25 | sample_image = Image.fromarray(sample_image, 'RGB').convert("RGBA") 26 | 27 | draw = ImageDraw.Draw(sample_image) 28 | draw_box(draw, tree, pred_labels) 29 | 30 | sample_image.save('./output/example/'+str(example_id)+'_box'+'.png') 31 | 32 | #print('saved img ' + str(example_id)) 33 | 34 | 35 | def draw_box(draw, tree, pred_labels): 36 | x1,y1,x2,y2 = int(tree.box[0]), int(tree.box[1]), int(tree.box[2]), int(tree.box[3]) 37 | draw.rectangle(((x1, y1), (x2, y2)), outline="red") 38 | draw.rectangle(((x1, y1), (x1+50, y1+10)), fill="red") 39 | node_label = int(pred_labels[int(tree.index)]) 40 | draw.text((x1, y1), label_to_text[str(node_label)]) 41 | 42 | if (tree.left_child is not None): 43 | draw_box(draw, tree.left_child, pred_labels) 44 | if (tree.right_child is not None): 45 | draw_box(draw, tree.right_child, pred_labels) 46 | 47 | 48 | 49 | def draw_tree_region(tree, image, example_id): 50 | """ 51 | tree: A tree structure 52 | image: origin image batch [batch_size, 3, IM_SIZE, IM_SIZE] 53 | output: a image display regions in a tree structure 54 | """ 55 | sample_image = image[tree.im_idx].view(image.shape[1:]).clone() 56 | sample_image = (revert_normalize(sample_image) * 255).int() 57 | sample_image = torch.clamp(sample_image, 0, 255) 58 | sample_image = sample_image.permute(1,2,0).contiguous().data.cpu().numpy().astype(dtype = np.uint8) 59 | 60 | global tree_max_depth 61 | 62 | depth = min(tree.max_depth(), tree_max_depth) 63 | tree_img = create_tree_img(depth, 64) 64 | tree_img = write_cell(sample_image, tree_img, (0,0,tree_img.shape[1], tree_img.shape[0]), tree, 64) 65 | 66 | im = Image.fromarray(sample_image, 'RGB') 67 | tree_img = Image.fromarray(tree_img, 'RGB') 68 | im.save('./output/example/'+str(example_id)+'_origin'+'.jpg') 69 | tree_img.save('./output/example/'+str(example_id)+'_tree'+'.jpg') 70 | 71 | if example_id % 200 == 0: 72 | print('saved img ' + str(example_id)) 73 | 74 | def write_cell(orig_img, tree_img, draw_box, tree, cell_size): 75 | """ 76 | orig_img: original image 77 | tree_img: draw roi tree 78 | draw_box: the whole bbox used to draw this sub-tree [x1,y1,x2,y2] 79 | tree: a sub-tree 80 | cell_size: size of each roi 81 | """ 82 | x1,y1,x2,y2 = draw_box 83 | if tree is None: 84 | return tree_img 85 | # draw 86 | roi = orig_img[int(tree.box[1]):int(tree.box[3]), int(tree.box[0]):int(tree.box[2]), :] 87 | roi = Image.fromarray(roi, 'RGB') 88 | roi = roi.resize((cell_size, cell_size)) 89 | roi = np.array(roi) 90 | draw_x1 = int(max((x1+x2)/2 - cell_size/2, 0)) 91 | draw_x2 = int(min(draw_x1 + cell_size, x2)) 92 | draw_y1 = y1 93 | draw_y2 = min(y1 + cell_size, y2) 94 | tree_img[draw_y1:draw_y2, draw_x1:draw_x2,:] = roi[:draw_y2-draw_y1,:draw_x2-draw_x1,:] 95 | # recursive draw 96 | global tree_max_depth 97 | if (tree.left_child is not None) and tree.left_child.depth() <= tree_max_depth: 98 | tree_img = write_cell(orig_img, tree_img, (x1,draw_y2,int((x1+x2)/2),y2), tree.left_child, cell_size) 99 | if (tree.right_child is not None) and tree.right_child.depth() <= tree_max_depth: 100 | tree_img = write_cell(orig_img, tree_img, (int((x1+x2)/2),draw_y2,x2,y2), tree.right_child, cell_size) 101 | 102 | return tree_img 103 | 104 | def create_tree_img(depth, cell_size): 105 | height = cell_size * (depth) 106 | width = cell_size * (2**(depth-1)) 107 | return np.zeros((height,width,3)).astype(dtype=np.uint8) 108 | 109 | def revert_normalize(image): 110 | image[0,:,:] = image[0,:,:] * 0.229 111 | image[1,:,:] = image[1,:,:] * 0.224 112 | image[2,:,:] = image[2,:,:] * 0.225 113 | 114 | image[0,:,:] = image[0,:,:] + 0.485 115 | image[1,:,:] = image[1,:,:] + 0.456 116 | image[2,:,:] = image[2,:,:] + 0.406 117 | 118 | return image 119 | 120 | def print_tree(tree): 121 | if tree is None: 122 | return 123 | if(tree.left_child is not None): 124 | print_node(tree.left_child) 125 | if(tree.right_child is not None): 126 | print_node(tree.right_child) 127 | 128 | print_tree(tree.left_child) 129 | print_tree(tree.right_child) 130 | 131 | return 132 | 133 | 134 | def print_node(tree): 135 | print(' depth: ', tree.depth(), end="") 136 | print(' label: ', tree.label, end="") 137 | print(' index: ', int(tree.index), end="") 138 | print(' score: ', tree.score(), end="") 139 | print(' center_x: ', tree.center_x) 140 | -------------------------------------------------------------------------------- /lib/tree_lstm/gen_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | from lib.word_vectors import obj_edge_vectors 7 | import numpy as np 8 | from config import IM_SCALE 9 | import random 10 | 11 | from lib.tree_lstm import tree_utils 12 | from lib.lstm.highway_lstm_cuda.alternating_highway_lstm import block_orthogonal 13 | from lib.tree_lstm.def_tree import ArbitraryTree 14 | from config import LOG_SOFTMAX 15 | 16 | class RLFeatPreprocessNet(nn.Module): 17 | def __init__(self, feature_size, embed_size, box_info_size, overlap_info_size, output_size): 18 | super(RLFeatPreprocessNet, self).__init__() 19 | self.feature_size = feature_size 20 | self.embed_size = embed_size 21 | self.box_info_size = box_info_size 22 | self.overlap_info_size = overlap_info_size 23 | self.output_size = output_size 24 | 25 | # linear layers 26 | self.resize_feat = nn.Linear(self.feature_size, int(output_size / 4)) 27 | self.resize_embed = nn.Linear(self.embed_size, int(output_size / 4)) 28 | self.resize_box = nn.Linear(self.box_info_size, int(output_size / 4)) 29 | self.resize_overlap = nn.Linear(self.overlap_info_size, int(output_size / 4)) 30 | 31 | # init 32 | self.resize_feat.weight.data.normal_(0, 0.001) 33 | self.resize_embed.weight.data.normal_(0, 0.01) 34 | self.resize_box.weight.data.normal_(0, 1) 35 | self.resize_overlap.weight.data.normal_(0, 1) 36 | self.resize_feat.bias.data.zero_() 37 | self.resize_embed.bias.data.zero_() 38 | self.resize_box.bias.data.zero_() 39 | self.resize_overlap.bias.data.zero_() 40 | 41 | def forward(self, obj_feat, obj_embed, box_info, overlap_info): 42 | resized_obj = self.resize_feat(obj_feat) 43 | resized_embed = self.resize_embed(obj_embed) 44 | resized_box = self.resize_box(box_info) 45 | resized_overlap = self.resize_overlap(overlap_info) 46 | 47 | output_feat = torch.cat((resized_obj, resized_embed, resized_box, resized_overlap), 1) 48 | return output_feat 49 | 50 | def generate_forest(im_inds, gt_forest, pair_scores, box_priors, obj_label, use_rl_tree, is_training, mode): 51 | """ 52 | generate a list of trees that covers all the objects in a batch 53 | im_inds: [obj_num] 54 | box_priors: [obj_num, (x1, y1, x2, y2)] 55 | pair_scores: [obj_num, obj_num] 56 | 57 | output: list of trees, each present a chunk of overlaping objects 58 | """ 59 | output_forest = [] # the list of trees, each one is a chunk of overlapping objects 60 | num_obj = box_priors.shape[0] 61 | 62 | # node score: accumulate parent scores 63 | node_scores = pair_scores.mean(1).view(-1) 64 | # make forest 65 | group_id = 0 66 | 67 | gen_tree_loss_per_batch = [] 68 | entropy_loss = [] 69 | 70 | while(torch.nonzero(im_inds == group_id).numel() > 0): 71 | # select the nodes from the same image 72 | rl_node_container = [] 73 | remain_index = [] 74 | picked_list = torch.nonzero(im_inds == group_id).view(-1) 75 | root_idx = picked_list[-1] 76 | 77 | rl_root = ArbitraryTree(root_idx, node_scores[int(root_idx)], -1, box_priors[int(root_idx)], im_inds[int(root_idx)], is_root=True) 78 | 79 | # put all nodes into node container 80 | for idx in picked_list[:-1]: 81 | if obj_label is not None: 82 | label = int(obj_label[idx]) 83 | else: 84 | label = -1 85 | new_node = ArbitraryTree(idx, node_scores[idx], label, box_priors[idx], im_inds[idx]) 86 | rl_node_container.append(new_node) 87 | remain_index.append(int(idx)) 88 | 89 | # iteratively generate tree 90 | rl_gen_tree(rl_node_container, pair_scores, node_scores, gen_tree_loss_per_batch, entropy_loss, rl_root, remain_index, (is_training and use_rl_tree), mode) 91 | 92 | output_forest.append(rl_root) 93 | group_id += 1 94 | 95 | return output_forest, gen_tree_loss_per_batch, entropy_loss 96 | 97 | 98 | def rl_gen_tree(node_container, pair_scores, node_scores, gen_tree_loss_per_batch, entropy_loss, rl_root, remain_index, rl_training, mode): 99 | """ 100 | use reinforcement learning to generate loss (without baseline tree) 101 | Calculate the log(pr) for each decision (not cross entropy) 102 | Step 1: Devide all nodes into left child container and right child container 103 | Step 2: From left child container and right child container, select their respective sub roots 104 | 105 | pair_scores: [obj_num, obj_num] 106 | node_scores: [obj_num] 107 | """ 108 | num_nodes = len(node_container) 109 | # Step 0 110 | if num_nodes == 0: 111 | return 112 | # Step 1 113 | select_node = [] 114 | select_index = [] 115 | select_node.append(rl_root) 116 | select_index.append(rl_root.index) 117 | 118 | if mode == 'predcls': 119 | first_score = node_scores[remain_index].contiguous().view(-1) 120 | _, inds = F.softmax(first_score, 0).max(0) 121 | first_node = node_container[int(inds)] 122 | rl_root.add_child(first_node) 123 | select_node.append(first_node) 124 | select_index.append(first_node.index) 125 | node_container.remove(first_node) 126 | remain_index.remove(first_node.index) 127 | 128 | not_sampled = True 129 | 130 | while len(node_container) > 0: 131 | wid = len(remain_index) 132 | select_index_var = Variable(torch.LongTensor(select_index).cuda()) 133 | remain_index_var = Variable(torch.LongTensor(remain_index).cuda()) 134 | select_score_map = torch.index_select( torch.index_select(pair_scores, 0, select_index_var), 1, remain_index_var ).contiguous().view(-1) 135 | #select_score_map = pair_scores[select_index][:,remain_index].contiguous().view(-1) 136 | if rl_training and not_sampled: 137 | dist = F.softmax(select_score_map, 0) 138 | greedy_id = dist.max(0)[1] 139 | best_id = torch.multinomial(dist, 1)[0] 140 | if int(greedy_id) != int(best_id): 141 | not_sampled = False 142 | if LOG_SOFTMAX: 143 | prob = dist[best_id] + 1e-20 144 | else: 145 | prob = select_score_map[best_id] + 1e-20 146 | gen_tree_loss_per_batch.append(prob.log()) 147 | #neg_entropy = dist * (dist + 1e-20).log() 148 | #entropy_loss.append(neg_entropy.sum()) 149 | else: 150 | best_score, best_id = select_score_map.max(0) 151 | depend_id = int(best_id) // wid 152 | insert_id = int(best_id) % wid 153 | best_depend_node = select_node[depend_id] 154 | best_insert_node = node_container[insert_id] 155 | best_depend_node.add_child(best_insert_node) 156 | 157 | select_node.append(best_insert_node) 158 | select_index.append(best_insert_node.index) 159 | node_container.remove(best_insert_node) 160 | remain_index.remove(best_insert_node.index) 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /lib/tree_lstm/graph_to_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | from lib.word_vectors import obj_edge_vectors 7 | import numpy as np 8 | from config import IM_SCALE, ROOT_PATH 9 | import random 10 | 11 | from lib.tree_lstm.def_tree import ArbitraryTree 12 | 13 | def graph_to_trees(co_occour_prob, rel_labels, obj_labels): 14 | """ 15 | Generate arbitrary trees according to the ground truth graph 16 | 17 | co_occour: [num_obj_classes, num_obj_classes] 18 | rel_labels: [image_index, i, j, rel_label], i is not j, i & j are under same image 19 | obj_labels: [image_index, obj_label] 20 | 21 | output: [forest] 22 | """ 23 | num_nodes = obj_labels.shape[0] 24 | num_edges = rel_labels.shape[0] 25 | output = [] 26 | 27 | # calculate the score of each edge 28 | edge_scores = rel_labels.float().clone() 29 | for i in range(num_edges): 30 | if int(rel_labels[i, 3]) == 0: 31 | edge_scores[i, 3] = 0 32 | else: 33 | sub_id = int(obj_labels[int(rel_labels[i, 1]), 1]) - 1 34 | obj_id = int(obj_labels[int(rel_labels[i, 2]), 1]) - 1 35 | edge_scores[i, 3] = co_occour_prob[sub_id, obj_id] 36 | 37 | # generate score map (num_obj * num_obj) 38 | score_map = np.zeros((num_nodes, num_nodes)) 39 | for i in range(num_edges): 40 | sub_id = int(edge_scores[i, 1]) 41 | obj_id = int(edge_scores[i, 2]) 42 | score_map[sub_id, obj_id] = float(score_map[sub_id, obj_id]) + float(edge_scores[i, 3]) 43 | score_map[obj_id, sub_id] = float(score_map[obj_id, sub_id]) + float(edge_scores[i, 3]) 44 | 45 | # calculate the score of each node 46 | node_scores = obj_labels.float().clone() 47 | node_scores[:,1] = 0 48 | for i in range(num_edges): 49 | node_id_1 = int(edge_scores[i, 1]) 50 | node_id_2 = int(edge_scores[i, 2]) 51 | node_scores[node_id_1, 1] = float(node_scores[node_id_1, 1]) + float(edge_scores[i, 3]) 52 | node_scores[node_id_2, 1] = float(node_scores[node_id_2, 1]) + float(edge_scores[i, 3]) 53 | 54 | # generate arbitrary tree 55 | group_id = 0 56 | im_inds = obj_labels[:, 0].contiguous() 57 | while(torch.nonzero(im_inds == group_id).numel() > 0): 58 | # generate node container 59 | node_container = [] 60 | picked_list = torch.nonzero(im_inds == group_id).view(-1) 61 | for idx in picked_list: 62 | node_container.append(ArbitraryTree(idx, node_scores[int(idx), 1])) 63 | # use virtual root (entire image), index is -1, score is almost infinity 64 | tree_root = ArbitraryTree(-1, 10e10, is_root=True) 65 | # node insert order 66 | node_order = 0 67 | # find first & best node to start 68 | best_node = find_best_node(node_container) 69 | tree_root.add_child(best_node) 70 | best_node.node_order = node_order 71 | node_order += 1 72 | node_container.remove(best_node) 73 | # generate tree 74 | while(len(node_container) > 0): 75 | best_depend_node = None 76 | best_insert_node = None 77 | best_score = -1 78 | for i in range(len(node_container)): 79 | best_score, best_depend_node, best_insert_node = \ 80 | tree_root.search_best_insert(score_map, best_score, node_container[i], best_depend_node, best_insert_node) 81 | 82 | # if not in current tree, add to root, else insert 83 | if best_score == 0: 84 | best_node = find_best_node(node_container) 85 | tree_root.add_child(best_node) 86 | best_node.node_order = node_order 87 | node_order += 1 88 | node_container.remove(best_node) 89 | else: 90 | best_depend_node.add_child(best_insert_node) 91 | best_insert_node.node_order = node_order 92 | node_order += 1 93 | node_container.remove(best_insert_node) 94 | 95 | # add tree to forest 96 | output.append(tree_root) 97 | # next image 98 | group_id += 1 99 | 100 | return output 101 | 102 | 103 | def arbitraryForest_to_biForest(forest): 104 | """ 105 | forest: a set of arbitrary Tree 106 | output: a set of corresponding binary Tree 107 | """ 108 | output = [] 109 | for i in range(len(forest)): 110 | result_tree = arTree_to_biTree(forest[i]) 111 | # make sure they are equivalent tree 112 | # assert(result_tree.get_total_child() == forest[i].get_total_child()) 113 | output.append(result_tree) 114 | 115 | return output 116 | 117 | 118 | def arTree_to_biTree(arTree): 119 | root_node = arTree.generate_bi_tree() 120 | arNode_to_biNode(arTree, root_node) 121 | 122 | return root_node 123 | 124 | def arNode_to_biNode(arNode, biNode): 125 | if arNode.get_child_num() >= 1: 126 | new_bi_node = arNode.children[0].generate_bi_tree() 127 | biNode.add_left_child(new_bi_node) 128 | arNode_to_biNode(arNode.children[0], biNode.left_child) 129 | 130 | if arNode.get_child_num() > 1: 131 | current_bi_node = biNode.left_child 132 | for i in range(arNode.get_child_num() - 1): 133 | new_bi_node = arNode.children[i+1].generate_bi_tree() 134 | current_bi_node.add_right_child(new_bi_node) 135 | current_bi_node = current_bi_node.right_child 136 | arNode_to_biNode(arNode.children[i+1], current_bi_node) 137 | 138 | def find_best_node(node_container): 139 | max_node_score = -1 140 | best_node = None 141 | for i in range(len(node_container)): 142 | if node_container[i].score > max_node_score: 143 | max_node_score = node_container[i].score 144 | best_node = node_container[i] 145 | return best_node 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /lib/tree_lstm/tree_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | import numpy as np 7 | 8 | from config import BOX_SCALE, IM_SCALE 9 | from lib.fpn.box_utils import bbox_overlaps, bbox_intersections, center_size 10 | 11 | class TreeLSTM_IO(object): 12 | def __init__(self, hidden_tensor, order_tensor, order_count, dists_tensor, commitments_tensor, dropout_mask): 13 | self.hidden = hidden_tensor # Float tensor [num_obj, self.out_dim] 14 | self.order = order_tensor # Long tensor [num_obj] 15 | self.order_count = order_count # int 16 | self.dists = dists_tensor # FLoat tensor [num_obj, len(self.classes)] 17 | self.commitments = commitments_tensor 18 | self.dropout_mask = dropout_mask 19 | 20 | def get_overlap_info(im_inds, box_priors): 21 | """ 22 | input: 23 | im_inds: [num_object] 24 | box_priors: [number_object, 4] 25 | output: [number_object, 6] 26 | number of overlapped obj (self not included) 27 | sum of all intersection area (self not included) 28 | sum of IoU (Intersection over Union) 29 | average of all intersection area (self not included) 30 | average of IoU (Intersection over Union) 31 | roi area 32 | """ 33 | # generate forest 34 | num_obj = box_priors.shape[0] 35 | inds_offset = (im_inds * 1000).view(-1,1).expand(box_priors.shape) 36 | offset_box = box_priors + inds_offset.float() 37 | intersection = bbox_intersections(offset_box, offset_box) 38 | overlap = bbox_overlaps(offset_box, offset_box) 39 | # [obj_num, obj_num], diagonal elements should been removed 40 | reverse_eye = Variable(1.0 - torch.eye(num_obj).float().cuda()) 41 | intersection = intersection * reverse_eye 42 | overlap = overlap * reverse_eye 43 | box_area = bbox_area(offset_box) 44 | # generate input feat 45 | boxes_info = Variable(torch.FloatTensor(num_obj, 6).zero_().cuda()) # each obj has how many overlaped objects 46 | 47 | for obj_idx in range(num_obj): 48 | boxes_info[obj_idx, 0] = torch.nonzero(intersection[obj_idx]).numel() 49 | boxes_info[obj_idx, 1] = intersection[obj_idx].view(-1).sum() / float(IM_SCALE * IM_SCALE) 50 | boxes_info[obj_idx, 2] = overlap[obj_idx].view(-1).sum() 51 | boxes_info[obj_idx, 3] = boxes_info[obj_idx, 1] / (boxes_info[obj_idx, 0] + 1e-9) 52 | boxes_info[obj_idx, 4] = boxes_info[obj_idx, 2] / (boxes_info[obj_idx, 0] + 1e-9) 53 | boxes_info[obj_idx, 5] = box_area[obj_idx] / float(IM_SCALE * IM_SCALE) 54 | 55 | return boxes_info, intersection 56 | 57 | def get_box_info(boxes): 58 | """ 59 | input: [batch_size, (x1,y1,x2,y2)] 60 | output: [batch_size, (x1,y1,x2,y2,cx,cy,w,h)] 61 | """ 62 | return torch.cat((boxes / float(IM_SCALE), center_size(boxes) / float(IM_SCALE)), 1) 63 | 64 | 65 | def get_box_pair_info(box1, box2): 66 | """ 67 | input: 68 | box1 [batch_size, (x1,y1,x2,y2,cx,cy,w,h)] 69 | box2 [batch_size, (x1,y1,x2,y2,cx,cy,w,h)] 70 | output: 71 | 32-digits: [box1, box2, unionbox, intersectionbox] 72 | """ 73 | # union box 74 | unionbox = box1[:,:4].clone() 75 | unionbox[:, 0] = torch.min(box1[:, 0], box2[:, 0]) 76 | unionbox[:, 1] = torch.min(box1[:, 1], box2[:, 1]) 77 | unionbox[:, 2] = torch.max(box1[:, 2], box2[:, 2]) 78 | unionbox[:, 3] = torch.max(box1[:, 3], box2[:, 3]) 79 | union_info = get_box_info(unionbox) 80 | 81 | # intersection box 82 | intersextion_box = box1[:,:4].clone() 83 | intersextion_box[:, 0] = torch.max(box1[:, 0], box2[:, 0]) 84 | intersextion_box[:, 1] = torch.max(box1[:, 1], box2[:, 1]) 85 | intersextion_box[:, 2] = torch.min(box1[:, 2], box2[:, 2]) 86 | intersextion_box[:, 3] = torch.min(box1[:, 3], box2[:, 3]) 87 | case1 = torch.nonzero(intersextion_box[:, 2].contiguous().view(-1) < intersextion_box[:, 0].contiguous().view(-1)).view(-1) 88 | case2 = torch.nonzero(intersextion_box[:, 3].contiguous().view(-1) < intersextion_box[:, 1].contiguous().view(-1)).view(-1) 89 | intersextion_info = get_box_info(intersextion_box) 90 | if case1.numel() > 0: 91 | intersextion_info[case1, :] = 0 92 | if case2.numel() > 0: 93 | intersextion_info[case2, :] = 0 94 | 95 | return torch.cat((box1, box2, union_info, intersextion_info), 1) 96 | 97 | 98 | def bbox_area(gt_boxes): 99 | """ 100 | gt_boxes: (K, 4) ndarray of float 101 | 102 | area: (k) 103 | """ 104 | K = gt_boxes.size(0) 105 | gt_boxes_area = ((gt_boxes[:,2] - gt_boxes[:,0] + 1) * 106 | (gt_boxes[:,3] - gt_boxes[:,1] + 1)).view(K) 107 | 108 | return gt_boxes_area 109 | 110 | def bbox_center(gt_boxes): 111 | """ 112 | gt_boxes: (K, 4) ndarray of float 113 | 114 | center: (k, 2) 115 | """ 116 | center = (gt_boxes[:, 2:] + gt_boxes[:, :2]) / 2.0 117 | assert(center.shape[1] == 2) 118 | return center 119 | 120 | def print_tree(tree): 121 | if tree is None: 122 | return 123 | if(tree.left_child is not None): 124 | print_node(tree.left_child) 125 | if(tree.right_child is not None): 126 | print_node(tree.right_child) 127 | 128 | print_tree(tree.left_child) 129 | print_tree(tree.right_child) 130 | 131 | return 132 | 133 | 134 | def print_node(tree): 135 | print(' depth: ', tree.depth(), end="") 136 | print(' label: ', tree.label, end="") 137 | print(' index: ', int(tree.index), end="") 138 | print(' score: ', tree.score(), end="") 139 | print(' center_x: ', tree.center_x) 140 | -------------------------------------------------------------------------------- /lib/word_vectors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from PyTorch's text library. 3 | """ 4 | 5 | import array 6 | import os 7 | import zipfile 8 | 9 | import six 10 | import torch 11 | from six.moves.urllib.request import urlretrieve 12 | from tqdm import tqdm 13 | 14 | from config import DATA_PATH 15 | 16 | 17 | def obj_edge_vectors(names, wv_type='glove.6B', wv_dir=DATA_PATH, wv_dim=300): 18 | wv_dict, wv_arr, wv_size = load_word_vectors(wv_dir, wv_type, wv_dim) 19 | 20 | vectors = torch.Tensor(len(names), wv_dim) 21 | vectors.normal_(0,1) 22 | 23 | for i, token in enumerate(names): 24 | wv_index = wv_dict.get(token, None) 25 | if wv_index is not None: 26 | vectors[i] = wv_arr[wv_index] 27 | else: 28 | # Try the longest word (hopefully won't be a preposition 29 | lw_token = sorted(token.split(' '), key=lambda x: len(x), reverse=True)[0] 30 | print("{} -> {} ".format(token, lw_token)) 31 | wv_index = wv_dict.get(lw_token, None) 32 | if wv_index is not None: 33 | vectors[i] = wv_arr[wv_index] 34 | else: 35 | print("fail on {}".format(token)) 36 | 37 | return vectors 38 | 39 | URL = { 40 | 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 41 | 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 42 | 'glove.twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 43 | 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 44 | } 45 | 46 | 47 | def load_word_vectors(root, wv_type, dim): 48 | """Load word vectors from a path, trying .pt, .txt, and .zip extensions.""" 49 | if isinstance(dim, int): 50 | dim = str(dim) + 'd' 51 | fname = os.path.join(root, wv_type + '.' + dim) 52 | if os.path.isfile(fname + '.pt'): 53 | fname_pt = fname + '.pt' 54 | print('loading word vectors from', fname_pt) 55 | return torch.load(fname_pt) 56 | if os.path.isfile(fname + '.txt'): 57 | fname_txt = fname + '.txt' 58 | cm = open(fname_txt, 'rb') 59 | cm = [line for line in cm] 60 | elif os.path.basename(wv_type) in URL: 61 | url = URL[wv_type] 62 | print('downloading word vectors from {}'.format(url)) 63 | filename = os.path.basename(fname) 64 | if not os.path.exists(root): 65 | os.makedirs(root) 66 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: 67 | fname, _ = urlretrieve(url, fname, reporthook=reporthook(t)) 68 | with zipfile.ZipFile(fname, "r") as zf: 69 | print('extracting word vectors into {}'.format(root)) 70 | zf.extractall(root) 71 | if not os.path.isfile(fname + '.txt'): 72 | raise RuntimeError('no word vectors of requested dimension found') 73 | return load_word_vectors(root, wv_type, dim) 74 | else: 75 | raise RuntimeError('unable to load word vectors') 76 | 77 | wv_tokens, wv_arr, wv_size = [], array.array('d'), None 78 | if cm is not None: 79 | for line in tqdm(range(len(cm)), desc="loading word vectors from {}".format(fname_txt)): 80 | entries = cm[line].strip().split(b' ') 81 | word, entries = entries[0], entries[1:] 82 | if wv_size is None: 83 | wv_size = len(entries) 84 | try: 85 | if isinstance(word, six.binary_type): 86 | word = word.decode('utf-8') 87 | except: 88 | print('non-UTF8 token', repr(word), 'ignored') 89 | continue 90 | wv_arr.extend(float(x) for x in entries) 91 | wv_tokens.append(word) 92 | 93 | wv_dict = {word: i for i, word in enumerate(wv_tokens)} 94 | wv_arr = torch.Tensor(wv_arr).view(-1, wv_size) 95 | ret = (wv_dict, wv_arr, wv_size) 96 | torch.save(ret, fname + '.pt') 97 | return ret 98 | 99 | def reporthook(t): 100 | """https://github.com/tqdm/tqdm""" 101 | last_b = [0] 102 | 103 | def inner(b=1, bsize=1, tsize=None): 104 | """ 105 | b: int, optionala 106 | Number of blocks just transferred [default: 1]. 107 | bsize: int, optional 108 | Size of each block (in tqdm units) [default: 1]. 109 | tsize: int, optional 110 | Total size (in tqdm units). If [default: None] remains unchanged. 111 | """ 112 | if tsize is not None: 113 | t.total = tsize 114 | t.update((b - last_b[0]) * bsize) 115 | last_b[0] = b 116 | return inner 117 | -------------------------------------------------------------------------------- /models/eval_rels.py: -------------------------------------------------------------------------------- 1 | 2 | from dataloaders.visual_genome import VGDataLoader, VG 3 | import numpy as np 4 | import torch 5 | 6 | from config import ModelConfig 7 | from lib.pytorch_misc import optimistic_restore 8 | from lib.evaluation.sg_eval import BasicSceneGraphEvaluator 9 | from tqdm import tqdm 10 | from config import BOX_SCALE, IM_SCALE 11 | import dill as pkl 12 | import os 13 | 14 | conf = ModelConfig() 15 | if conf.model == 'motifnet': 16 | from lib.rel_model import RelModel 17 | elif conf.model == 'stanford': 18 | from lib.rel_model_stanford import RelModelStanford as RelModel 19 | else: 20 | raise ValueError() 21 | 22 | train, val, test = VG.splits(num_val_im=conf.val_size, filter_duplicate_rels=True, 23 | use_proposals=conf.use_proposals, 24 | filter_non_overlap=conf.mode == 'sgdet') 25 | if conf.test: 26 | val = test 27 | train_loader, val_loader = VGDataLoader.splits(train, val, mode='rel', 28 | batch_size=conf.batch_size, 29 | num_workers=conf.num_workers, 30 | num_gpus=conf.num_gpus) 31 | 32 | detector = RelModel(classes=train.ind_to_classes, rel_classes=train.ind_to_predicates, 33 | num_gpus=conf.num_gpus, mode=conf.mode, require_overlap_det=True, 34 | use_resnet=conf.use_resnet, order=conf.order, 35 | nl_edge=conf.nl_edge, nl_obj=conf.nl_obj, hidden_dim=conf.hidden_dim, 36 | use_proposals=conf.use_proposals, 37 | pass_in_obj_feats_to_decoder=conf.pass_in_obj_feats_to_decoder, 38 | pass_in_obj_feats_to_edge=conf.pass_in_obj_feats_to_edge, 39 | pooling_dim=conf.pooling_dim, 40 | rec_dropout=conf.rec_dropout, 41 | use_bias=conf.use_bias, 42 | use_tanh=conf.use_tanh, 43 | use_encoded_box = conf.use_encoded_box, 44 | use_rl_tree = conf.use_rl_tree, 45 | draw_tree = conf.draw_tree, 46 | limit_vision=conf.limit_vision 47 | ) 48 | 49 | 50 | detector.cuda() 51 | ckpt = torch.load(conf.ckpt) 52 | 53 | optimistic_restore(detector, ckpt['state_dict']) 54 | # if conf.mode == 'sgdet': 55 | # det_ckpt = torch.load('checkpoints/new_vgdet/vg-19.tar')['state_dict'] 56 | # detector.detector.bbox_fc.weight.data.copy_(det_ckpt['bbox_fc.weight']) 57 | # detector.detector.bbox_fc.bias.data.copy_(det_ckpt['bbox_fc.bias']) 58 | # detector.detector.score_fc.weight.data.copy_(det_ckpt['score_fc.weight']) 59 | # detector.detector.score_fc.bias.data.copy_(det_ckpt['score_fc.bias']) 60 | 61 | all_pred_entries = [] 62 | def val_batch(batch_num, b, evaluator, thrs=(20, 50, 100)): 63 | det_res = detector[b] 64 | if conf.num_gpus == 1: 65 | det_res = [det_res] 66 | 67 | for i, (boxes_i, objs_i, obj_scores_i, rels_i, pred_scores_i, gt_boxes, gt_classes, gt_rels) in enumerate(det_res): 68 | gt_entry = { 69 | 'gt_classes': val.gt_classes[batch_num + i].copy(), 70 | 'gt_relations': val.relationships[batch_num + i].copy(), 71 | 'gt_boxes': val.gt_boxes[batch_num + i].copy(), 72 | } 73 | assert np.all(objs_i[rels_i[:,0]] > 0) and np.all(objs_i[rels_i[:,1]] > 0) 74 | # assert np.all(rels_i[:,2] > 0) 75 | 76 | pred_entry = { 77 | 'pred_boxes': boxes_i * BOX_SCALE/IM_SCALE, 78 | 'pred_classes': objs_i, 79 | 'pred_rel_inds': rels_i, 80 | 'obj_scores': obj_scores_i, 81 | 'rel_scores': pred_scores_i, 82 | } 83 | all_pred_entries.append(pred_entry) 84 | 85 | evaluator[conf.mode].evaluate_scene_graph_entry( 86 | gt_entry, 87 | pred_entry, 88 | ) 89 | 90 | evaluator = BasicSceneGraphEvaluator.all_modes(multiple_preds=conf.multi_pred) 91 | if conf.cache is not None and os.path.exists(conf.cache): 92 | print("Found {}! Loading from it".format(conf.cache)) 93 | with open(conf.cache,'rb') as f: 94 | all_pred_entries = pkl.load(f) 95 | for i, pred_entry in enumerate(tqdm(all_pred_entries)): 96 | gt_entry = { 97 | 'gt_classes': val.gt_classes[i].copy(), 98 | 'gt_relations': val.relationships[i].copy(), 99 | 'gt_boxes': val.gt_boxes[i].copy(), 100 | } 101 | evaluator[conf.mode].evaluate_scene_graph_entry( 102 | gt_entry, 103 | pred_entry, 104 | ) 105 | evaluator[conf.mode].print_stats() 106 | else: 107 | detector.eval() 108 | for val_b, batch in enumerate(tqdm(val_loader)): 109 | val_batch(conf.num_gpus*val_b, batch, evaluator) 110 | 111 | evaluator[conf.mode].print_stats() 112 | 113 | if conf.cache is not None: 114 | with open(conf.cache,'wb') as f: 115 | pkl.dump(all_pred_entries, f) 116 | -------------------------------------------------------------------------------- /models/train_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script 4 Detection 3 | """ 4 | from dataloaders.mscoco import CocoDetection, CocoDataLoader 5 | from dataloaders.visual_genome import VGDataLoader, VG 6 | from lib.object_detector import ObjectDetector 7 | import numpy as np 8 | from torch import optim 9 | import torch 10 | import pandas as pd 11 | import time 12 | import os 13 | from config import ModelConfig, FG_FRACTION, RPN_FG_FRACTION, IM_SCALE, BOX_SCALE 14 | from torch.nn import functional as F 15 | from lib.fpn.box_utils import bbox_loss 16 | import torch.backends.cudnn as cudnn 17 | from pycocotools.cocoeval import COCOeval 18 | from lib.pytorch_misc import optimistic_restore, clip_grad_norm 19 | from torch.optim.lr_scheduler import ReduceLROnPlateau 20 | 21 | cudnn.benchmark = True 22 | conf = ModelConfig() 23 | 24 | if conf.coco: 25 | train, val = CocoDetection.splits() 26 | val.ids = val.ids[:conf.val_size] 27 | train.ids = train.ids 28 | train_loader, val_loader = CocoDataLoader.splits(train, val, batch_size=conf.batch_size, 29 | num_workers=conf.num_workers, 30 | num_gpus=conf.num_gpus) 31 | else: 32 | train, val, _ = VG.splits(num_val_im=conf.val_size, filter_non_overlap=False, 33 | filter_empty_rels=False, use_proposals=conf.use_proposals) 34 | train_loader, val_loader = VGDataLoader.splits(train, val, batch_size=conf.batch_size, 35 | num_workers=conf.num_workers, 36 | num_gpus=conf.num_gpus) 37 | 38 | detector = ObjectDetector(classes=train.ind_to_classes, num_gpus=conf.num_gpus, 39 | mode='rpntrain' if not conf.use_proposals else 'proposals', use_resnet=conf.use_resnet) 40 | detector.cuda() 41 | 42 | # Note: if you're doing the stanford setup, you'll need to change this to freeze the lower layers 43 | if conf.use_proposals: 44 | for n, param in detector.named_parameters(): 45 | if n.startswith('features'): 46 | param.requires_grad = False 47 | 48 | optimizer = optim.SGD([p for p in detector.parameters() if p.requires_grad], 49 | weight_decay=conf.l2, lr=conf.lr * conf.num_gpus * conf.batch_size, momentum=0.9) 50 | scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, 51 | verbose=True, threshold=0.001, threshold_mode='abs', cooldown=1) 52 | 53 | start_epoch = -1 54 | if conf.ckpt is not None: 55 | ckpt = torch.load(conf.ckpt) 56 | if optimistic_restore(detector, ckpt['state_dict']): 57 | start_epoch = ckpt['epoch'] 58 | 59 | 60 | def train_epoch(epoch_num): 61 | detector.train() 62 | tr = [] 63 | start = time.time() 64 | for b, batch in enumerate(train_loader): 65 | tr.append(train_batch(batch)) 66 | 67 | if b % conf.print_interval == 0 and b >= conf.print_interval: 68 | mn = pd.concat(tr[-conf.print_interval:], axis=1).mean(1) 69 | time_per_batch = (time.time() - start) / conf.print_interval 70 | print("\ne{:2d}b{:5d}/{:5d} {:.3f}s/batch, {:.1f}m/epoch".format( 71 | epoch_num, b, len(train_loader), time_per_batch, len(train_loader) * time_per_batch / 60)) 72 | print(mn) 73 | print('-----------', flush=True) 74 | start = time.time() 75 | return pd.concat(tr, axis=1) 76 | 77 | 78 | def train_batch(b): 79 | """ 80 | :param b: contains: 81 | :param imgs: the image, [batch_size, 3, IM_SIZE, IM_SIZE] 82 | :param all_anchors: [num_anchors, 4] the boxes of all anchors that we'll be using 83 | :param all_anchor_inds: [num_anchors, 2] array of the indices into the concatenated 84 | RPN feature vector that give us all_anchors, 85 | each one (img_ind, fpn_idx) 86 | :param im_sizes: a [batch_size, 4] numpy array of (h, w, scale, num_good_anchors) for each image. 87 | 88 | :param num_anchors_per_img: int, number of anchors in total over the feature pyramid per img 89 | 90 | Training parameters: 91 | :param train_anchor_inds: a [num_train, 5] array of indices for the anchors that will 92 | be used to compute the training loss (img_ind, fpn_idx) 93 | :param gt_boxes: [num_gt, 4] GT boxes over the batch. 94 | :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) 95 | 96 | :return: 97 | """ 98 | result = detector[b] 99 | scores = result.od_obj_dists 100 | box_deltas = result.od_box_deltas 101 | labels = result.od_obj_labels 102 | roi_boxes = result.od_box_priors 103 | bbox_targets = result.od_box_targets 104 | rpn_scores = result.rpn_scores 105 | rpn_box_deltas = result.rpn_box_deltas 106 | 107 | # detector loss 108 | valid_inds = (labels.data != 0).nonzero().squeeze(1) 109 | fg_cnt = valid_inds.size(0) 110 | bg_cnt = labels.size(0) - fg_cnt 111 | class_loss = F.cross_entropy(scores, labels) 112 | 113 | # No gather_nd in pytorch so instead convert first 2 dims of tensor to 1d 114 | box_reg_mult = 2 * (1. / FG_FRACTION) * fg_cnt / (fg_cnt + bg_cnt + 1e-4) 115 | twod_inds = valid_inds * box_deltas.size(1) + labels[valid_inds].data 116 | 117 | box_loss = bbox_loss(roi_boxes[valid_inds], box_deltas.view(-1, 4)[twod_inds], 118 | bbox_targets[valid_inds]) * box_reg_mult 119 | 120 | loss = class_loss + box_loss 121 | 122 | # RPN loss 123 | if not conf.use_proposals: 124 | train_anchor_labels = b.train_anchor_labels[:, -1] 125 | train_anchors = b.train_anchors[:, :4] 126 | train_anchor_targets = b.train_anchors[:, 4:] 127 | 128 | train_valid_inds = (train_anchor_labels.data == 1).nonzero().squeeze(1) 129 | rpn_class_loss = F.cross_entropy(rpn_scores, train_anchor_labels) 130 | 131 | # print("{} fg {} bg, ratio of {:.3f} vs {:.3f}. RPN {}fg {}bg ratio of {:.3f} vs {:.3f}".format( 132 | # fg_cnt, bg_cnt, fg_cnt / (fg_cnt + bg_cnt + 1e-4), FG_FRACTION, 133 | # train_valid_inds.size(0), train_anchor_labels.size(0)-train_valid_inds.size(0), 134 | # train_valid_inds.size(0) / (train_anchor_labels.size(0) + 1e-4), RPN_FG_FRACTION), flush=True) 135 | rpn_box_mult = 2 * (1. / RPN_FG_FRACTION) * train_valid_inds.size(0) / (train_anchor_labels.size(0) + 1e-4) 136 | rpn_box_loss = bbox_loss(train_anchors[train_valid_inds], 137 | rpn_box_deltas[train_valid_inds], 138 | train_anchor_targets[train_valid_inds]) * rpn_box_mult 139 | 140 | loss += rpn_class_loss + rpn_box_loss 141 | res = pd.Series([rpn_class_loss.data[0], rpn_box_loss.data[0], 142 | class_loss.data[0], box_loss.data[0], loss.data[0]], 143 | ['rpn_class_loss', 'rpn_box_loss', 'class_loss', 'box_loss', 'total']) 144 | else: 145 | res = pd.Series([class_loss.data[0], box_loss.data[0], loss.data[0]], 146 | ['class_loss', 'box_loss', 'total']) 147 | 148 | optimizer.zero_grad() 149 | loss.backward() 150 | clip_grad_norm( 151 | [(n, p) for n, p in detector.named_parameters() if p.grad is not None], 152 | max_norm=conf.clip, clip=True) 153 | optimizer.step() 154 | 155 | return res 156 | 157 | 158 | def val_epoch(): 159 | detector.eval() 160 | # all_boxes is a list of length number-of-classes. 161 | # Each list element is a list of length number-of-images. 162 | # Each of those list elements is either an empty list [] 163 | # or a numpy array of detection. 164 | vr = [] 165 | for val_b, batch in enumerate(val_loader): 166 | vr.append(val_batch(val_b, batch)) 167 | vr = np.concatenate(vr, 0) 168 | if vr.shape[0] == 0: 169 | print("No detections anywhere") 170 | return 0.0 171 | 172 | val_coco = val.coco 173 | coco_dt = val_coco.loadRes(vr) 174 | coco_eval = COCOeval(val_coco, coco_dt, 'bbox') 175 | coco_eval.params.imgIds = val.ids if conf.coco else [x for x in range(len(val))] 176 | 177 | coco_eval.evaluate() 178 | coco_eval.accumulate() 179 | coco_eval.summarize() 180 | mAp = coco_eval.stats[1] 181 | return mAp 182 | 183 | 184 | def val_batch(batch_num, b): 185 | result = detector[b] 186 | if result is None: 187 | return np.zeros((0, 7)) 188 | scores_np = result.obj_scores.data.cpu().numpy() 189 | cls_preds_np = result.obj_preds.data.cpu().numpy() 190 | boxes_np = result.boxes_assigned.data.cpu().numpy() 191 | im_inds_np = result.im_inds.data.cpu().numpy() 192 | im_scales = b.im_sizes.reshape((-1, 3))[:, 2] 193 | if conf.coco: 194 | boxes_np /= im_scales[im_inds_np][:, None] 195 | boxes_np[:, 2:4] = boxes_np[:, 2:4] - boxes_np[:, 0:2] + 1 196 | cls_preds_np[:] = [val.ind_to_id[c_ind] for c_ind in cls_preds_np] 197 | im_inds_np[:] = [val.ids[im_ind + batch_num * conf.batch_size * conf.num_gpus] 198 | for im_ind in im_inds_np] 199 | else: 200 | boxes_np *= BOX_SCALE / IM_SCALE 201 | boxes_np[:, 2:4] = boxes_np[:, 2:4] - boxes_np[:, 0:2] + 1 202 | im_inds_np += batch_num * conf.batch_size * conf.num_gpus 203 | 204 | return np.column_stack((im_inds_np, boxes_np, scores_np, cls_preds_np)) 205 | 206 | 207 | print("Training starts now!") 208 | for epoch in range(start_epoch + 1, start_epoch + 1 + conf.num_epochs): 209 | rez = train_epoch(epoch) 210 | print("overall{:2d}: ({:.3f})\n{}".format(epoch, rez.mean(1)['total'], rez.mean(1)), flush=True) 211 | mAp = val_epoch() 212 | scheduler.step(mAp) 213 | 214 | torch.save({ 215 | 'epoch': epoch, 216 | 'state_dict': detector.state_dict(), 217 | 'optimizer': optimizer.state_dict(), 218 | }, os.path.join(conf.save_dir, '{}-{}.tar'.format('coco' if conf.coco else 'vg', epoch))) 219 | -------------------------------------------------------------------------------- /scripts/eval_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is a script that will evaluate all models for SGCLS 4 | export CUDA_VISIBLE_DEVICES=1 5 | 6 | if [ $1 == "0" ]; then 7 | python models/eval_rels.py -m predcls -model motifnet -order leftright -nl_obj 1 -nl_edge 1 -b 6 -clip 5 \ 8 | -p 100 -hidden_dim 512 -pooling_dim 4096 -lr 1e-3 -ngpu 1 -test -ckpt checkpoints/predcls.tar -nepoch 50 -use_bias -use_encoded_box 9 | 10 | elif [ $1 == "1" ]; then 11 | python models/eval_rels.py -m sgcls -model motifnet -order leftright -nl_obj 1 -nl_edge 1 -b 6 -clip 5 \ 12 | -p 100 -hidden_dim 512 -pooling_dim 4096 -lr 1e-3 -ngpu 1 -test -ckpt checkpoints/sgcls.tar -nepoch 50 -use_bias -use_encoded_box 13 | 14 | elif [ $1 == "2" ]; then 15 | python models/eval_rels.py -m sgdet -model motifnet -order leftright -nl_obj 1 -nl_edge 1 -b 6 -clip 5 \ 16 | -p 100 -hidden_dim 512 -pooling_dim 4096 -lr 1e-3 -ngpu 1 -test -ckpt checkpoints/sgdet.tar -nepoch 50 -use_bias -use_encoded_box 17 | fi 18 | -------------------------------------------------------------------------------- /scripts/pretrain_detector.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Train the model without COCO pretraining 3 | python models/train_detector.py -b 6 -lr 1e-3 -save_dir checkpoints/vgdet -nepoch 50 -ngpu 3 -nwork 3 -p 100 -clip 5 4 | 5 | # If you want to evaluate on the frequency baseline now, run this command (replace the checkpoint with the 6 | # best checkpoint you found). 7 | #export CUDA_VISIBLE_DEVICES=0 8 | #python models/eval_rel_count.py -ngpu 1 -b 6 -ckpt checkpoints/vgdet/vg-24.tar -nwork 1 -p 100 -test 9 | #export CUDA_VISIBLE_DEVICES=1 10 | #python models/eval_rel_count.py -ngpu 1 -b 6 -ckpt checkpoints/vgdet/vg-28.tar -nwork 1 -p 100 -test 11 | #export CUDA_VISIBLE_DEVICES=2 12 | #python models/eval_rel_count.py -ngpu 1 -b 6 -ckpt checkpoints/vgdet/vg-28.tar -nwork 1 -p 100 -test 13 | # 14 | # 15 | -------------------------------------------------------------------------------- /scripts/refine_for_detection.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Refine Motifnet for detection 4 | 5 | 6 | export CUDA_VISIBLE_DEVICES=$1 7 | 8 | if [ $1 == "0" ]; then 9 | echo "TRAINING THE BASELINE" 10 | python models/train_rels.py -m sgdet -model motifnet -nl_obj 0 -nl_edge 0 -b 6 \ 11 | -clip 5 -p 100 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/baseline-sgcls/vgrel-11.tar -save_dir checkpoints/baseline-sgdet \ 12 | -nepoch 50 -use_bias 13 | elif [ $1 == "1" ]; then 14 | echo "TRAINING STANFORD" 15 | python models/train_rels.py -m sgdet -model stanford -b 6 -p 100 -lr 1e-4 -ngpu 1 -clip 5 \ 16 | -ckpt checkpoints/stanford-sgcls/vgrel-11.tar -save_dir checkpoints/stanford-sgdet 17 | elif [ $1 == "2" ]; then 18 | echo "Refining Motifnet for detection!" 19 | python models/train_rels.py -m sgdet -model motifnet -order leftright -nl_obj 2 -nl_edge 4 -b 6 -clip 5 \ 20 | -p 100 -hidden_dim 512 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/motifnet-sgcls/vgrel-7.tar \ 21 | -save_dir checkpoints/motifnet-sgdet -nepoch 10 -use_bias 22 | fi -------------------------------------------------------------------------------- /scripts/train_models_sgcls.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is a script that will train all of the models for scene graph classification and then evaluate them. 4 | export CUDA_VISIBLE_DEVICES=$1 5 | 6 | if [ $1 == "0" ]; then 7 | echo "TRAINING THE BASELINE" 8 | python models/train_rels.py -m sgcls -model motifnet -nl_obj 0 -nl_edge 0 -b 6 \ 9 | -clip 5 -p 100 -pooling_dim 4096 -lr 1e-3 -ngpu 1 -ckpt checkpoints/vgdet/vg-24.tar -save_dir checkpoints/baseline2 \ 10 | -nepoch 50 -use_bias 11 | elif [ $1 == "1" ]; then 12 | echo "TRAINING MESSAGE PASSING" 13 | 14 | python models/train_rels.py -m sgcls -model stanford -b 6 -p 100 -lr 1e-3 -ngpu 1 -clip 5 \ 15 | -ckpt checkpoints/vgdet/vg-24.tar -save_dir checkpoints/stanford2 16 | elif [ $1 == "2" ]; then 17 | echo "TRAINING MOTIFNET" 18 | 19 | python models/train_rels.py -m sgcls -model motifnet -order leftright -nl_obj 2 -nl_edge 4 -b 6 -clip 5 \ 20 | -p 100 -hidden_dim 512 -pooling_dim 4096 -lr 1e-3 -ngpu 1 -ckpt checkpoints/vgdet/vg-24.tar \ 21 | -save_dir checkpoints/motifnet2 -nepoch 50 -use_bias 22 | fi 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /scripts/train_stanford.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python models/train_rels.py -m sgcls -model stanford -b 4 -p 400 -lr 1e-4 -ngpu 1 -ckpt checkpoints/vgdet/vg-24.tar -save_dir checkpoints/stanford -adam 4 | 5 | # To test you can run this command 6 | # python models/eval_rels.py -m sgcls -model stanford -ngpu 1 -ckpt checkpoints/stanford/vgrel-28.tar -test 7 | -------------------------------------------------------------------------------- /scripts/train_vctreenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Train VCTREE using different orderings 4 | 5 | export CUDA_VISIBLE_DEVICES=1 6 | 7 | if [ $1 == "0" ]; then 8 | echo "TRAINING VCTREE V1" 9 | python models/train_rels.py -m sgdet -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 1 -clip 5 \ 10 | -p 2000 -hidden_dim 512 -pooling_dim 4096 -lr 1e-5 -ngpu 1 -ckpt checkpoints/vctree-sgdet-19.tar\ 11 | -save_dir checkpoints/vctree-sgdet-rl -nepoch 50 -use_bias -use_encoded_box -use_rl_tree 12 | elif [ $1 == "1" ]; then 13 | echo "TRAINING VCTREE V2" 14 | python models/train_rels.py -m sgdet -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 2 -clip 3 \ 15 | -p 1000 -hidden_dim 512 -pooling_dim 4096 -lr 1e-5 -ngpu 1 -ckpt checkpoints/vg-faster-rcnn.tar\ 16 | -save_dir checkpoints/vctree-sgdet -nepoch 50 -use_bias -use_encoded_box 17 | 18 | 19 | elif [ $1 == "2" ]; then 20 | echo "TRAINING VCTREE V3" 21 | python models/train_rels.py -m sgcls -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 1 -clip 5 \ 22 | -p 2000 -hidden_dim 512 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/vctree-sgcls-19.tar\ 23 | -save_dir checkpoints/vctree-sgcls-rl -nepoch 50 -use_bias -use_encoded_box -use_rl_tree 24 | elif [ $1 == "3" ]; then 25 | echo "TRAINING VCTREE V3" 26 | python models/train_rels.py -m sgcls -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 5 -clip 5 \ 27 | -p 500 -hidden_dim 512 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/vg-faster-rcnn.tar\ 28 | -save_dir checkpoints/vctree-sgcls -nepoch 50 -use_bias -use_encoded_box 29 | 30 | elif [ $1 == "4" ]; then 31 | echo "TRAINING VCTREE V3" 32 | python models/train_rels.py -m predcls -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 1 -clip 5 \ 33 | -p 2000 -hidden_dim 512 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/vctree-predcls-18.tar\ 34 | -save_dir checkpoints/vctree-predcls-rl -nepoch 50 -use_bias -use_encoded_box -use_rl_tree 35 | elif [ $1 == "5" ]; then 36 | echo "TRAINING VCTREE V3" 37 | python models/train_rels.py -m predcls -model motifnet -order confidence -nl_obj 1 -nl_edge 1 -b 5 -clip 5 \ 38 | -p 2000 -hidden_dim 512 -pooling_dim 4096 -lr 1e-4 -ngpu 1 -ckpt checkpoints/vg-faster-rcnn.tar\ 39 | -save_dir checkpoints/vctree-predcls -nepoch 50 -use_bias -use_encoded_box 40 | 41 | fi 42 | --------------------------------------------------------------------------------