├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── backbone.py ├── data ├── __init__.py ├── coco.py ├── config.py ├── grid.npy ├── scripts │ ├── COCO.sh │ ├── COCO_test.sh │ └── mix_sets.py ├── yolact_example_0.png ├── yolact_example_1.png └── yolact_example_2.png ├── environment.yml ├── eval.py ├── external └── DCNv2 │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── dcn_v2.py │ ├── setup.py │ ├── src │ ├── cpu │ │ ├── dcn_v2_cpu.cpp │ │ └── vision.h │ ├── cuda │ │ ├── dcn_v2_cuda.cu │ │ ├── dcn_v2_im2col_cuda.cu │ │ ├── dcn_v2_im2col_cuda.h │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ └── vision.h │ ├── dcn_v2.h │ └── vision.cpp │ └── test.py ├── layers ├── __init__.py ├── box_utils.py ├── functions │ ├── __init__.py │ └── detection.py ├── interpolate.py ├── modules │ ├── __init__.py │ └── multibox_loss.py └── output_utils.py ├── run_coco_eval.py ├── scripts ├── augment_bbox.py ├── bbox_recall.py ├── cluster_bbox_sizes.py ├── compute_masks.py ├── convert_darknet.py ├── convert_sbd.py ├── eval.sh ├── make_grid.py ├── optimize_bboxes.py ├── parse_eval.py ├── plot_loss.py ├── resume.sh ├── save_bboxes.py ├── train.sh └── unpack_statedict.py ├── train.py ├── utils ├── __init__.py ├── augmentations.py ├── cython_nms.pyx ├── functions.py ├── logger.py ├── nvinfo.py └── timer.py ├── web ├── css │ ├── index.css │ ├── list.css │ ├── toggle.css │ └── viewer.css ├── dets │ ├── ssd300.json │ ├── ssd550.json │ ├── ssd550_resnet101.json │ ├── test.json │ ├── yolact_base.json │ ├── yolact_darknet53.json │ ├── yolact_im700.json │ ├── yolact_resnet101_conv4.json │ ├── yolact_resnet101_maskrcnn.json │ ├── yolact_resnet101_maskrcnn_1.json │ ├── yolact_resnet50.json │ ├── yrm12.json │ ├── yrm13.json │ ├── yrm16_2.json │ ├── yrm18.json │ ├── yrm19.json │ ├── yrm21.json │ ├── yrm25_b.json │ ├── yrm28_2_perfect.json │ ├── yrm35_crop.json │ └── yrm35_retina.json ├── index.html ├── iou.html ├── scripts │ ├── index.js │ ├── iou.js │ ├── jquery.js │ ├── js.cookie.js │ ├── utils.js │ └── viewer.js ├── server.py └── viewer.html └── yolact.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # atom remote-sync package 92 | .remote-sync.json 93 | 94 | # weights 95 | weights/ 96 | 97 | #DS_Store 98 | .DS_Store 99 | 100 | # dev stuff 101 | eval/ 102 | eval.ipynb 103 | dev.ipynb 104 | .vscode/ 105 | 106 | # not ready 107 | videos/ 108 | templates/ 109 | data/ssd_dataloader.py 110 | data/datasets/ 111 | doc/visualize.py 112 | read_results.py 113 | ssd300_120000/ 114 | demos/live 115 | webdemo.py 116 | test_data_aug.py 117 | 118 | # attributes 119 | 120 | # pycharm 121 | .idea/ 122 | 123 | # temp checkout soln 124 | data/coco 125 | data/sbd 126 | data/cityscapes 127 | 128 | # pylint 129 | .pylintrc 130 | 131 | # ssd.pytorch master branch (for merging) 132 | ssd.pytorch/ 133 | 134 | # some datasets 135 | data/VOCdevkit/ 136 | data/coco/images/ 137 | data/coco/annotations/ 138 | ap_data.pkl 139 | results/ 140 | logs/ 141 | scripts/aws/ 142 | scripts/gt.npy 143 | scripts/proto.npy 144 | scripts/info.txt 145 | test.pkl 146 | testeval.py 147 | scripts/aws2/ 148 | status.sh 149 | train.sh 150 | img/ 151 | scripts/aws-ohio/ 152 | scripts/aws3/ 153 | data/config_dev.py 154 | data/coco/ 155 | data/sbd/ 156 | 157 | vid/ 158 | vidres/ 159 | 160 | _.py 161 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # YOLACT Change Log 2 | 3 | This document will detail all changes I make. 4 | I don't know how I'm going to be versioning things yet, so you get dates for now. 5 | 6 | ``` 7 | 2020.01.25: 8 | - Fixed the mask IoU branch crashing when all masks in a batch are discarded (fixes #302, #259). 9 | 2020.01.24: 10 | - Fixed the conv layer detection during initialization to work with pytorch 1.4 (fixes #292). 11 | 2020.01.23: 12 | - Fixed the video playback crashing if there's nothing in the scene (fixes #266). 13 | - Fixed the logger logging the last loss as total loss instead of the actual total (fixes #254). 14 | 15 | 2019.12.16 (v1.2): 16 | - Added YOLACT++ implementation, paper, and code. 17 | - Added DCN support (need to compile CUDA kernels if you want to use them, see README). 18 | - Added a mask rescoring network trained with mask iou. 19 | - Added configs with more anchors. 20 | 21 | 2019.12.06: 22 | - Made training much more stable (no more infs and hopefully fewer loss explosions) by ignoring 23 | augmented boxes with < 4px of height and width (this includes 0 area boxes which caused the inf). 24 | See #222 for details. 25 | 26 | 2019.11.20: 27 | - Fixed bug where saving videos wouldn't work when using cv2 not compiled with display support (#197). 28 | 29 | 2019.11.06: 30 | - Changed Cython import to only active when using traditional nms. 31 | - Added cross-class fast NMS. 32 | 33 | 2019.11.04: 34 | - Fixed a bug where the learning rate auto-scaling wasn't being applied properly. 35 | - Fixed a logging bug were lr was sometimes not properly logged after a resume. 36 | 37 | 2019.10.25 (v1.1): 38 | - Added proper Multi-GPU support. Simply increase your batch size to 8*num_gpus and everything will scale. 39 | - I get an ~1.8x training speed increase when using 2 gpus and an ~3x increase when using 4. 40 | - Added a logger that logs everything about your training. 41 | - Check the Logging section of the README to see how to visualize your logs. (Not written yet) 42 | - Savevideo now uses the evalvideo framework and suports --video_multiframe. It's much faster now! 43 | - Added the ability to display fps right on the videos themselves by using --display_fps 44 | - Evalvideo now doesn't crash when it runs out of frames. 45 | - Pascal SBD is now officially supported! Check the training section for more details. 46 | - Preserve_aspect_ratio kinda sorta works now, but it's iffy and the way I have it set up doesn't perform better. 47 | - Added a ton of new config settings, most of which don't improve performance :/ 48 | 49 | 2019.09.20 50 | - Fixed a bug where custom label maps weren't being applied properly because of global default argument initialization. 51 | 2019.08.29 52 | - Fixed a bug where the fpn conv layers weren't getting initialized with xavier since they were being overwritten by jit modules (see #127). 53 | 2019.08.04 54 | - Improved the matching algorithm used to match anchors to gt by making it less greedy (see #104). 55 | 2019.06.27 56 | - Sped up save video by ~8 ms per frame because I forgot to apply a speed fix I applied to the other modes. 57 | ``` 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Daniel Bolya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Y**ou **O**nly **L**ook **A**t **C**oefficien**T**s 2 | ``` 3 | ██╗ ██╗ ██████╗ ██╗ █████╗ ██████╗████████╗ 4 | ╚██╗ ██╔╝██╔═══██╗██║ ██╔══██╗██╔════╝╚══██╔══╝ 5 | ╚████╔╝ ██║ ██║██║ ███████║██║ ██║ 6 | ╚██╔╝ ██║ ██║██║ ██╔══██║██║ ██║ 7 | ██║ ╚██████╔╝███████╗██║ ██║╚██████╗ ██║ 8 | ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ 9 | ``` 10 | 11 | A simple, fully convolutional model for real-time instance segmentation. This is the code for our papers: 12 | - [YOLACT: Real-time Instance Segmentation](https://arxiv.org/abs/1904.02689) 13 | - [YOLACT++: Better Real-time Instance Segmentation](https://arxiv.org/abs/1912.06218) 14 | 15 | #### YOLACT++ (v1.2) released! ([Changelog](CHANGELOG.md)) 16 | YOLACT++'s resnet50 model runs at 33.5 fps on a Titan Xp and achieves 34.1 mAP on COCO's `test-dev` (check out our journal paper [here](https://arxiv.org/abs/1912.06218)). 17 | 18 | In order to use YOLACT++, make sure you compile the DCNv2 code. (See [Installation](https://github.com/dbolya/yolact#installation)) 19 | 20 | #### For a real-time demo, check out our ICCV video: 21 | [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/0pMfmo8qfpQ/0.jpg)](https://www.youtube.com/watch?v=0pMfmo8qfpQ) 22 | 23 | Some examples from our YOLACT base model (33.5 fps on a Titan Xp and 29.8 mAP on COCO's `test-dev`): 24 | 25 | ![Example 0](data/yolact_example_0.png) 26 | 27 | ![Example 1](data/yolact_example_1.png) 28 | 29 | ![Example 2](data/yolact_example_2.png) 30 | 31 | # Installation 32 | - Clone this repository and enter it: 33 | ```Shell 34 | git clone https://github.com/dbolya/yolact.git 35 | cd yolact 36 | ``` 37 | - Set up the environment using one of the following methods: 38 | - Using [Anaconda](https://www.anaconda.com/distribution/) 39 | - Run `conda env create -f environment.yml` 40 | - Manually with pip 41 | - Set up a Python3 environment (e.g., using virtenv). 42 | - Install [Pytorch](http://pytorch.org/) 1.0.1 (or higher) and TorchVision. 43 | - Install some other packages: 44 | ```Shell 45 | # Cython needs to be installed before pycocotools 46 | pip install cython 47 | pip install opencv-python pillow pycocotools matplotlib 48 | ``` 49 | - If you'd like to train YOLACT, download the COCO dataset and the 2014/2017 annotations. Note that this script will take a while and dump 21gb of files into `./data/coco`. 50 | ```Shell 51 | sh data/scripts/COCO.sh 52 | ``` 53 | - If you'd like to evaluate YOLACT on `test-dev`, download `test-dev` with this script. 54 | ```Shell 55 | sh data/scripts/COCO_test.sh 56 | ``` 57 | - If you want to use YOLACT++, compile deformable convolutional layers (from [DCNv2](https://github.com/CharlesShang/DCNv2/tree/pytorch_1.0)). 58 | Make sure you have the latest CUDA toolkit installed from [NVidia's Website](https://developer.nvidia.com/cuda-toolkit). 59 | ```Shell 60 | cd external/DCNv2 61 | python setup.py build develop 62 | ``` 63 | 64 | 65 | # Evaluation 66 | Here are our YOLACT models (released on April 5th, 2019) along with their FPS on a Titan Xp and mAP on `test-dev`: 67 | 68 | | Image Size | Backbone | FPS | mAP | Weights | | 69 | |:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|--------| 70 | | 550 | Resnet50-FPN | 42.5 | 28.2 | [yolact_resnet50_54_800000.pth](https://drive.google.com/file/d/1yp7ZbbDwvMiFJEq4ptVKTYTI2VeRDXl0/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EUVpxoSXaqNIlssoLKOEoCcB1m0RpzGq_Khp5n1VX3zcUw) | 71 | | 550 | Darknet53-FPN | 40.0 | 28.7 | [yolact_darknet53_54_800000.pth](https://drive.google.com/file/d/1dukLrTzZQEuhzitGkHaGjphlmRJOjVnP/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/ERrao26c8llJn25dIyZPhwMBxUp2GdZTKIMUQA3t0djHLw) 72 | | 550 | Resnet101-FPN | 33.5 | 29.8 | [yolact_base_54_800000.pth](https://drive.google.com/file/d/1UYy3dMapbH1BnmtZU4WH1zbYgOzzHHf_/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EYRWxBEoKU9DiblrWx2M89MBGFkVVB_drlRd_v5sdT3Hgg) 73 | | 700 | Resnet101-FPN | 23.6 | 31.2 | [yolact_im700_54_800000.pth](https://drive.google.com/file/d/1lE4Lz5p25teiXV-6HdTiOJSnS7u7GBzg/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/Eagg5RSc5hFEhp7sPtvLNyoBjhlf2feog7t8OQzHKKphjw) 74 | 75 | YOLACT++ models (released on December 16th, 2019): 76 | 77 | | Image Size | Backbone | FPS | mAP | Weights | | 78 | |:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|--------| 79 | | 550 | Resnet50-FPN | 33.5 | 34.1 | [yolact_plus_resnet50_54_800000.pth](https://drive.google.com/file/d/1ZPu1YR2UzGHQD0o1rEqy-j5bmEm3lbyP/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EcJAtMiEFlhAnVsDf00yWRIBUC4m8iE9NEEiV05XwtEoGw) | 80 | | 550 | Resnet101-FPN | 27.3 | 34.6 | [yolact_plus_base_54_800000.pth](https://drive.google.com/file/d/15id0Qq5eqRbkD-N3ZjDZXdCvRyIaHpFB/view?usp=sharing) | [Mirror](https://ucdavis365-my.sharepoint.com/:u:/g/personal/yongjaelee_ucdavis_edu/EVQ62sF0SrJPrl_68onyHF8BpG7c05A8PavV4a849sZgEA) 81 | 82 | To evalute the model, put the corresponding weights file in the `./weights` directory and run one of the following commands. The name of each config is everything before the numbers in the file name (e.g., `yolact_base` for `yolact_base_54_800000.pth`). 83 | ## Quantitative Results on COCO 84 | ```Shell 85 | # Quantitatively evaluate a trained model on the entire validation set. Make sure you have COCO downloaded as above. 86 | # This should get 29.92 validation mask mAP last time I checked. 87 | python eval.py --trained_model=weights/yolact_base_54_800000.pth 88 | 89 | # Output a COCOEval json to submit to the website or to use the run_coco_eval.py script. 90 | # This command will create './results/bbox_detections.json' and './results/mask_detections.json' for detection and instance segmentation respectively. 91 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --output_coco_json 92 | 93 | # You can run COCOEval on the files created in the previous command. The performance should match my implementation in eval.py. 94 | python run_coco_eval.py 95 | 96 | # To output a coco json file for test-dev, make sure you have test-dev downloaded from above and go 97 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --output_coco_json --dataset=coco2017_testdev_dataset 98 | ``` 99 | ## Qualitative Results on COCO 100 | ```Shell 101 | # Display qualitative results on COCO. From here on I'll use a confidence threshold of 0.15. 102 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --display 103 | ``` 104 | ## Benchmarking on COCO 105 | ```Shell 106 | # Run just the raw model on the first 1k images of the validation set 107 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --benchmark --max_images=1000 108 | ``` 109 | ## Images 110 | ```Shell 111 | # Display qualitative results on the specified image. 112 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --image=my_image.png 113 | 114 | # Process an image and save it to another file. 115 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --image=input_image.png:output_image.png 116 | 117 | # Process a whole folder of images. 118 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --images=path/to/input/folder:path/to/output/folder 119 | ``` 120 | ## Video 121 | ```Shell 122 | # Display a video in real-time. "--video_multiframe" will process that many frames at once for improved performance. 123 | # If you want, use "--display_fps" to draw the FPS directly on the frame. 124 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --video_multiframe=4 --video=my_video.mp4 125 | 126 | # Display a webcam feed in real-time. If you have multiple webcams pass the index of the webcam you want instead of 0. 127 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --video_multiframe=4 --video=0 128 | 129 | # Process a video and save it to another file. This uses the same pipeline as the ones above now, so it's fast! 130 | python eval.py --trained_model=weights/yolact_base_54_800000.pth --score_threshold=0.15 --top_k=15 --video_multiframe=4 --video=input_video.mp4:output_video.mp4 131 | ``` 132 | As you can tell, `eval.py` can do a ton of stuff. Run the `--help` command to see everything it can do. 133 | ```Shell 134 | python eval.py --help 135 | ``` 136 | 137 | 138 | # Training 139 | By default, we train on COCO. Make sure to download the entire dataset using the commands above. 140 | - To train, grab an imagenet-pretrained model and put it in `./weights`. 141 | - For Resnet101, download `resnet101_reducedfc.pth` from [here](https://drive.google.com/file/d/1tvqFPd4bJtakOlmn-uIA492g2qurRChj/view?usp=sharing). 142 | - For Resnet50, download `resnet50-19c8e357.pth` from [here](https://drive.google.com/file/d/1Jy3yCdbatgXa5YYIdTCRrSV0S9V5g1rn/view?usp=sharing). 143 | - For Darknet53, download `darknet53.pth` from [here](https://drive.google.com/file/d/17Y431j4sagFpSReuPNoFcj9h7azDTZFf/view?usp=sharing). 144 | - Run one of the training commands below. 145 | - Note that you can press ctrl+c while training and it will save an `*_interrupt.pth` file at the current iteration. 146 | - All weights are saved in the `./weights` directory by default with the file name `__.pth`. 147 | ```Shell 148 | # Trains using the base config with a batch size of 8 (the default). 149 | python train.py --config=yolact_base_config 150 | 151 | # Trains yolact_base_config with a batch_size of 5. For the 550px models, 1 batch takes up around 1.5 gigs of VRAM, so specify accordingly. 152 | python train.py --config=yolact_base_config --batch_size=5 153 | 154 | # Resume training yolact_base with a specific weight file and start from the iteration specified in the weight file's name. 155 | python train.py --config=yolact_base_config --resume=weights/yolact_base_10_32100.pth --start_iter=-1 156 | 157 | # Use the help option to see a description of all available command line arguments 158 | python train.py --help 159 | ``` 160 | 161 | ## Multi-GPU Support 162 | YOLACT now supports multiple GPUs seamlessly during training: 163 | 164 | - Before running any of the scripts, run: `export CUDA_VISIBLE_DEVICES=[gpus]` 165 | - Where you should replace [gpus] with a comma separated list of the index of each GPU you want to use (e.g., 0,1,2,3). 166 | - You should still do this if only using 1 GPU. 167 | - You can check the indices of your GPUs with `nvidia-smi`. 168 | - Then, simply set the batch size to `8*num_gpus` with the training commands above. The training script will automatically scale the hyperparameters to the right values. 169 | - If you have memory to spare you can increase the batch size further, but keep it a multiple of the number of GPUs you're using. 170 | - If you want to allocate the images per GPU specific for different GPUs, you can use `--batch_alloc=[alloc]` where [alloc] is a comma seprated list containing the number of images on each GPU. This must sum to `batch_size`. 171 | 172 | ## Logging 173 | YOLACT now logs training and validation information by default. You can disable this with `--no_log`. A guide on how to visualize these logs is coming soon, but now you can look at `LogVizualizer` in `utils/logger.py` for help. 174 | 175 | ## Pascal SBD 176 | We also include a config for training on Pascal SBD annotations (for rapid experimentation or comparing with other methods). To train on Pascal SBD, proceed with the following steps: 177 | 1. Download the dataset from [here](http://home.bharathh.info/pubs/codes/SBD/download.html). It's the first link in the top "Overview" section (and the file is called `benchmark.tgz`). 178 | 2. Extract the dataset somewhere. In the dataset there should be a folder called `dataset/img`. Create the directory `./data/sbd` (where `.` is YOLACT's root) and copy `dataset/img` to `./data/sbd/img`. 179 | 4. Download the COCO-style annotations from [here](https://drive.google.com/open?id=1ExrRSPVctHW8Nxrn0SofU1lVhK5Wn0_S). 180 | 5. Extract the annotations into `./data/sbd/`. 181 | 6. Now you can train using `--config=yolact_resnet50_pascal_config`. Check that config to see how to extend it to other models. 182 | 183 | I will automate this all with a script soon, don't worry. Also, if you want the script I used to convert the annotations, I put it in `./scripts/convert_sbd.py`, but you'll have to check how it works to be able to use it because I don't actually remember at this point. 184 | 185 | If you want to verify our results, you can download our `yolact_resnet50_pascal_config` weights from [here](https://drive.google.com/open?id=1yLVwtkRtNxyl0kxeMCtPXJsXFFyc_FHe). This model should get 72.3 mask AP_50 and 56.2 mask AP_70. Note that the "all" AP isn't the same as the "vol" AP reported in others papers for pascal (they use an averages of the thresholds from `0.1 - 0.9` in increments of `0.1` instead of what COCO uses). 186 | 187 | ## Custom Datasets 188 | You can also train on your own dataset by following these steps: 189 | - Create a COCO-style Object Detection JSON annotation file for your dataset. The specification for this can be found [here](http://cocodataset.org/#format-data). Note that we don't use some fields, so the following may be omitted: 190 | - `info` 191 | - `liscense` 192 | - Under `image`: `license, flickr_url, coco_url, date_captured` 193 | - `categories` (we use our own format for categories, see below) 194 | - Create a definition for your dataset under `dataset_base` in `data/config.py` (see the comments in `dataset_base` for an explanation of each field): 195 | ```Python 196 | my_custom_dataset = dataset_base.copy({ 197 | 'name': 'My Dataset', 198 | 199 | 'train_images': 'path_to_training_images', 200 | 'train_info': 'path_to_training_annotation', 201 | 202 | 'valid_images': 'path_to_validation_images', 203 | 'valid_info': 'path_to_validation_annotation', 204 | 205 | 'has_gt': True, 206 | 'class_names': ('my_class_id_1', 'my_class_id_2', 'my_class_id_3', ...) 207 | }) 208 | ``` 209 | - A couple things to note: 210 | - Class IDs in the annotation file should start at 1 and increase sequentially on the order of `class_names`. If this isn't the case for your annotation file (like in COCO), see the field `label_map` in `dataset_base`. 211 | - If you do not want to create a validation split, use the same image path and annotations file for validation. By default (see `python train.py --help`), `train.py` will output validation mAP for the first 5000 images in the dataset every 2 epochs. 212 | - Finally, in `yolact_base_config` in the same file, change the value for `'dataset'` to `'my_custom_dataset'` or whatever you named the config object above. Then you can use any of the training commands in the previous section. 213 | 214 | #### Creating a Custom Dataset from Scratch 215 | See [this nice post by @Amit12690](https://github.com/dbolya/yolact/issues/70#issuecomment-504283008) for tips on how to annotate a custom dataset and prepare it for use with YOLACT. 216 | 217 | 218 | 219 | 220 | # Citation 221 | If you use YOLACT or this code base in your work, please cite 222 | ``` 223 | @inproceedings{yolact-iccv2019, 224 | author = {Daniel Bolya and Chong Zhou and Fanyi Xiao and Yong Jae Lee}, 225 | title = {YOLACT: {Real-time} Instance Segmentation}, 226 | booktitle = {ICCV}, 227 | year = {2019}, 228 | } 229 | ``` 230 | 231 | For YOLACT++, please cite 232 | ``` 233 | @article{yolact-plus-tpami2020, 234 | author = {Daniel Bolya and Chong Zhou and Fanyi Xiao and Yong Jae Lee}, 235 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 236 | title = {YOLACT++: Better Real-time Instance Segmentation}, 237 | year = {2020}, 238 | } 239 | ``` 240 | 241 | 242 | 243 | # Contact 244 | For questions about our paper or code, please contact [Daniel Bolya](mailto:dbolya@ucdavis.edu). 245 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .coco import * 3 | 4 | import torch 5 | import cv2 6 | import numpy as np 7 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import cv2 8 | import numpy as np 9 | from .config import cfg 10 | from pycocotools import mask as maskUtils 11 | import random 12 | 13 | def get_label_map(): 14 | if cfg.dataset.label_map is None: 15 | return {x+1: x+1 for x in range(len(cfg.dataset.class_names))} 16 | else: 17 | return cfg.dataset.label_map 18 | 19 | class COCOAnnotationTransform(object): 20 | """Transforms a COCO annotation into a Tensor of bbox coords and label index 21 | Initilized with a dictionary lookup of classnames to indexes 22 | """ 23 | def __init__(self): 24 | self.label_map = get_label_map() 25 | 26 | def __call__(self, target, width, height): 27 | """ 28 | Args: 29 | target (dict): COCO target json annotation as a python dict 30 | height (int): height 31 | width (int): width 32 | Returns: 33 | a list containing lists of bounding boxes [bbox coords, class idx] 34 | """ 35 | scale = np.array([width, height, width, height]) 36 | res = [] 37 | for obj in target: 38 | if 'bbox' in obj: 39 | bbox = obj['bbox'] 40 | label_idx = obj['category_id'] 41 | if label_idx >= 0: 42 | label_idx = self.label_map[label_idx] - 1 43 | final_box = list(np.array([bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]])/scale) 44 | final_box.append(label_idx) 45 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx] 46 | else: 47 | print("No bbox found for object ", obj) 48 | 49 | return res 50 | 51 | 52 | class COCODetection(data.Dataset): 53 | """`MS Coco Detection `_ Dataset. 54 | Args: 55 | root (string): Root directory where images are downloaded to. 56 | set_name (string): Name of the specific set of COCO images. 57 | transform (callable, optional): A function/transform that augments the 58 | raw images` 59 | target_transform (callable, optional): A function/transform that takes 60 | in the target (bbox) and transforms it. 61 | prep_crowds (bool): Whether or not to prepare crowds for the evaluation step. 62 | """ 63 | 64 | def __init__(self, image_path, info_file, transform=None, 65 | target_transform=None, 66 | dataset_name='MS COCO', has_gt=True): 67 | # Do this here because we have too many things named COCO 68 | from pycocotools.coco import COCO 69 | 70 | if target_transform is None: 71 | target_transform = COCOAnnotationTransform() 72 | 73 | self.root = image_path 74 | self.coco = COCO(info_file) 75 | 76 | self.ids = list(self.coco.imgToAnns.keys()) 77 | if len(self.ids) == 0 or not has_gt: 78 | self.ids = list(self.coco.imgs.keys()) 79 | 80 | self.transform = transform 81 | self.target_transform = COCOAnnotationTransform() 82 | 83 | self.name = dataset_name 84 | self.has_gt = has_gt 85 | 86 | def __getitem__(self, index): 87 | """ 88 | Args: 89 | index (int): Index 90 | Returns: 91 | tuple: Tuple (image, (target, masks, num_crowds)). 92 | target is the object returned by ``coco.loadAnns``. 93 | """ 94 | im, gt, masks, h, w, num_crowds = self.pull_item(index) 95 | return im, (gt, masks, num_crowds) 96 | 97 | def __len__(self): 98 | return len(self.ids) 99 | 100 | def pull_item(self, index): 101 | """ 102 | Args: 103 | index (int): Index 104 | Returns: 105 | tuple: Tuple (image, target, masks, height, width, crowd). 106 | target is the object returned by ``coco.loadAnns``. 107 | Note that if no crowd annotations exist, crowd will be None 108 | """ 109 | img_id = self.ids[index] 110 | 111 | if self.has_gt: 112 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 113 | 114 | # Target has {'segmentation', 'area', iscrowd', 'image_id', 'bbox', 'category_id'} 115 | target = [x for x in self.coco.loadAnns(ann_ids) if x['image_id'] == img_id] 116 | else: 117 | target = [] 118 | 119 | # Separate out crowd annotations. These are annotations that signify a large crowd of 120 | # objects of said class, where there is no annotation for each individual object. Both 121 | # during testing and training, consider these crowds as neutral. 122 | crowd = [x for x in target if ('iscrowd' in x and x['iscrowd'])] 123 | target = [x for x in target if not ('iscrowd' in x and x['iscrowd'])] 124 | num_crowds = len(crowd) 125 | 126 | for x in crowd: 127 | x['category_id'] = -1 128 | 129 | # This is so we ensure that all crowd annotations are at the end of the array 130 | target += crowd 131 | 132 | # The split here is to have compatibility with both COCO2014 and 2017 annotations. 133 | # In 2014, images have the pattern COCO_{train/val}2014_%012d.jpg, while in 2017 it's %012d.jpg. 134 | # Our script downloads the images as %012d.jpg so convert accordingly. 135 | file_name = self.coco.loadImgs(img_id)[0]['file_name'] 136 | 137 | if file_name.startswith('COCO'): 138 | file_name = file_name.split('_')[-1] 139 | 140 | path = osp.join(self.root, file_name) 141 | assert osp.exists(path), 'Image path does not exist: {}'.format(path) 142 | 143 | img = cv2.imread(path) 144 | height, width, _ = img.shape 145 | 146 | if len(target) > 0: 147 | # Pool all the masks for this image into one [num_objects,height,width] matrix 148 | masks = [self.coco.annToMask(obj).reshape(-1) for obj in target] 149 | masks = np.vstack(masks) 150 | masks = masks.reshape(-1, height, width) 151 | 152 | if self.target_transform is not None and len(target) > 0: 153 | target = self.target_transform(target, width, height) 154 | 155 | if self.transform is not None: 156 | if len(target) > 0: 157 | target = np.array(target) 158 | img, masks, boxes, labels = self.transform(img, masks, target[:, :4], 159 | {'num_crowds': num_crowds, 'labels': target[:, 4]}) 160 | 161 | # I stored num_crowds in labels so I didn't have to modify the entirety of augmentations 162 | num_crowds = labels['num_crowds'] 163 | labels = labels['labels'] 164 | 165 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 166 | else: 167 | img, _, _, _ = self.transform(img, np.zeros((1, height, width), dtype=np.float), np.array([[0, 0, 1, 1]]), 168 | {'num_crowds': 0, 'labels': np.array([0])}) 169 | masks = None 170 | target = None 171 | 172 | if target.shape[0] == 0: 173 | print('Warning: Augmentation output an example with no ground truth. Resampling...') 174 | return self.pull_item(random.randint(0, len(self.ids)-1)) 175 | 176 | return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowds 177 | 178 | def pull_image(self, index): 179 | '''Returns the original image object at index in PIL form 180 | 181 | Note: not using self.__getitem__(), as any transformations passed in 182 | could mess up this functionality. 183 | 184 | Argument: 185 | index (int): index of img to show 186 | Return: 187 | cv2 img 188 | ''' 189 | img_id = self.ids[index] 190 | path = self.coco.loadImgs(img_id)[0]['file_name'] 191 | return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR) 192 | 193 | def pull_anno(self, index): 194 | '''Returns the original annotation of image at index 195 | 196 | Note: not using self.__getitem__(), as any transformations passed in 197 | could mess up this functionality. 198 | 199 | Argument: 200 | index (int): index of img to get annotation of 201 | Return: 202 | list: [img_id, [(label, bbox coords),...]] 203 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 204 | ''' 205 | img_id = self.ids[index] 206 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 207 | return self.coco.loadAnns(ann_ids) 208 | 209 | def __repr__(self): 210 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 211 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 212 | fmt_str += ' Root Location: {}\n'.format(self.root) 213 | tmp = ' Transforms (if any): ' 214 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 215 | tmp = ' Target Transforms (if any): ' 216 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 217 | return fmt_str 218 | 219 | def enforce_size(img, targets, masks, num_crowds, new_w, new_h): 220 | """ Ensures that the image is the given size without distorting aspect ratio. """ 221 | with torch.no_grad(): 222 | _, h, w = img.size() 223 | 224 | if h == new_h and w == new_w: 225 | return img, targets, masks, num_crowds 226 | 227 | # Resize the image so that it fits within new_w, new_h 228 | w_prime = new_w 229 | h_prime = h * new_w / w 230 | 231 | if h_prime > new_h: 232 | w_prime *= new_h / h_prime 233 | h_prime = new_h 234 | 235 | w_prime = int(w_prime) 236 | h_prime = int(h_prime) 237 | 238 | # Do all the resizing 239 | img = F.interpolate(img.unsqueeze(0), (h_prime, w_prime), mode='bilinear', align_corners=False) 240 | img.squeeze_(0) 241 | 242 | # Act like each object is a color channel 243 | masks = F.interpolate(masks.unsqueeze(0), (h_prime, w_prime), mode='bilinear', align_corners=False) 244 | masks.squeeze_(0) 245 | 246 | # Scale bounding boxes (this will put them in the top left corner in the case of padding) 247 | targets[:, [0, 2]] *= (w_prime / new_w) 248 | targets[:, [1, 3]] *= (h_prime / new_h) 249 | 250 | # Finally, pad everything to be the new_w, new_h 251 | pad_dims = (0, new_w - w_prime, 0, new_h - h_prime) 252 | img = F.pad( img, pad_dims, mode='constant', value=0) 253 | masks = F.pad(masks, pad_dims, mode='constant', value=0) 254 | 255 | return img, targets, masks, num_crowds 256 | 257 | 258 | 259 | 260 | def detection_collate(batch): 261 | """Custom collate fn for dealing with batches of images that have a different 262 | number of associated object annotations (bounding boxes). 263 | 264 | Arguments: 265 | batch: (tuple) A tuple of tensor images and (lists of annotations, masks) 266 | 267 | Return: 268 | A tuple containing: 269 | 1) (tensor) batch of images stacked on their 0 dim 270 | 2) (list, list, list) annotations for a given image are stacked 271 | on 0 dim. The output gt is a tuple of annotations and masks. 272 | """ 273 | targets = [] 274 | imgs = [] 275 | masks = [] 276 | num_crowds = [] 277 | 278 | for sample in batch: 279 | imgs.append(sample[0]) 280 | targets.append(torch.FloatTensor(sample[1][0])) 281 | masks.append(torch.FloatTensor(sample[1][1])) 282 | num_crowds.append(sample[1][2]) 283 | 284 | return imgs, (targets, masks, num_crowds) 285 | -------------------------------------------------------------------------------- /data/grid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbolya/yolact/57b8f2d95e62e2e649b382f516ab41f949b57239/data/grid.npy -------------------------------------------------------------------------------- /data/scripts/COCO.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start=`date +%s` 4 | 5 | # handle optional download dir 6 | if [ -z "$1" ] 7 | then 8 | # navigate to ./data 9 | echo "navigating to ./data/ ..." 10 | mkdir -p ./data 11 | cd ./data/ 12 | mkdir -p ./coco 13 | cd ./coco 14 | mkdir -p ./images 15 | mkdir -p ./annotations 16 | else 17 | # check if specified dir is valid 18 | if [ ! -d $1 ]; then 19 | echo $1 " is not a valid directory" 20 | exit 0 21 | fi 22 | echo "navigating to " $1 " ..." 23 | cd $1 24 | fi 25 | 26 | if [ ! -d images ] 27 | then 28 | mkdir -p ./images 29 | fi 30 | 31 | # Download the image data. 32 | cd ./images 33 | echo "Downloading MSCOCO train images ..." 34 | curl -LO http://images.cocodataset.org/zips/train2017.zip 35 | echo "Downloading MSCOCO val images ..." 36 | curl -LO http://images.cocodataset.org/zips/val2017.zip 37 | 38 | cd ../ 39 | if [ ! -d annotations ] 40 | then 41 | mkdir -p ./annotations 42 | fi 43 | 44 | # Download the annotation data. 45 | cd ./annotations 46 | echo "Downloading MSCOCO train/val annotations ..." 47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip 48 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2017.zip 49 | echo "Finished downloading. Now extracting ..." 50 | 51 | # Unzip data 52 | echo "Extracting train images ..." 53 | unzip -qqjd ../images ../images/train2017.zip 54 | echo "Extracting val images ..." 55 | unzip -qqjd ../images ../images/val2017.zip 56 | echo "Extracting annotations ..." 57 | unzip -qqd .. ./annotations_trainval2014.zip 58 | unzip -qqd .. ./annotations_trainval2017.zip 59 | 60 | echo "Removing zip files ..." 61 | rm ../images/train2017.zip 62 | rm ../images/val2017.zip 63 | rm ./annotations_trainval2014.zip 64 | rm ./annotations_trainval2017.zip 65 | 66 | 67 | end=`date +%s` 68 | runtime=$((end-start)) 69 | 70 | echo "Completed in " $runtime " seconds" 71 | -------------------------------------------------------------------------------- /data/scripts/COCO_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start=`date +%s` 4 | 5 | # handle optional download dir 6 | if [ -z "$1" ] 7 | then 8 | # navigate to ./data 9 | echo "navigating to ./data/ ..." 10 | mkdir -p ./data 11 | cd ./data/ 12 | mkdir -p ./coco 13 | cd ./coco 14 | mkdir -p ./images 15 | mkdir -p ./annotations 16 | else 17 | # check if specified dir is valid 18 | if [ ! -d $1 ]; then 19 | echo $1 " is not a valid directory" 20 | exit 0 21 | fi 22 | echo "navigating to " $1 " ..." 23 | cd $1 24 | fi 25 | 26 | if [ ! -d images ] 27 | then 28 | mkdir -p ./images 29 | fi 30 | 31 | # Download the image data. 32 | cd ./images 33 | echo "Downloading MSCOCO test images ..." 34 | curl -LO http://images.cocodataset.org/zips/test2017.zip 35 | 36 | cd ../ 37 | if [ ! -d annotations ] 38 | then 39 | mkdir -p ./annotations 40 | fi 41 | 42 | # Download the annotation data. 43 | cd ./annotations 44 | echo "Downloading MSCOCO test info ..." 45 | curl -LO http://images.cocodataset.org/annotations/image_info_test2017.zip 46 | echo "Finished downloading. Now extracting ..." 47 | 48 | # Unzip data 49 | echo "Extracting train images ..." 50 | unzip -qqjd ../images ../images/test2017.zip 51 | echo "Extracting info ..." 52 | unzip -qqd .. ./image_info_test2017.zip 53 | 54 | echo "Removing zip files ..." 55 | rm ../images/test2017.zip 56 | rm ./image_info_test2017.zip 57 | 58 | 59 | end=`date +%s` 60 | runtime=$((end-start)) 61 | 62 | echo "Completed in " $runtime " seconds" 63 | -------------------------------------------------------------------------------- /data/scripts/mix_sets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from collections import defaultdict 5 | 6 | usage_text = """ 7 | This script creates a coco annotation file by mixing one or more existing annotation files. 8 | 9 | Usage: python data/scripts/mix_sets.py output_name [set1 range1 [set2 range2 [...]]] 10 | 11 | To use, specify the output annotation name and any number of set + range pairs, where the sets 12 | are in the form instances_.json and ranges are python-evalable ranges. The resulting 13 | json will be spit out as instances_.json in the same folder as the input sets. 14 | 15 | For instance, 16 | python data/scripts/mix_sets.py trainval35k train2014 : val2014 :-5000 17 | 18 | This will create an instance_trainval35k.json file with all images and corresponding annotations 19 | from train2014 and the first 35000 images from val2014. 20 | 21 | You can also specify only one set: 22 | python data/scripts/mix_sets.py minival5k val2014 -5000: 23 | 24 | This will take the last 5k images from val2014 and put it in instances_minival5k.json. 25 | """ 26 | 27 | annotations_path = 'data/coco/annotations/instances_%s.json' 28 | fields_to_combine = ('images', 'annotations') 29 | fields_to_steal = ('info', 'categories', 'licenses') 30 | 31 | if __name__ == '__main__': 32 | if len(sys.argv) < 4 or len(sys.argv) % 2 != 0: 33 | print(usage_text) 34 | exit() 35 | 36 | out_name = sys.argv[1] 37 | sets = sys.argv[2:] 38 | sets = [(sets[2*i], sets[2*i+1]) for i in range(len(sets)//2)] 39 | 40 | out = {x: [] for x in fields_to_combine} 41 | 42 | for idx, (set_name, range_str) in enumerate(sets): 43 | print('Loading set %s...' % set_name) 44 | with open(annotations_path % set_name, 'r') as f: 45 | set_json = json.load(f) 46 | 47 | # "Steal" some fields that don't need to be combined from the first set 48 | if idx == 0: 49 | for field in fields_to_steal: 50 | out[field] = set_json[field] 51 | 52 | print('Building image index...') 53 | image_idx = {x['id']: x for x in set_json['images']} 54 | 55 | print('Collecting annotations...') 56 | anns_idx = defaultdict(lambda: []) 57 | 58 | for ann in set_json['annotations']: 59 | anns_idx[ann['image_id']].append(ann) 60 | 61 | export_ids = list(image_idx.keys()) 62 | export_ids.sort() 63 | export_ids = eval('export_ids[%s]' % range_str, {}, {'export_ids': export_ids}) 64 | 65 | print('Adding %d images...' % len(export_ids)) 66 | for _id in export_ids: 67 | out['images'].append(image_idx[_id]) 68 | out['annotations'] += anns_idx[_id] 69 | 70 | print('Done.\n') 71 | 72 | print('Saving result...') 73 | with open(annotations_path % (out_name), 'w') as out_file: 74 | json.dump(out, out_file) 75 | -------------------------------------------------------------------------------- /data/yolact_example_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbolya/yolact/57b8f2d95e62e2e649b382f516ab41f949b57239/data/yolact_example_0.png -------------------------------------------------------------------------------- /data/yolact_example_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbolya/yolact/57b8f2d95e62e2e649b382f516ab41f949b57239/data/yolact_example_1.png -------------------------------------------------------------------------------- /data/yolact_example_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbolya/yolact/57b8f2d95e62e2e649b382f516ab41f949b57239/data/yolact_example_2.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # Installs dependencies for YOLACT managed by Anaconda. 2 | # Advantage is you get working CUDA+cuDNN+pytorch+torchvison versions. 3 | # 4 | # TODO: you must additionally install nVidia drivers, eg. on Ubuntu linux 5 | # `apt install nvidia-driver-440` (change the 440 for whatever version you need/have). 6 | # 7 | name: yolact-env 8 | #prefix: /your/custom/path/envs/yolact-env 9 | channels: 10 | - conda-forge 11 | - pytorch 12 | - defaults 13 | dependencies: 14 | - python==3.7 15 | - pip 16 | - cython 17 | - pytorch::torchvision 18 | - pytorch::pytorch >=1.0.1 19 | - cudatoolkit 20 | - cudnn 21 | - pytorch::cuda100 22 | - matplotlib 23 | - git # to download COCO dataset 24 | - curl # to download COCO dataset 25 | - unzip # to download COCO dataset 26 | - conda-forge::bash # to download COCO dataset 27 | - pip: 28 | - opencv-python 29 | - pillow <7.0 # bug PILLOW_VERSION in torchvision, must be < 7.0 until torchvision is upgraded 30 | - pycocotools 31 | - PyQt5 # needed on KDE/Qt envs for matplotlib 32 | 33 | -------------------------------------------------------------------------------- /external/DCNv2/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /external/DCNv2/README.md: -------------------------------------------------------------------------------- 1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0 2 | 3 | ### Build 4 | ```bash 5 | ./make.sh # build 6 | python test.py # run examples and gradient check 7 | ``` 8 | 9 | ### An Example 10 | - deformable conv 11 | ```python 12 | from dcn_v2 import DCN 13 | input = torch.randn(2, 64, 128, 128).cuda() 14 | # wrap all things (offset and mask) in DCN 15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() 16 | output = dcn(input) 17 | print(output.shape) 18 | ``` 19 | - deformable roi pooling 20 | ```python 21 | from dcn_v2 import DCNPooling 22 | input = torch.randn(2, 32, 64, 64).cuda() 23 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 24 | x = torch.randint(256, (20, 1)).cuda().float() 25 | y = torch.randint(256, (20, 1)).cuda().float() 26 | w = torch.randint(64, (20, 1)).cuda().float() 27 | h = torch.randint(64, (20, 1)).cuda().float() 28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 29 | 30 | # mdformable pooling (V2) 31 | # wrap all things (offset and mask) in DCNPooling 32 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 33 | pooled_size=7, 34 | output_dim=32, 35 | no_trans=False, 36 | group_size=1, 37 | trans_std=0.1).cuda() 38 | 39 | dout = dpooling(input, rois) 40 | ``` 41 | ### Note 42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, 43 | ```bash 44 | git checkout pytorch_0.4 45 | ``` 46 | 47 | ### Known Issues: 48 | 49 | - [x] Gradient check w.r.t offset (solved) 50 | - [ ] Backward is not reentrant (minor) 51 | 52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). 53 | 54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. 55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some 56 | non-differential points? 57 | 58 | Update: all gradient check passes with double precision. 59 | 60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for 61 | float `<1e-15` for double), 62 | so it may not be a serious problem (?) 63 | 64 | Please post an issue or PR if you have any comments. 65 | -------------------------------------------------------------------------------- /external/DCNv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbolya/yolact/57b8f2d95e62e2e649b382f516ab41f949b57239/external/DCNv2/__init__.py -------------------------------------------------------------------------------- /external/DCNv2/dcn_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Function 10 | from torch.nn.modules.utils import _pair 11 | from torch.autograd.function import once_differentiable 12 | 13 | import _ext as _backend 14 | 15 | 16 | class _DCNv2(Function): 17 | @staticmethod 18 | def forward(ctx, input, offset, mask, weight, bias, 19 | stride, padding, dilation, deformable_groups): 20 | ctx.stride = _pair(stride) 21 | ctx.padding = _pair(padding) 22 | ctx.dilation = _pair(dilation) 23 | ctx.kernel_size = _pair(weight.shape[2:4]) 24 | ctx.deformable_groups = deformable_groups 25 | output = _backend.dcn_v2_forward(input, weight, bias, 26 | offset, mask, 27 | ctx.kernel_size[0], ctx.kernel_size[1], 28 | ctx.stride[0], ctx.stride[1], 29 | ctx.padding[0], ctx.padding[1], 30 | ctx.dilation[0], ctx.dilation[1], 31 | ctx.deformable_groups) 32 | ctx.save_for_backward(input, offset, mask, weight, bias) 33 | return output 34 | 35 | @staticmethod 36 | @once_differentiable 37 | def backward(ctx, grad_output): 38 | input, offset, mask, weight, bias = ctx.saved_tensors 39 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ 40 | _backend.dcn_v2_backward(input, weight, 41 | bias, 42 | offset, mask, 43 | grad_output, 44 | ctx.kernel_size[0], ctx.kernel_size[1], 45 | ctx.stride[0], ctx.stride[1], 46 | ctx.padding[0], ctx.padding[1], 47 | ctx.dilation[0], ctx.dilation[1], 48 | ctx.deformable_groups) 49 | 50 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ 51 | None, None, None, None, 52 | 53 | 54 | dcn_v2_conv = _DCNv2.apply 55 | 56 | 57 | class DCNv2(nn.Module): 58 | 59 | def __init__(self, in_channels, out_channels, 60 | kernel_size, stride, padding, dilation=1, deformable_groups=1): 61 | super(DCNv2, self).__init__() 62 | self.in_channels = in_channels 63 | self.out_channels = out_channels 64 | self.kernel_size = _pair(kernel_size) 65 | self.stride = _pair(stride) 66 | self.padding = _pair(padding) 67 | self.dilation = _pair(dilation) 68 | self.deformable_groups = deformable_groups 69 | 70 | self.weight = nn.Parameter(torch.Tensor( 71 | out_channels, in_channels, *self.kernel_size)) 72 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | n = self.in_channels 77 | for k in self.kernel_size: 78 | n *= k 79 | stdv = 1. / math.sqrt(n) 80 | self.weight.data.uniform_(-stdv, stdv) 81 | self.bias.data.zero_() 82 | 83 | def forward(self, input, offset, mask): 84 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 85 | offset.shape[1] 86 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 87 | mask.shape[1] 88 | return dcn_v2_conv(input, offset, mask, 89 | self.weight, 90 | self.bias, 91 | self.stride, 92 | self.padding, 93 | self.dilation, 94 | self.deformable_groups) 95 | 96 | 97 | class DCN(DCNv2): 98 | 99 | def __init__(self, in_channels, out_channels, 100 | kernel_size, stride, padding, 101 | dilation=1, deformable_groups=1): 102 | super(DCN, self).__init__(in_channels, out_channels, 103 | kernel_size, stride, padding, dilation, deformable_groups) 104 | 105 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 106 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 107 | channels_, 108 | kernel_size=self.kernel_size, 109 | stride=self.stride, 110 | padding=self.padding, 111 | bias=True) 112 | self.init_offset() 113 | 114 | def init_offset(self): 115 | self.conv_offset_mask.weight.data.zero_() 116 | self.conv_offset_mask.bias.data.zero_() 117 | 118 | def forward(self, input): 119 | out = self.conv_offset_mask(input) 120 | o1, o2, mask = torch.chunk(out, 3, dim=1) 121 | offset = torch.cat((o1, o2), dim=1) 122 | mask = torch.sigmoid(mask) 123 | return dcn_v2_conv(input, offset, mask, 124 | self.weight, self.bias, 125 | self.stride, 126 | self.padding, 127 | self.dilation, 128 | self.deformable_groups) 129 | 130 | 131 | 132 | class _DCNv2Pooling(Function): 133 | @staticmethod 134 | def forward(ctx, input, rois, offset, 135 | spatial_scale, 136 | pooled_size, 137 | output_dim, 138 | no_trans, 139 | group_size=1, 140 | part_size=None, 141 | sample_per_part=4, 142 | trans_std=.0): 143 | ctx.spatial_scale = spatial_scale 144 | ctx.no_trans = int(no_trans) 145 | ctx.output_dim = output_dim 146 | ctx.group_size = group_size 147 | ctx.pooled_size = pooled_size 148 | ctx.part_size = pooled_size if part_size is None else part_size 149 | ctx.sample_per_part = sample_per_part 150 | ctx.trans_std = trans_std 151 | 152 | output, output_count = \ 153 | _backend.dcn_v2_psroi_pooling_forward(input, rois, offset, 154 | ctx.no_trans, ctx.spatial_scale, 155 | ctx.output_dim, ctx.group_size, 156 | ctx.pooled_size, ctx.part_size, 157 | ctx.sample_per_part, ctx.trans_std) 158 | ctx.save_for_backward(input, rois, offset, output_count) 159 | return output 160 | 161 | @staticmethod 162 | @once_differentiable 163 | def backward(ctx, grad_output): 164 | input, rois, offset, output_count = ctx.saved_tensors 165 | grad_input, grad_offset = \ 166 | _backend.dcn_v2_psroi_pooling_backward(grad_output, 167 | input, 168 | rois, 169 | offset, 170 | output_count, 171 | ctx.no_trans, 172 | ctx.spatial_scale, 173 | ctx.output_dim, 174 | ctx.group_size, 175 | ctx.pooled_size, 176 | ctx.part_size, 177 | ctx.sample_per_part, 178 | ctx.trans_std) 179 | 180 | return grad_input, None, grad_offset, \ 181 | None, None, None, None, None, None, None, None 182 | 183 | 184 | dcn_v2_pooling = _DCNv2Pooling.apply 185 | 186 | 187 | class DCNv2Pooling(nn.Module): 188 | 189 | def __init__(self, 190 | spatial_scale, 191 | pooled_size, 192 | output_dim, 193 | no_trans, 194 | group_size=1, 195 | part_size=None, 196 | sample_per_part=4, 197 | trans_std=.0): 198 | super(DCNv2Pooling, self).__init__() 199 | self.spatial_scale = spatial_scale 200 | self.pooled_size = pooled_size 201 | self.output_dim = output_dim 202 | self.no_trans = no_trans 203 | self.group_size = group_size 204 | self.part_size = pooled_size if part_size is None else part_size 205 | self.sample_per_part = sample_per_part 206 | self.trans_std = trans_std 207 | 208 | def forward(self, input, rois, offset): 209 | assert input.shape[1] == self.output_dim 210 | if self.no_trans: 211 | offset = input.new() 212 | return dcn_v2_pooling(input, rois, offset, 213 | self.spatial_scale, 214 | self.pooled_size, 215 | self.output_dim, 216 | self.no_trans, 217 | self.group_size, 218 | self.part_size, 219 | self.sample_per_part, 220 | self.trans_std) 221 | 222 | 223 | class DCNPooling(DCNv2Pooling): 224 | 225 | def __init__(self, 226 | spatial_scale, 227 | pooled_size, 228 | output_dim, 229 | no_trans, 230 | group_size=1, 231 | part_size=None, 232 | sample_per_part=4, 233 | trans_std=.0, 234 | deform_fc_dim=1024): 235 | super(DCNPooling, self).__init__(spatial_scale, 236 | pooled_size, 237 | output_dim, 238 | no_trans, 239 | group_size, 240 | part_size, 241 | sample_per_part, 242 | trans_std) 243 | 244 | self.deform_fc_dim = deform_fc_dim 245 | 246 | if not no_trans: 247 | self.offset_mask_fc = nn.Sequential( 248 | nn.Linear(self.pooled_size * self.pooled_size * 249 | self.output_dim, self.deform_fc_dim), 250 | nn.ReLU(inplace=True), 251 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), 252 | nn.ReLU(inplace=True), 253 | nn.Linear(self.deform_fc_dim, self.pooled_size * 254 | self.pooled_size * 3) 255 | ) 256 | self.offset_mask_fc[4].weight.data.zero_() 257 | self.offset_mask_fc[4].bias.data.zero_() 258 | 259 | def forward(self, input, rois): 260 | offset = input.new() 261 | 262 | if not self.no_trans: 263 | 264 | # do roi_align first 265 | n = rois.shape[0] 266 | roi = dcn_v2_pooling(input, rois, offset, 267 | self.spatial_scale, 268 | self.pooled_size, 269 | self.output_dim, 270 | True, # no trans 271 | self.group_size, 272 | self.part_size, 273 | self.sample_per_part, 274 | self.trans_std) 275 | 276 | # build mask and offset 277 | offset_mask = self.offset_mask_fc(roi.view(n, -1)) 278 | offset_mask = offset_mask.view( 279 | n, 3, self.pooled_size, self.pooled_size) 280 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) 281 | offset = torch.cat((o1, o2), dim=1) 282 | mask = torch.sigmoid(mask) 283 | 284 | # do pooling with offset and mask 285 | return dcn_v2_pooling(input, rois, offset, 286 | self.spatial_scale, 287 | self.pooled_size, 288 | self.output_dim, 289 | self.no_trans, 290 | self.group_size, 291 | self.part_size, 292 | self.sample_per_part, 293 | self.trans_std) * mask 294 | # only roi_align 295 | return dcn_v2_pooling(input, rois, offset, 296 | self.spatial_scale, 297 | self.pooled_size, 298 | self.output_dim, 299 | self.no_trans, 300 | self.group_size, 301 | self.part_size, 302 | self.sample_per_part, 303 | self.trans_std) 304 | -------------------------------------------------------------------------------- /external/DCNv2/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import glob 5 | 6 | import torch 7 | 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | from torch.utils.cpp_extension import CppExtension 10 | from torch.utils.cpp_extension import CUDAExtension 11 | 12 | from setuptools import find_packages 13 | from setuptools import setup 14 | 15 | requirements = ["torch", "torchvision"] 16 | 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | extensions_dir = os.path.join(this_dir, "src") 20 | 21 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 22 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 23 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 24 | 25 | sources = main_file + source_cpu 26 | extension = CppExtension 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if torch.cuda.is_available() and CUDA_HOME is not None: 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | else: 41 | raise NotImplementedError('Cuda is not available') 42 | 43 | sources = [os.path.join(extensions_dir, s) for s in sources] 44 | include_dirs = [extensions_dir] 45 | ext_modules = [ 46 | extension( 47 | "_ext", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ] 54 | return ext_modules 55 | 56 | setup( 57 | name="DCNv2", 58 | version="0.1", 59 | author="charlesshang", 60 | url="https://github.com/charlesshang/DCNv2", 61 | description="deformable convolutional networks", 62 | packages=find_packages(exclude=("configs", "tests",)), 63 | # install_requires=requirements, 64 | ext_modules=get_extensions(), 65 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 66 | ) 67 | -------------------------------------------------------------------------------- /external/DCNv2/src/cpu/dcn_v2_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | at::Tensor 8 | dcn_v2_cpu_forward(const at::Tensor &input, 9 | const at::Tensor &weight, 10 | const at::Tensor &bias, 11 | const at::Tensor &offset, 12 | const at::Tensor &mask, 13 | const int kernel_h, 14 | const int kernel_w, 15 | const int stride_h, 16 | const int stride_w, 17 | const int pad_h, 18 | const int pad_w, 19 | const int dilation_h, 20 | const int dilation_w, 21 | const int deformable_group) 22 | { 23 | AT_ERROR("Not implement on cpu"); 24 | } 25 | 26 | std::vector 27 | dcn_v2_cpu_backward(const at::Tensor &input, 28 | const at::Tensor &weight, 29 | const at::Tensor &bias, 30 | const at::Tensor &offset, 31 | const at::Tensor &mask, 32 | const at::Tensor &grad_output, 33 | int kernel_h, int kernel_w, 34 | int stride_h, int stride_w, 35 | int pad_h, int pad_w, 36 | int dilation_h, int dilation_w, 37 | int deformable_group) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | std::tuple 43 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 44 | const at::Tensor &bbox, 45 | const at::Tensor &trans, 46 | const int no_trans, 47 | const float spatial_scale, 48 | const int output_dim, 49 | const int group_size, 50 | const int pooled_size, 51 | const int part_size, 52 | const int sample_per_part, 53 | const float trans_std) 54 | { 55 | AT_ERROR("Not implement on cpu"); 56 | } 57 | 58 | std::tuple 59 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 60 | const at::Tensor &input, 61 | const at::Tensor &bbox, 62 | const at::Tensor &trans, 63 | const at::Tensor &top_count, 64 | const int no_trans, 65 | const float spatial_scale, 66 | const int output_dim, 67 | const int group_size, 68 | const int pooled_size, 69 | const int part_size, 70 | const int sample_per_part, 71 | const float trans_std) 72 | { 73 | AT_ERROR("Not implement on cpu"); 74 | } -------------------------------------------------------------------------------- /external/DCNv2/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /external/DCNv2/src/cuda/dcn_v2_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/dcn_v2_im2col_cuda.h" 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | THCState *state = at::globalContext().lazyInitCUDA(); 12 | 13 | // author: Charles Shang 14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 15 | 16 | // [batch gemm] 17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu 18 | 19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b, 20 | float **columns_b, const float **ones_b, 21 | const float **weight_b, const float **bias_b, 22 | float *input, float *output, 23 | float *columns, float *ones, 24 | float *weight, float *bias, 25 | const int input_stride, const int output_stride, 26 | const int columns_stride, const int ones_stride, 27 | const int num_batches) 28 | { 29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 30 | if (idx < num_batches) 31 | { 32 | input_b[idx] = input + idx * input_stride; 33 | output_b[idx] = output + idx * output_stride; 34 | columns_b[idx] = columns + idx * columns_stride; 35 | ones_b[idx] = ones + idx * ones_stride; 36 | // share weights and bias within a Mini-Batch 37 | weight_b[idx] = weight; 38 | bias_b[idx] = bias; 39 | } 40 | } 41 | 42 | at::Tensor 43 | dcn_v2_cuda_forward(const at::Tensor &input, 44 | const at::Tensor &weight, 45 | const at::Tensor &bias, 46 | const at::Tensor &offset, 47 | const at::Tensor &mask, 48 | const int kernel_h, 49 | const int kernel_w, 50 | const int stride_h, 51 | const int stride_w, 52 | const int pad_h, 53 | const int pad_w, 54 | const int dilation_h, 55 | const int dilation_w, 56 | const int deformable_group) 57 | { 58 | using scalar_t = float; 59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); 60 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 61 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); 62 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); 63 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 64 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 65 | 66 | const int batch = input.size(0); 67 | const int channels = input.size(1); 68 | const int height = input.size(2); 69 | const int width = input.size(3); 70 | 71 | const int channels_out = weight.size(0); 72 | const int channels_kernel = weight.size(1); 73 | const int kernel_h_ = weight.size(2); 74 | const int kernel_w_ = weight.size(3); 75 | 76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); 77 | // printf("Channels: %d %d\n", channels, channels_kernel); 78 | // printf("Channels: %d %d\n", channels_out, channels_kernel); 79 | 80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 82 | 83 | AT_ASSERTM(channels == channels_kernel, 84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 85 | 86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 88 | 89 | auto ones = at::ones({batch, height_out, width_out}, input.options()); 90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 92 | 93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing 94 | // when batch size is large. 95 | // launch batch threads 96 | int matrices_size = batch * sizeof(float *); 97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size)); 98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size)); 99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); 100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); 101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); 102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); 103 | 104 | const int block = 128; 105 | const int grid = (batch + block - 1) / block; 106 | 107 | createBatchGemmBuffer<<>>( 108 | input_b, output_b, 109 | columns_b, ones_b, 110 | weight_b, bias_b, 111 | input.data(), 112 | output.data(), 113 | columns.data(), 114 | ones.data(), 115 | weight.data(), 116 | bias.data(), 117 | channels * width * height, 118 | channels_out * width_out * height_out, 119 | channels * kernel_h * kernel_w * height_out * width_out, 120 | height_out * width_out, 121 | batch); 122 | 123 | long m_ = channels_out; 124 | long n_ = height_out * width_out; 125 | long k_ = 1; 126 | THCudaBlas_SgemmBatched(state, 127 | 't', 128 | 'n', 129 | n_, 130 | m_, 131 | k_, 132 | 1.0f, 133 | ones_b, k_, 134 | bias_b, k_, 135 | 0.0f, 136 | output_b, n_, 137 | batch); 138 | 139 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 140 | input.data(), 141 | offset.data(), 142 | mask.data(), 143 | batch, channels, height, width, 144 | height_out, width_out, kernel_h, kernel_w, 145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 146 | deformable_group, 147 | columns.data()); 148 | 149 | long m = channels_out; 150 | long n = height_out * width_out; 151 | long k = channels * kernel_h * kernel_w; 152 | THCudaBlas_SgemmBatched(state, 153 | 'n', 154 | 'n', 155 | n, 156 | m, 157 | k, 158 | 1.0f, 159 | (const float **)columns_b, n, 160 | weight_b, k, 161 | 1.0f, 162 | output_b, n, 163 | batch); 164 | 165 | THCudaFree(state, input_b); 166 | THCudaFree(state, output_b); 167 | THCudaFree(state, columns_b); 168 | THCudaFree(state, ones_b); 169 | THCudaFree(state, weight_b); 170 | THCudaFree(state, bias_b); 171 | return output; 172 | } 173 | 174 | __global__ void createBatchGemmBufferBackward( 175 | float **grad_output_b, 176 | float **columns_b, 177 | float **ones_b, 178 | float **weight_b, 179 | float **grad_weight_b, 180 | float **grad_bias_b, 181 | float *grad_output, 182 | float *columns, 183 | float *ones, 184 | float *weight, 185 | float *grad_weight, 186 | float *grad_bias, 187 | const int grad_output_stride, 188 | const int columns_stride, 189 | const int ones_stride, 190 | const int num_batches) 191 | { 192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 193 | if (idx < num_batches) 194 | { 195 | grad_output_b[idx] = grad_output + idx * grad_output_stride; 196 | columns_b[idx] = columns + idx * columns_stride; 197 | ones_b[idx] = ones + idx * ones_stride; 198 | 199 | // share weights and bias within a Mini-Batch 200 | weight_b[idx] = weight; 201 | grad_weight_b[idx] = grad_weight; 202 | grad_bias_b[idx] = grad_bias; 203 | } 204 | } 205 | 206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input, 207 | const at::Tensor &weight, 208 | const at::Tensor &bias, 209 | const at::Tensor &offset, 210 | const at::Tensor &mask, 211 | const at::Tensor &grad_output, 212 | int kernel_h, int kernel_w, 213 | int stride_h, int stride_w, 214 | int pad_h, int pad_w, 215 | int dilation_h, int dilation_w, 216 | int deformable_group) 217 | { 218 | 219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); 220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); 221 | 222 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 223 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); 224 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); 225 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 226 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 227 | 228 | const int batch = input.size(0); 229 | const int channels = input.size(1); 230 | const int height = input.size(2); 231 | const int width = input.size(3); 232 | 233 | const int channels_out = weight.size(0); 234 | const int channels_kernel = weight.size(1); 235 | const int kernel_h_ = weight.size(2); 236 | const int kernel_w_ = weight.size(3); 237 | 238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 240 | 241 | AT_ASSERTM(channels == channels_kernel, 242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 243 | 244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 246 | 247 | auto ones = at::ones({height_out, width_out}, input.options()); 248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 250 | 251 | auto grad_input = at::zeros_like(input); 252 | auto grad_weight = at::zeros_like(weight); 253 | auto grad_bias = at::zeros_like(bias); 254 | auto grad_offset = at::zeros_like(offset); 255 | auto grad_mask = at::zeros_like(mask); 256 | 257 | using scalar_t = float; 258 | 259 | for (int b = 0; b < batch; b++) 260 | { 261 | auto input_n = input.select(0, b); 262 | auto offset_n = offset.select(0, b); 263 | auto mask_n = mask.select(0, b); 264 | auto grad_output_n = grad_output.select(0, b); 265 | auto grad_input_n = grad_input.select(0, b); 266 | auto grad_offset_n = grad_offset.select(0, b); 267 | auto grad_mask_n = grad_mask.select(0, b); 268 | 269 | long m = channels * kernel_h * kernel_w; 270 | long n = height_out * width_out; 271 | long k = channels_out; 272 | 273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 274 | grad_output_n.data(), n, 275 | weight.data(), m, 0.0f, 276 | columns.data(), n); 277 | 278 | // gradient w.r.t. input coordinate data 279 | modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), 280 | columns.data(), 281 | input_n.data(), 282 | offset_n.data(), 283 | mask_n.data(), 284 | 1, channels, height, width, 285 | height_out, width_out, kernel_h, kernel_w, 286 | pad_h, pad_w, stride_h, stride_w, 287 | dilation_h, dilation_w, deformable_group, 288 | grad_offset_n.data(), 289 | grad_mask_n.data()); 290 | // gradient w.r.t. input data 291 | modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), 292 | columns.data(), 293 | offset_n.data(), 294 | mask_n.data(), 295 | 1, channels, height, width, 296 | height_out, width_out, kernel_h, kernel_w, 297 | pad_h, pad_w, stride_h, stride_w, 298 | dilation_h, dilation_w, deformable_group, 299 | grad_input_n.data()); 300 | 301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group 302 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 303 | input_n.data(), 304 | offset_n.data(), 305 | mask_n.data(), 306 | 1, channels, height, width, 307 | height_out, width_out, kernel_h, kernel_w, 308 | pad_h, pad_w, stride_h, stride_w, 309 | dilation_h, dilation_w, deformable_group, 310 | columns.data()); 311 | 312 | long m_ = channels_out; 313 | long n_ = channels * kernel_h * kernel_w; 314 | long k_ = height_out * width_out; 315 | 316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 317 | columns.data(), k_, 318 | grad_output_n.data(), k_, 1.0f, 319 | grad_weight.data(), n_); 320 | 321 | // gradient w.r.t. bias 322 | // long m_ = channels_out; 323 | // long k__ = height_out * width_out; 324 | THCudaBlas_Sgemv(state, 325 | 't', 326 | k_, m_, 1.0f, 327 | grad_output_n.data(), k_, 328 | ones.data(), 1, 1.0f, 329 | grad_bias.data(), 1); 330 | } 331 | 332 | return { 333 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias 334 | }; 335 | } 336 | -------------------------------------------------------------------------------- /external/DCNv2/src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /external/DCNv2/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /external/DCNv2/src/dcn_v2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | at::Tensor 10 | dcn_v2_forward(const at::Tensor &input, 11 | const at::Tensor &weight, 12 | const at::Tensor &bias, 13 | const at::Tensor &offset, 14 | const at::Tensor &mask, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int deformable_group) 24 | { 25 | if (input.type().is_cuda()) 26 | { 27 | #ifdef WITH_CUDA 28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask, 29 | kernel_h, kernel_w, 30 | stride_h, stride_w, 31 | pad_h, pad_w, 32 | dilation_h, dilation_w, 33 | deformable_group); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | dcn_v2_backward(const at::Tensor &input, 43 | const at::Tensor &weight, 44 | const at::Tensor &bias, 45 | const at::Tensor &offset, 46 | const at::Tensor &mask, 47 | const at::Tensor &grad_output, 48 | int kernel_h, int kernel_w, 49 | int stride_h, int stride_w, 50 | int pad_h, int pad_w, 51 | int dilation_h, int dilation_w, 52 | int deformable_group) 53 | { 54 | if (input.type().is_cuda()) 55 | { 56 | #ifdef WITH_CUDA 57 | return dcn_v2_cuda_backward(input, 58 | weight, 59 | bias, 60 | offset, 61 | mask, 62 | grad_output, 63 | kernel_h, kernel_w, 64 | stride_h, stride_w, 65 | pad_h, pad_w, 66 | dilation_h, dilation_w, 67 | deformable_group); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | std::tuple 76 | dcn_v2_psroi_pooling_forward(const at::Tensor &input, 77 | const at::Tensor &bbox, 78 | const at::Tensor &trans, 79 | const int no_trans, 80 | const float spatial_scale, 81 | const int output_dim, 82 | const int group_size, 83 | const int pooled_size, 84 | const int part_size, 85 | const int sample_per_part, 86 | const float trans_std) 87 | { 88 | if (input.type().is_cuda()) 89 | { 90 | #ifdef WITH_CUDA 91 | return dcn_v2_psroi_pooling_cuda_forward(input, 92 | bbox, 93 | trans, 94 | no_trans, 95 | spatial_scale, 96 | output_dim, 97 | group_size, 98 | pooled_size, 99 | part_size, 100 | sample_per_part, 101 | trans_std); 102 | #else 103 | AT_ERROR("Not compiled with GPU support"); 104 | #endif 105 | } 106 | AT_ERROR("Not implemented on the CPU"); 107 | } 108 | 109 | std::tuple 110 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, 111 | const at::Tensor &input, 112 | const at::Tensor &bbox, 113 | const at::Tensor &trans, 114 | const at::Tensor &top_count, 115 | const int no_trans, 116 | const float spatial_scale, 117 | const int output_dim, 118 | const int group_size, 119 | const int pooled_size, 120 | const int part_size, 121 | const int sample_per_part, 122 | const float trans_std) 123 | { 124 | if (input.type().is_cuda()) 125 | { 126 | #ifdef WITH_CUDA 127 | return dcn_v2_psroi_pooling_cuda_backward(out_grad, 128 | input, 129 | bbox, 130 | trans, 131 | top_count, 132 | no_trans, 133 | spatial_scale, 134 | output_dim, 135 | group_size, 136 | pooled_size, 137 | part_size, 138 | sample_per_part, 139 | trans_std); 140 | #else 141 | AT_ERROR("Not compiled with GPU support"); 142 | #endif 143 | } 144 | AT_ERROR("Not implemented on the CPU"); 145 | } -------------------------------------------------------------------------------- /external/DCNv2/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /external/DCNv2/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import gradcheck 10 | 11 | from dcn_v2 import dcn_v2_conv, DCNv2, DCN 12 | from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling 13 | 14 | deformable_groups = 1 15 | N, inC, inH, inW = 2, 2, 4, 4 16 | outC = 2 17 | kH, kW = 3, 3 18 | 19 | 20 | def conv_identify(weight, bias): 21 | weight.data.zero_() 22 | bias.data.zero_() 23 | o, i, h, w = weight.shape 24 | y = h//2 25 | x = w//2 26 | for p in range(i): 27 | for q in range(o): 28 | if p == q: 29 | weight.data[q, p, y, x] = 1.0 30 | 31 | 32 | def check_zero_offset(): 33 | conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, 34 | kernel_size=(kH, kW), 35 | stride=(1, 1), 36 | padding=(1, 1), 37 | bias=True).cuda() 38 | 39 | conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, 40 | kernel_size=(kH, kW), 41 | stride=(1, 1), 42 | padding=(1, 1), 43 | bias=True).cuda() 44 | 45 | dcn_v2 = DCNv2(inC, outC, (kH, kW), 46 | stride=1, padding=1, dilation=1, 47 | deformable_groups=deformable_groups).cuda() 48 | 49 | conv_offset.weight.data.zero_() 50 | conv_offset.bias.data.zero_() 51 | conv_mask.weight.data.zero_() 52 | conv_mask.bias.data.zero_() 53 | conv_identify(dcn_v2.weight, dcn_v2.bias) 54 | 55 | input = torch.randn(N, inC, inH, inW).cuda() 56 | offset = conv_offset(input) 57 | mask = conv_mask(input) 58 | mask = torch.sigmoid(mask) 59 | output = dcn_v2(input, offset, mask) 60 | output *= 2 61 | d = (input - output).abs().max() 62 | if d < 1e-10: 63 | print('Zero offset passed') 64 | else: 65 | print('Zero offset failed') 66 | print(input) 67 | print(output) 68 | 69 | def check_gradient_dconv(): 70 | 71 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01 72 | input.requires_grad = True 73 | 74 | offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 75 | # offset.data.zero_() 76 | # offset.data -= 0.5 77 | offset.requires_grad = True 78 | 79 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() 80 | # mask.data.zero_() 81 | mask.requires_grad = True 82 | mask = torch.sigmoid(mask) 83 | 84 | weight = torch.randn(outC, inC, kH, kW).cuda() 85 | weight.requires_grad = True 86 | 87 | bias = torch.rand(outC).cuda() 88 | bias.requires_grad = True 89 | 90 | stride = 1 91 | padding = 1 92 | dilation = 1 93 | 94 | print('check_gradient_dconv: ', 95 | gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, 96 | stride, padding, dilation, deformable_groups), 97 | eps=1e-3, atol=1e-4, rtol=1e-2)) 98 | 99 | 100 | def check_pooling_zero_offset(): 101 | 102 | input = torch.randn(2, 16, 64, 64).cuda().zero_() 103 | input[0, :, 16:26, 16:26] = 1. 104 | input[1, :, 10:20, 20:30] = 2. 105 | rois = torch.tensor([ 106 | [0, 65, 65, 103, 103], 107 | [1, 81, 41, 119, 79], 108 | ]).cuda().float() 109 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4, 110 | pooled_size=7, 111 | output_dim=16, 112 | no_trans=True, 113 | group_size=1, 114 | trans_std=0.0).cuda() 115 | 116 | out = pooling(input, rois, input.new()) 117 | s = ', '.join(['%f' % out[i, :, :, :].mean().item() 118 | for i in range(rois.shape[0])]) 119 | print(s) 120 | 121 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, 122 | pooled_size=7, 123 | output_dim=16, 124 | no_trans=False, 125 | group_size=1, 126 | trans_std=0.0).cuda() 127 | offset = torch.randn(20, 2, 7, 7).cuda().zero_() 128 | dout = dpooling(input, rois, offset) 129 | s = ', '.join(['%f' % dout[i, :, :, :].mean().item() 130 | for i in range(rois.shape[0])]) 131 | print(s) 132 | 133 | 134 | def check_gradient_dpooling(): 135 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01 136 | N = 4 137 | batch_inds = torch.randint(2, (N, 1)).cuda().float() 138 | x = torch.rand((N, 1)).cuda().float() * 15 139 | y = torch.rand((N, 1)).cuda().float() * 15 140 | w = torch.rand((N, 1)).cuda().float() * 10 141 | h = torch.rand((N, 1)).cuda().float() * 10 142 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 143 | offset = torch.randn(N, 2, 3, 3).cuda() 144 | input.requires_grad = True 145 | offset.requires_grad = True 146 | 147 | spatial_scale = 1.0 / 4 148 | pooled_size = 3 149 | output_dim = 3 150 | no_trans = 0 151 | group_size = 1 152 | trans_std = 0.0 153 | sample_per_part = 4 154 | part_size = pooled_size 155 | 156 | print('check_gradient_dpooling:', 157 | gradcheck(dcn_v2_pooling, (input, rois, offset, 158 | spatial_scale, 159 | pooled_size, 160 | output_dim, 161 | no_trans, 162 | group_size, 163 | part_size, 164 | sample_per_part, 165 | trans_std), 166 | eps=1e-4)) 167 | 168 | 169 | def example_dconv(): 170 | input = torch.randn(2, 64, 128, 128).cuda() 171 | # wrap all things (offset and mask) in DCN 172 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, 173 | padding=1, deformable_groups=2).cuda() 174 | # print(dcn.weight.shape, input.shape) 175 | output = dcn(input) 176 | targert = output.new(*output.size()) 177 | targert.data.uniform_(-0.01, 0.01) 178 | error = (targert - output).mean() 179 | error.backward() 180 | print(output.shape) 181 | 182 | 183 | def example_dpooling(): 184 | input = torch.randn(2, 32, 64, 64).cuda() 185 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 186 | x = torch.randint(256, (20, 1)).cuda().float() 187 | y = torch.randint(256, (20, 1)).cuda().float() 188 | w = torch.randint(64, (20, 1)).cuda().float() 189 | h = torch.randint(64, (20, 1)).cuda().float() 190 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 191 | offset = torch.randn(20, 2, 7, 7).cuda() 192 | input.requires_grad = True 193 | offset.requires_grad = True 194 | 195 | # normal roi_align 196 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4, 197 | pooled_size=7, 198 | output_dim=32, 199 | no_trans=True, 200 | group_size=1, 201 | trans_std=0.1).cuda() 202 | 203 | # deformable pooling 204 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, 205 | pooled_size=7, 206 | output_dim=32, 207 | no_trans=False, 208 | group_size=1, 209 | trans_std=0.1).cuda() 210 | 211 | out = pooling(input, rois, offset) 212 | dout = dpooling(input, rois, offset) 213 | print(out.shape) 214 | print(dout.shape) 215 | 216 | target_out = out.new(*out.size()) 217 | target_out.data.uniform_(-0.01, 0.01) 218 | target_dout = dout.new(*dout.size()) 219 | target_dout.data.uniform_(-0.01, 0.01) 220 | e = (target_out - out).mean() 221 | e.backward() 222 | e = (target_dout - dout).mean() 223 | e.backward() 224 | 225 | 226 | def example_mdpooling(): 227 | input = torch.randn(2, 32, 64, 64).cuda() 228 | input.requires_grad = True 229 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 230 | x = torch.randint(256, (20, 1)).cuda().float() 231 | y = torch.randint(256, (20, 1)).cuda().float() 232 | w = torch.randint(64, (20, 1)).cuda().float() 233 | h = torch.randint(64, (20, 1)).cuda().float() 234 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 235 | 236 | # mdformable pooling (V2) 237 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 238 | pooled_size=7, 239 | output_dim=32, 240 | no_trans=False, 241 | group_size=1, 242 | trans_std=0.1, 243 | deform_fc_dim=1024).cuda() 244 | 245 | dout = dpooling(input, rois) 246 | target = dout.new(*dout.size()) 247 | target.data.uniform_(-0.1, 0.1) 248 | error = (target - dout).mean() 249 | error.backward() 250 | print(dout.shape) 251 | 252 | 253 | if __name__ == '__main__': 254 | 255 | example_dconv() 256 | example_dpooling() 257 | example_mdpooling() 258 | 259 | check_pooling_zero_offset() 260 | # zero offset check 261 | if inC == outC: 262 | check_zero_offset() 263 | 264 | check_gradient_dpooling() 265 | check_gradient_dconv() 266 | # """ 267 | # ****** Note: backward is not reentrant error may not be a serious problem, 268 | # ****** since the max error is less than 1e-7, 269 | # ****** Still looking for what trigger this problem 270 | # """ 271 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | 3 | 4 | __all__ = ['Detect'] 5 | -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ..box_utils import decode, jaccard, index2d 4 | from utils import timer 5 | 6 | from data import cfg, mask_type 7 | 8 | import numpy as np 9 | 10 | 11 | class Detect(object): 12 | """At test time, Detect is the final layer of SSD. Decode location preds, 13 | apply non-maximum suppression to location predictions based on conf 14 | scores and threshold to a top_k number of output predictions for both 15 | confidence score and locations, as the predicted masks. 16 | """ 17 | # TODO: Refactor this whole class away. It needs to go. 18 | 19 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): 20 | self.num_classes = num_classes 21 | self.background_label = bkg_label 22 | self.top_k = top_k 23 | # Parameters used in nms. 24 | self.nms_thresh = nms_thresh 25 | if nms_thresh <= 0: 26 | raise ValueError('nms_threshold must be non negative.') 27 | self.conf_thresh = conf_thresh 28 | 29 | self.use_cross_class_nms = False 30 | self.use_fast_nms = False 31 | 32 | def __call__(self, predictions, net): 33 | """ 34 | Args: 35 | loc_data: (tensor) Loc preds from loc layers 36 | Shape: [batch, num_priors, 4] 37 | conf_data: (tensor) Shape: Conf preds from conf layers 38 | Shape: [batch, num_priors, num_classes] 39 | mask_data: (tensor) Mask preds from mask layers 40 | Shape: [batch, num_priors, mask_dim] 41 | prior_data: (tensor) Prior boxes and variances from priorbox layers 42 | Shape: [num_priors, 4] 43 | proto_data: (tensor) If using mask_type.lincomb, the prototype masks 44 | Shape: [batch, mask_h, mask_w, mask_dim] 45 | 46 | Returns: 47 | output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) 48 | These outputs are in the order: class idx, confidence, bbox coords, and mask. 49 | 50 | Note that the outputs are sorted only if cross_class_nms is False 51 | """ 52 | 53 | loc_data = predictions['loc'] 54 | conf_data = predictions['conf'] 55 | mask_data = predictions['mask'] 56 | prior_data = predictions['priors'] 57 | 58 | proto_data = predictions['proto'] if 'proto' in predictions else None 59 | inst_data = predictions['inst'] if 'inst' in predictions else None 60 | 61 | out = [] 62 | 63 | with timer.env('Detect'): 64 | batch_size = loc_data.size(0) 65 | num_priors = prior_data.size(0) 66 | 67 | conf_preds = conf_data.view(batch_size, num_priors, self.num_classes).transpose(2, 1).contiguous() 68 | 69 | for batch_idx in range(batch_size): 70 | decoded_boxes = decode(loc_data[batch_idx], prior_data) 71 | result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data) 72 | 73 | if result is not None and proto_data is not None: 74 | result['proto'] = proto_data[batch_idx] 75 | 76 | out.append({'detection': result, 'net': net}) 77 | 78 | return out 79 | 80 | 81 | def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data, inst_data): 82 | """ Perform nms for only the max scoring class that isn't background (class 0) """ 83 | cur_scores = conf_preds[batch_idx, 1:, :] 84 | conf_scores, _ = torch.max(cur_scores, dim=0) 85 | 86 | keep = (conf_scores > self.conf_thresh) 87 | scores = cur_scores[:, keep] 88 | boxes = decoded_boxes[keep, :] 89 | masks = mask_data[batch_idx, keep, :] 90 | 91 | if inst_data is not None: 92 | inst = inst_data[batch_idx, keep, :] 93 | 94 | if scores.size(1) == 0: 95 | return None 96 | 97 | if self.use_fast_nms: 98 | if self.use_cross_class_nms: 99 | boxes, masks, classes, scores = self.cc_fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k) 100 | else: 101 | boxes, masks, classes, scores = self.fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k) 102 | else: 103 | boxes, masks, classes, scores = self.traditional_nms(boxes, masks, scores, self.nms_thresh, self.conf_thresh) 104 | 105 | if self.use_cross_class_nms: 106 | print('Warning: Cross Class Traditional NMS is not implemented.') 107 | 108 | return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores} 109 | 110 | 111 | def cc_fast_nms(self, boxes, masks, scores, iou_threshold:float=0.5, top_k:int=200): 112 | # Collapse all the classes into 1 113 | scores, classes = scores.max(dim=0) 114 | 115 | _, idx = scores.sort(0, descending=True) 116 | idx = idx[:top_k] 117 | 118 | boxes_idx = boxes[idx] 119 | 120 | # Compute the pairwise IoU between the boxes 121 | iou = jaccard(boxes_idx, boxes_idx) 122 | 123 | # Zero out the lower triangle of the cosine similarity matrix and diagonal 124 | iou.triu_(diagonal=1) 125 | 126 | # Now that everything in the diagonal and below is zeroed out, if we take the max 127 | # of the IoU matrix along the columns, each column will represent the maximum IoU 128 | # between this element and every element with a higher score than this element. 129 | iou_max, _ = torch.max(iou, dim=0) 130 | 131 | # Now just filter out the ones greater than the threshold, i.e., only keep boxes that 132 | # don't have a higher scoring box that would supress it in normal NMS. 133 | idx_out = idx[iou_max <= iou_threshold] 134 | 135 | return boxes[idx_out], masks[idx_out], classes[idx_out], scores[idx_out] 136 | 137 | def fast_nms(self, boxes, masks, scores, iou_threshold:float=0.5, top_k:int=200, second_threshold:bool=False): 138 | scores, idx = scores.sort(1, descending=True) 139 | 140 | idx = idx[:, :top_k].contiguous() 141 | scores = scores[:, :top_k] 142 | 143 | num_classes, num_dets = idx.size() 144 | 145 | boxes = boxes[idx.view(-1), :].view(num_classes, num_dets, 4) 146 | masks = masks[idx.view(-1), :].view(num_classes, num_dets, -1) 147 | 148 | iou = jaccard(boxes, boxes) 149 | iou.triu_(diagonal=1) 150 | iou_max, _ = iou.max(dim=1) 151 | 152 | # Now just filter out the ones higher than the threshold 153 | keep = (iou_max <= iou_threshold) 154 | 155 | # We should also only keep detections over the confidence threshold, but at the cost of 156 | # maxing out your detection count for every image, you can just not do that. Because we 157 | # have such a minimal amount of computation per detection (matrix mulitplication only), 158 | # this increase doesn't affect us much (+0.2 mAP for 34 -> 33 fps), so we leave it out. 159 | # However, when you implement this in your method, you should do this second threshold. 160 | if second_threshold: 161 | keep *= (scores > self.conf_thresh) 162 | 163 | # Assign each kept detection to its corresponding class 164 | classes = torch.arange(num_classes, device=boxes.device)[:, None].expand_as(keep) 165 | classes = classes[keep] 166 | 167 | boxes = boxes[keep] 168 | masks = masks[keep] 169 | scores = scores[keep] 170 | 171 | # Only keep the top cfg.max_num_detections highest scores across all classes 172 | scores, idx = scores.sort(0, descending=True) 173 | idx = idx[:cfg.max_num_detections] 174 | scores = scores[:cfg.max_num_detections] 175 | 176 | classes = classes[idx] 177 | boxes = boxes[idx] 178 | masks = masks[idx] 179 | 180 | return boxes, masks, classes, scores 181 | 182 | def traditional_nms(self, boxes, masks, scores, iou_threshold=0.5, conf_thresh=0.05): 183 | import pyximport 184 | pyximport.install(setup_args={"include_dirs":np.get_include()}, reload_support=True) 185 | 186 | from utils.cython_nms import nms as cnms 187 | 188 | num_classes = scores.size(0) 189 | 190 | idx_lst = [] 191 | cls_lst = [] 192 | scr_lst = [] 193 | 194 | # Multiplying by max_size is necessary because of how cnms computes its area and intersections 195 | boxes = boxes * cfg.max_size 196 | 197 | for _cls in range(num_classes): 198 | cls_scores = scores[_cls, :] 199 | conf_mask = cls_scores > conf_thresh 200 | idx = torch.arange(cls_scores.size(0), device=boxes.device) 201 | 202 | cls_scores = cls_scores[conf_mask] 203 | idx = idx[conf_mask] 204 | 205 | if cls_scores.size(0) == 0: 206 | continue 207 | 208 | preds = torch.cat([boxes[conf_mask], cls_scores[:, None]], dim=1).cpu().numpy() 209 | keep = cnms(preds, iou_threshold) 210 | keep = torch.Tensor(keep, device=boxes.device).long() 211 | 212 | idx_lst.append(idx[keep]) 213 | cls_lst.append(keep * 0 + _cls) 214 | scr_lst.append(cls_scores[keep]) 215 | 216 | idx = torch.cat(idx_lst, dim=0) 217 | classes = torch.cat(cls_lst, dim=0) 218 | scores = torch.cat(scr_lst, dim=0) 219 | 220 | scores, idx2 = scores.sort(0, descending=True) 221 | idx2 = idx2[:cfg.max_num_detections] 222 | scores = scores[:cfg.max_num_detections] 223 | 224 | idx = idx[idx2] 225 | classes = classes[idx2] 226 | 227 | # Undo the multiplication above 228 | return boxes[idx] / cfg.max_size, masks[idx], classes, scores 229 | -------------------------------------------------------------------------------- /layers/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class InterpolateModule(nn.Module): 5 | """ 6 | This is a module version of F.interpolate (rip nn.Upsampling). 7 | Any arguments you give it just get passed along for the ride. 8 | """ 9 | 10 | def __init__(self, *args, **kwdargs): 11 | super().__init__() 12 | 13 | self.args = args 14 | self.kwdargs = kwdargs 15 | 16 | def forward(self, x): 17 | return F.interpolate(x, *self.args, **self.kwdargs) 18 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | 3 | __all__ = ['MultiBoxLoss'] 4 | -------------------------------------------------------------------------------- /layers/output_utils.py: -------------------------------------------------------------------------------- 1 | """ Contains functions used to sanitize and prepare the output of Yolact. """ 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import cv2 9 | 10 | from data import cfg, mask_type, MEANS, STD, activation_func 11 | from utils.augmentations import Resize 12 | from utils import timer 13 | from .box_utils import crop, sanitize_coordinates 14 | 15 | def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', 16 | visualize_lincomb=False, crop_masks=True, score_threshold=0): 17 | """ 18 | Postprocesses the output of Yolact on testing mode into a format that makes sense, 19 | accounting for all the possible configuration settings. 20 | 21 | Args: 22 | - det_output: The lost of dicts that Detect outputs. 23 | - w: The real with of the image. 24 | - h: The real height of the image. 25 | - batch_idx: If you have multiple images for this batch, the image's index in the batch. 26 | - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) 27 | 28 | Returns 4 torch Tensors (in the following order): 29 | - classes [num_det]: The class idx for each detection. 30 | - scores [num_det]: The confidence score for each detection. 31 | - boxes [num_det, 4]: The bounding box for each detection in absolute point form. 32 | - masks [num_det, h, w]: Full image masks for each detection. 33 | """ 34 | 35 | dets = det_output[batch_idx] 36 | net = dets['net'] 37 | dets = dets['detection'] 38 | 39 | if dets is None: 40 | return [torch.Tensor()] * 4 # Warning, this is 4 copies of the same thing 41 | 42 | if score_threshold > 0: 43 | keep = dets['score'] > score_threshold 44 | 45 | for k in dets: 46 | if k != 'proto': 47 | dets[k] = dets[k][keep] 48 | 49 | if dets['score'].size(0) == 0: 50 | return [torch.Tensor()] * 4 51 | 52 | # Actually extract everything from dets now 53 | classes = dets['class'] 54 | boxes = dets['box'] 55 | scores = dets['score'] 56 | masks = dets['mask'] 57 | 58 | if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: 59 | # At this points masks is only the coefficients 60 | proto_data = dets['proto'] 61 | 62 | # Test flag, do not upvote 63 | if cfg.mask_proto_debug: 64 | np.save('scripts/proto.npy', proto_data.cpu().numpy()) 65 | 66 | if visualize_lincomb: 67 | display_lincomb(proto_data, masks) 68 | 69 | masks = proto_data @ masks.t() 70 | masks = cfg.mask_proto_mask_activation(masks) 71 | 72 | # Crop masks before upsampling because you know why 73 | if crop_masks: 74 | masks = crop(masks, boxes) 75 | 76 | # Permute into the correct output shape [num_dets, proto_h, proto_w] 77 | masks = masks.permute(2, 0, 1).contiguous() 78 | 79 | if cfg.use_maskiou: 80 | with timer.env('maskiou_net'): 81 | with torch.no_grad(): 82 | maskiou_p = net.maskiou_net(masks.unsqueeze(1)) 83 | maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) 84 | if cfg.rescore_mask: 85 | if cfg.rescore_bbox: 86 | scores = scores * maskiou_p 87 | else: 88 | scores = [scores, scores * maskiou_p] 89 | 90 | # Scale masks up to the full image 91 | masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) 92 | 93 | # Binarize the masks 94 | masks.gt_(0.5) 95 | 96 | 97 | boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False) 98 | boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False) 99 | boxes = boxes.long() 100 | 101 | if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: 102 | # Upscale masks 103 | full_masks = torch.zeros(masks.size(0), h, w) 104 | 105 | for jdx in range(masks.size(0)): 106 | x1, y1, x2, y2 = boxes[jdx, :] 107 | 108 | mask_w = x2 - x1 109 | mask_h = y2 - y1 110 | 111 | # Just in case 112 | if mask_w * mask_h <= 0 or mask_w < 0: 113 | continue 114 | 115 | mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) 116 | mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) 117 | mask = mask.gt(0.5).float() 118 | full_masks[jdx, y1:y2, x1:x2] = mask 119 | 120 | masks = full_masks 121 | 122 | return classes, scores, boxes, masks 123 | 124 | 125 | 126 | 127 | 128 | def undo_image_transformation(img, w, h): 129 | """ 130 | Takes a transformed image tensor and returns a numpy ndarray that is untransformed. 131 | Arguments w and h are the original height and width of the image. 132 | """ 133 | img_numpy = img.permute(1, 2, 0).cpu().numpy() 134 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG 135 | 136 | if cfg.backbone.transform.normalize: 137 | img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0 138 | elif cfg.backbone.transform.subtract_means: 139 | img_numpy = (img_numpy / 255.0 + np.array(MEANS) / 255.0).astype(np.float32) 140 | 141 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB 142 | img_numpy = np.clip(img_numpy, 0, 1) 143 | 144 | return cv2.resize(img_numpy, (w,h)) 145 | 146 | 147 | def display_lincomb(proto_data, masks): 148 | out_masks = torch.matmul(proto_data, masks.t()) 149 | # out_masks = cfg.mask_proto_mask_activation(out_masks) 150 | 151 | for kdx in range(1): 152 | jdx = kdx + 0 153 | import matplotlib.pyplot as plt 154 | coeffs = masks[jdx, :].cpu().numpy() 155 | idx = np.argsort(-np.abs(coeffs)) 156 | # plt.bar(list(range(idx.shape[0])), coeffs[idx]) 157 | # plt.show() 158 | 159 | coeffs_sort = coeffs[idx] 160 | arr_h, arr_w = (4,8) 161 | proto_h, proto_w, _ = proto_data.size() 162 | arr_img = np.zeros([proto_h*arr_h, proto_w*arr_w]) 163 | arr_run = np.zeros([proto_h*arr_h, proto_w*arr_w]) 164 | test = torch.sum(proto_data, -1).cpu().numpy() 165 | 166 | for y in range(arr_h): 167 | for x in range(arr_w): 168 | i = arr_w * y + x 169 | 170 | if i == 0: 171 | running_total = proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i] 172 | else: 173 | running_total += proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i] 174 | 175 | running_total_nonlin = running_total 176 | if cfg.mask_proto_mask_activation == activation_func.sigmoid: 177 | running_total_nonlin = (1/(1+np.exp(-running_total_nonlin))) 178 | 179 | arr_img[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (proto_data[:, :, idx[i]] / torch.max(proto_data[:, :, idx[i]])).cpu().numpy() * coeffs_sort[i] 180 | arr_run[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (running_total_nonlin > 0.5).astype(np.float) 181 | plt.imshow(arr_img) 182 | plt.show() 183 | # plt.imshow(arr_run) 184 | # plt.show() 185 | # plt.imshow(test) 186 | # plt.show() 187 | plt.imshow(out_masks[:, :, jdx].cpu().numpy()) 188 | plt.show() 189 | -------------------------------------------------------------------------------- /run_coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs the coco-supplied cocoeval script to evaluate detections 3 | outputted by using the output_coco_json flag in eval.py. 4 | """ 5 | 6 | 7 | import argparse 8 | 9 | from pycocotools.coco import COCO 10 | from pycocotools.cocoeval import COCOeval 11 | 12 | 13 | parser = argparse.ArgumentParser(description='COCO Detections Evaluator') 14 | parser.add_argument('--bbox_det_file', default='results/bbox_detections.json', type=str) 15 | parser.add_argument('--mask_det_file', default='results/mask_detections.json', type=str) 16 | parser.add_argument('--gt_ann_file', default='data/coco/annotations/instances_val2017.json', type=str) 17 | parser.add_argument('--eval_type', default='both', choices=['bbox', 'mask', 'both'], type=str) 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | if __name__ == '__main__': 23 | 24 | eval_bbox = (args.eval_type in ('bbox', 'both')) 25 | eval_mask = (args.eval_type in ('mask', 'both')) 26 | 27 | print('Loading annotations...') 28 | gt_annotations = COCO(args.gt_ann_file) 29 | if eval_bbox: 30 | bbox_dets = gt_annotations.loadRes(args.bbox_det_file) 31 | if eval_mask: 32 | mask_dets = gt_annotations.loadRes(args.mask_det_file) 33 | 34 | if eval_bbox: 35 | print('\nEvaluating BBoxes:') 36 | bbox_eval = COCOeval(gt_annotations, bbox_dets, 'bbox') 37 | bbox_eval.evaluate() 38 | bbox_eval.accumulate() 39 | bbox_eval.summarize() 40 | 41 | if eval_mask: 42 | print('\nEvaluating Masks:') 43 | bbox_eval = COCOeval(gt_annotations, mask_dets, 'segm') 44 | bbox_eval.evaluate() 45 | bbox_eval.accumulate() 46 | bbox_eval.summarize() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /scripts/augment_bbox.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | import json, pickle 4 | import sys 5 | from math import sqrt 6 | from itertools import product 7 | import torch 8 | from numpy import random 9 | 10 | import numpy as np 11 | 12 | 13 | max_image_size = 550 14 | augment_idx = 0 15 | dump_file = 'weights/bboxes_aug.pkl' 16 | box_file = 'weights/bboxes.pkl' 17 | 18 | def augment_boxes(bboxes): 19 | bboxes_rel = [] 20 | for box in bboxes: 21 | bboxes_rel.append(prep_box(box)) 22 | bboxes_rel = np.concatenate(bboxes_rel, axis=0) 23 | 24 | with open(dump_file, 'wb') as f: 25 | pickle.dump(bboxes_rel, f) 26 | 27 | def prep_box(box_list): 28 | global augment_idx 29 | boxes = np.array([box_list[2:]], dtype=np.float32) 30 | 31 | # Image width and height 32 | width, height = box_list[:2] 33 | 34 | # To point form 35 | boxes[:, 2:] += boxes[:, :2] 36 | 37 | 38 | # Expand 39 | ratio = random.uniform(1, 4) 40 | left = random.uniform(0, width*ratio - width) 41 | top = random.uniform(0, height*ratio - height) 42 | 43 | height *= ratio 44 | width *= ratio 45 | 46 | boxes[:, :2] += (int(left), int(top)) 47 | boxes[:, 2:] += (int(left), int(top)) 48 | 49 | 50 | # RandomSampleCrop 51 | height, width, boxes = random_sample_crop(height, width, boxes) 52 | 53 | 54 | # RandomMirror 55 | if random.randint(0, 2): 56 | boxes[:, 0::2] = width - boxes[:, 2::-2] 57 | 58 | 59 | # Resize 60 | boxes[:, [0, 2]] *= (max_image_size / width) 61 | boxes[:, [1, 3]] *= (max_image_size / height) 62 | width = height = max_image_size 63 | 64 | 65 | # ToPercentCoords 66 | boxes[:, [0, 2]] /= width 67 | boxes[:, [1, 3]] /= height 68 | 69 | if augment_idx % 50000 == 0: 70 | print('Current idx: %d' % augment_idx) 71 | 72 | augment_idx += 1 73 | 74 | return boxes 75 | 76 | 77 | 78 | 79 | sample_options = ( 80 | # using entire original input image 81 | None, 82 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 83 | (0.1, None), 84 | (0.3, None), 85 | (0.7, None), 86 | (0.9, None), 87 | # randomly sample a patch 88 | (None, None), 89 | ) 90 | 91 | def intersect(box_a, box_b): 92 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 93 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 94 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 95 | return inter[:, 0] * inter[:, 1] 96 | 97 | 98 | def jaccard_numpy(box_a, box_b): 99 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 100 | is simply the intersection over union of two boxes. 101 | E.g.: 102 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 103 | Args: 104 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 105 | box_b: Single bounding box, Shape: [4] 106 | Return: 107 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 108 | """ 109 | inter = intersect(box_a, box_b) 110 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 111 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 112 | area_b = ((box_b[2]-box_b[0]) * 113 | (box_b[3]-box_b[1])) # [A,B] 114 | union = area_a + area_b - inter 115 | return inter / union # [A,B] 116 | 117 | 118 | def random_sample_crop(height, width, boxes=None): 119 | global sample_options 120 | 121 | while True: 122 | # randomly choose a mode 123 | mode = random.choice(sample_options) 124 | if mode is None: 125 | return height, width, boxes 126 | 127 | min_iou, max_iou = mode 128 | if min_iou is None: 129 | min_iou = float('-inf') 130 | if max_iou is None: 131 | max_iou = float('inf') 132 | 133 | for _ in range(50): 134 | w = random.uniform(0.3 * width, width) 135 | h = random.uniform(0.3 * height, height) 136 | 137 | if h / w < 0.5 or h / w > 2: 138 | continue 139 | 140 | left = random.uniform(0, width - w) 141 | top = random.uniform(0, height - h) 142 | 143 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 144 | overlap = jaccard_numpy(boxes, rect) 145 | if overlap.min() < min_iou and max_iou < overlap.max(): 146 | continue 147 | 148 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 149 | 150 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 151 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 152 | mask = m1 * m2 153 | 154 | if not mask.any(): 155 | continue 156 | 157 | current_boxes = boxes[mask, :].copy() 158 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) 159 | current_boxes[:, :2] -= rect[:2] 160 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) 161 | current_boxes[:, 2:] -= rect[:2] 162 | 163 | return h, w, current_boxes 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | with open(box_file, 'rb') as f: 169 | bboxes = pickle.load(f) 170 | 171 | augment_boxes(bboxes) 172 | -------------------------------------------------------------------------------- /scripts/bbox_recall.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script compiles all the bounding boxes in the training data and 3 | clusters them for each convout resolution on which they're used. 4 | 5 | Run this script from the Yolact root directory. 6 | """ 7 | 8 | import os.path as osp 9 | import json, pickle 10 | import sys 11 | from math import sqrt 12 | from itertools import product 13 | import torch 14 | import random 15 | 16 | import numpy as np 17 | 18 | dump_file = 'weights/bboxes.pkl' 19 | aug_file = 'weights/bboxes_aug.pkl' 20 | 21 | use_augmented_boxes = True 22 | 23 | 24 | def intersect(box_a, box_b): 25 | """ We resize both tensors to [A,B,2] without new malloc: 26 | [A,2] -> [A,1,2] -> [A,B,2] 27 | [B,2] -> [1,B,2] -> [A,B,2] 28 | Then we compute the area of intersect between box_a and box_b. 29 | Args: 30 | box_a: (tensor) bounding boxes, Shape: [A,4]. 31 | box_b: (tensor) bounding boxes, Shape: [B,4]. 32 | Return: 33 | (tensor) intersection area, Shape: [A,B]. 34 | """ 35 | A = box_a.size(0) 36 | B = box_b.size(0) 37 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 38 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 39 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 40 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 41 | inter = torch.clamp((max_xy - min_xy), min=0) 42 | return inter[:, :, 0] * inter[:, :, 1] 43 | 44 | 45 | def jaccard(box_a, box_b, iscrowd=False): 46 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 47 | is simply the intersection over union of two boxes. Here we operate on 48 | ground truth boxes and default boxes. If iscrowd=True, put the crowd in box_b. 49 | E.g.: 50 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 51 | Args: 52 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 53 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 54 | Return: 55 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 56 | """ 57 | inter = intersect(box_a, box_b) 58 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 59 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 60 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 61 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 62 | union = area_a + area_b - inter 63 | 64 | if iscrowd: 65 | return inter / area_a 66 | else: 67 | return inter / union # [A,B] 68 | 69 | # Also convert to point form 70 | def to_relative(bboxes): 71 | return np.concatenate((bboxes[:, 2:4] / bboxes[:, :2], (bboxes[:, 2:4] + bboxes[:, 4:]) / bboxes[:, :2]), axis=1) 72 | 73 | 74 | def make_priors(conv_size, scales, aspect_ratios): 75 | prior_data = [] 76 | conv_h = conv_size[0] 77 | conv_w = conv_size[1] 78 | 79 | # Iteration order is important (it has to sync up with the convout) 80 | for j, i in product(range(conv_h), range(conv_w)): 81 | x = (i + 0.5) / conv_w 82 | y = (j + 0.5) / conv_h 83 | 84 | for scale, ars in zip(scales, aspect_ratios): 85 | for ar in ars: 86 | w = scale * ar / conv_w 87 | h = scale / ar / conv_h 88 | 89 | # Point form 90 | prior_data += [x - w/2, y - h/2, x + w/2, y + h/2] 91 | 92 | return np.array(prior_data).reshape(-1, 4) 93 | 94 | # fixed_ssd_config 95 | # scales = [[3.5, 4.95], [3.6, 4.90], [3.3, 4.02], [2.7, 3.10], [2.1, 2.37], [2.1, 2.37], [1.8, 1.92]] 96 | # aspect_ratios = [ [[1, sqrt(2), 1/sqrt(2), sqrt(3), 1/sqrt(3)][:n], [1]] for n in [3, 5, 5, 5, 3, 3, 3] ] 97 | # conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 98 | 99 | scales = [[1.68, 2.91], 100 | [2.95, 2.22, 0.84], 101 | [2.23, 2.17, 3.12], 102 | [0.76, 1.94, 2.72], 103 | [2.10, 2.65], 104 | [1.80, 1.92]] 105 | aspect_ratios = [[[0.72, 0.96], [0.68, 1.17]], 106 | [[1.28, 0.66], [0.63, 1.23], [0.89, 1.40]], 107 | [[2.05, 1.24], [0.57, 0.83], [0.61, 1.15]], 108 | [[1.00, 2.21], [0.47, 1.60], [1.44, 0.79]], 109 | [[1.00, 1.41, 0.71, 1.73, 0.58], [1.08]], 110 | [[1.00, 1.41, 0.71, 1.73, 0.58], [1.00]]] 111 | conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 112 | 113 | # yrm33_config 114 | # scales = [ [5.3] ] * 5 115 | # aspect_ratios = [ [[1, 1/sqrt(2), sqrt(2)]] ]*5 116 | # conv_sizes = [(136, 136), (67, 67), (33, 33), (16, 16), (8, 8)] 117 | 118 | 119 | SMALL = 0 120 | MEDIUM = 1 121 | LARGE = 2 122 | 123 | if __name__ == '__main__': 124 | 125 | with open(dump_file, 'rb') as f: 126 | bboxes = pickle.load(f) 127 | 128 | sizes = [] 129 | smalls = [] 130 | for i in range(len(bboxes)): 131 | area = bboxes[i][4] * bboxes[i][5] 132 | if area < 32 ** 2: 133 | sizes.append(SMALL) 134 | smalls.append(area) 135 | elif area < 96 ** 2: 136 | sizes.append(MEDIUM) 137 | else: 138 | sizes.append(LARGE) 139 | 140 | # Each box is in the form [im_w, im_h, pos_x, pos_y, size_x, size_y] 141 | 142 | if use_augmented_boxes: 143 | with open(aug_file, 'rb') as f: 144 | bboxes_rel = pickle.load(f) 145 | else: 146 | bboxes_rel = to_relative(np.array(bboxes)) 147 | 148 | 149 | with torch.no_grad(): 150 | sizes = torch.Tensor(sizes) 151 | 152 | anchors = [make_priors(cs, s, ar) for cs, s, ar in zip(conv_sizes, scales, aspect_ratios)] 153 | anchors = np.concatenate(anchors, axis=0) 154 | anchors = torch.Tensor(anchors).cuda() 155 | 156 | bboxes_rel = torch.Tensor(bboxes_rel).cuda() 157 | perGTAnchorMax = torch.zeros(bboxes_rel.shape[0]).cuda() 158 | 159 | chunk_size = 1000 160 | for i in range((bboxes_rel.size(0) // chunk_size) + 1): 161 | start = i * chunk_size 162 | end = min((i + 1) * chunk_size, bboxes_rel.size(0)) 163 | 164 | ious = jaccard(bboxes_rel[start:end, :], anchors) 165 | maxes, maxidx = torch.max(ious, dim=1) 166 | 167 | perGTAnchorMax[start:end] = maxes 168 | 169 | 170 | hits = (perGTAnchorMax > 0.5).float() 171 | 172 | print('Total recall: %.2f' % (torch.sum(hits) / hits.size(0) * 100)) 173 | print() 174 | 175 | for i, metric in zip(range(3), ('small', 'medium', 'large')): 176 | _hits = hits[sizes == i] 177 | _size = (1 if _hits.size(0) == 0 else _hits.size(0)) 178 | print(metric + ' recall: %.2f' % ((torch.sum(_hits) / _size) * 100)) 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /scripts/cluster_bbox_sizes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script compiles all the bounding boxes in the training data and 3 | clusters them for each convout resolution on which they're used. 4 | 5 | Run this script from the Yolact root directory. 6 | """ 7 | 8 | import os.path as osp 9 | import json, pickle 10 | import sys 11 | 12 | import numpy as np 13 | import sklearn.cluster as cluster 14 | 15 | dump_file = 'weights/bboxes.pkl' 16 | max_size = 550 17 | 18 | num_scale_clusters = 5 19 | num_aspect_ratio_clusters = 3 20 | 21 | def to_relative(bboxes): 22 | return bboxes[:, 2:4] / bboxes[:, :2] 23 | 24 | def process(bboxes): 25 | return to_relative(bboxes) * max_size 26 | 27 | if __name__ == '__main__': 28 | 29 | with open(dump_file, 'rb') as f: 30 | bboxes = pickle.load(f) 31 | 32 | bboxes = np.array(bboxes) 33 | bboxes = process(bboxes) 34 | bboxes = bboxes[(bboxes[:, 0] > 1) * (bboxes[:, 1] > 1)] 35 | 36 | scale = np.sqrt(bboxes[:, 0] * bboxes[:, 1]).reshape(-1, 1) 37 | 38 | clusterer = cluster.KMeans(num_scale_clusters, random_state=99, n_jobs=4) 39 | assignments = clusterer.fit_predict(scale) 40 | counts = np.bincount(assignments) 41 | 42 | cluster_centers = clusterer.cluster_centers_ 43 | 44 | center_indices = list(range(num_scale_clusters)) 45 | center_indices.sort(key=lambda x: cluster_centers[x, 0]) 46 | 47 | for idx in center_indices: 48 | center = cluster_centers[idx, 0] 49 | boxes_for_center = bboxes[assignments == idx] 50 | aspect_ratios = (boxes_for_center[:,0] / boxes_for_center[:,1]).reshape(-1, 1) 51 | 52 | c = cluster.KMeans(num_aspect_ratio_clusters, random_state=idx, n_jobs=4) 53 | ca = c.fit_predict(aspect_ratios) 54 | cc = np.bincount(ca) 55 | 56 | c = list(c.cluster_centers_.reshape(-1)) 57 | cidx = list(range(num_aspect_ratio_clusters)) 58 | cidx.sort(key=lambda x: -cc[x]) 59 | 60 | # import code 61 | # code.interact(local=locals()) 62 | 63 | print('%.3f (%d) aspect ratios:' % (center, counts[idx])) 64 | for idx in cidx: 65 | print('\t%.2f (%d)' % (c[idx], cc[idx])) 66 | print() 67 | # exit() 68 | 69 | 70 | -------------------------------------------------------------------------------- /scripts/compute_masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 8 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 9 | 10 | def mask_iou(mask1, mask2): 11 | """ 12 | Inputs inputs are matricies of size _ x N. Output is size _1 x _2. 13 | Note: if iscrowd is True, then mask2 should be the crowd. 14 | """ 15 | intersection = torch.matmul(mask1, mask2.t()) 16 | area1 = torch.sum(mask1, dim=1).view(1, -1) 17 | area2 = torch.sum(mask2, dim=1).view(1, -1) 18 | union = (area1.t() + area2) - intersection 19 | 20 | return intersection / union 21 | 22 | def paint_mask(img_numpy, mask, color): 23 | h, w, _ = img_numpy.shape 24 | img_numpy = img_numpy.copy() 25 | 26 | mask = np.tile(mask.reshape(h, w, 1), (1, 1, 3)) 27 | color_np = np.array(color[:3]).reshape(1, 1, 3) 28 | color_np = np.tile(color_np, (h, w, 1)) 29 | mask_color = mask * color_np 30 | 31 | mask_alpha = 0.3 32 | 33 | # Blend image and mask 34 | image_crop = img_numpy * mask 35 | img_numpy *= (1-mask) 36 | img_numpy += image_crop * (1-mask_alpha) + mask_color * mask_alpha 37 | 38 | return img_numpy 39 | 40 | # Inverse sigmoid 41 | def logit(x): 42 | return np.log(x / (1-x + 0.0001) + 0.0001) 43 | 44 | def sigmoid(x): 45 | return 1 / (1 + np.exp(-x)) 46 | 47 | img_fmt = '../data/coco/images/%012d.jpg' 48 | with open('info.txt', 'r') as f: 49 | img_id = int(f.read()) 50 | 51 | img = plt.imread(img_fmt % img_id).astype(np.float32) 52 | h, w, _ = img.shape 53 | 54 | gt_masks = np.load('gt.npy').astype(np.float32).transpose(1, 2, 0) 55 | proto_masks = np.load('proto.npy').astype(np.float32) 56 | 57 | proto_masks = torch.Tensor(proto_masks).permute(2, 0, 1).contiguous().unsqueeze(0) 58 | proto_masks = F.interpolate(proto_masks, (h, w), mode='bilinear', align_corners=False).squeeze(0) 59 | proto_masks = proto_masks.permute(1, 2, 0).numpy() 60 | 61 | # # A x = b 62 | ls_A = proto_masks.reshape(-1, proto_masks.shape[-1]) 63 | ls_b = gt_masks.reshape(-1, gt_masks.shape[-1]) 64 | 65 | # x is size [256, num_gt] 66 | x = np.linalg.lstsq(ls_A, ls_b, rcond=None)[0] 67 | 68 | approximated_masks = (np.matmul(proto_masks, x) > 0.5).astype(np.float32) 69 | 70 | num_gt = approximated_masks.shape[2] 71 | ious = mask_iou(torch.Tensor(approximated_masks.reshape(-1, num_gt).T), 72 | torch.Tensor(gt_masks.reshape(-1, num_gt).T)) 73 | 74 | ious = [int(ious[i, i].item() * 100) for i in range(num_gt)] 75 | ious.sort(key=lambda x: -x) 76 | 77 | print(ious) 78 | 79 | gt_img = img.copy() 80 | 81 | for i in range(num_gt): 82 | gt_img = paint_mask(gt_img, gt_masks[:, :, i], COLORS[i % len(COLORS)]) 83 | 84 | plt.imshow(gt_img / 255) 85 | plt.title('GT') 86 | plt.show() 87 | 88 | for i in range(num_gt): 89 | img = paint_mask(img, approximated_masks[:, :, i], COLORS[i % len(COLORS)]) 90 | 91 | plt.imshow(img / 255) 92 | plt.title('Approximated') 93 | plt.show() 94 | -------------------------------------------------------------------------------- /scripts/convert_darknet.py: -------------------------------------------------------------------------------- 1 | from backbone import DarkNetBackbone 2 | import h5py 3 | import torch 4 | 5 | f = h5py.File('darknet53.h5', 'r') 6 | m = f['model_weights'] 7 | 8 | yolo_keys = list(m.keys()) 9 | yolo_keys = [x for x in yolo_keys if len(m[x].keys()) > 0] 10 | yolo_keys.sort() 11 | 12 | sd = DarkNetBackbone().state_dict() 13 | 14 | sd_keys = list(sd.keys()) 15 | sd_keys.sort() 16 | 17 | # Note this won't work if there are 10 elements in some list but whatever that doesn't happen 18 | layer_keys = list(set(['.'.join(x.split('.')[:-2]) for x in sd_keys])) 19 | layer_keys.sort() 20 | 21 | # print([x for x in sd_keys if x.startswith(layer_keys[0])]) 22 | 23 | mapping = { 24 | '.0.weight' : ('conv2d_%d', 'kernel:0'), 25 | '.1.bias' : ('batch_normalization_%d', 'beta:0'), 26 | '.1.weight' : ('batch_normalization_%d', 'gamma:0'), 27 | '.1.running_var' : ('batch_normalization_%d', 'moving_variance:0'), 28 | '.1.running_mean': ('batch_normalization_%d', 'moving_mean:0'), 29 | '.1.num_batches_tracked': None, 30 | } 31 | 32 | for i, layer_key in zip(range(1, len(layer_keys) + 1), layer_keys): 33 | # This is pretty inefficient but I don't care 34 | for weight_key in [x for x in sd_keys if x.startswith(layer_key)]: 35 | diff = weight_key[len(layer_key):] 36 | 37 | if mapping[diff] is not None: 38 | yolo_key = mapping[diff][0] % i 39 | sub_key = mapping[diff][1] 40 | 41 | yolo_weight = torch.Tensor(m[yolo_key][yolo_key][sub_key].value) 42 | if (len(yolo_weight.size()) == 4): 43 | yolo_weight = yolo_weight.permute(3, 2, 0, 1).contiguous() 44 | 45 | sd[weight_key] = yolo_weight 46 | 47 | torch.save(sd, 'weights/darknet53.pth') 48 | 49 | -------------------------------------------------------------------------------- /scripts/convert_sbd.py: -------------------------------------------------------------------------------- 1 | import scipy.io, scipy.ndimage 2 | import os.path, json 3 | import pycocotools.mask 4 | import numpy as np 5 | 6 | def mask2bbox(mask): 7 | rows = np.any(mask, axis=1) 8 | cols = np.any(mask, axis=0) 9 | rmin, rmax = np.where(rows)[0][[0, -1]] 10 | cmin, cmax = np.where(cols)[0][[0, -1]] 11 | 12 | return cmin, rmin, cmax - cmin, rmax - rmin 13 | 14 | 15 | 16 | inst_path = './inst/' 17 | img_path = './img/' 18 | img_name_fmt = '%s.jpg' 19 | ann_name_fmt = '%s.mat' 20 | 21 | image_id = 1 22 | ann_id = 1 23 | 24 | types = ['train', 'val'] 25 | 26 | for t in types: 27 | with open('%s.txt' % t, 'r') as f: 28 | names = f.read().strip().split('\n') 29 | 30 | images = [] 31 | annotations = [] 32 | 33 | for name in names: 34 | img_name = img_name_fmt % name 35 | 36 | ann_path = os.path.join(inst_path, ann_name_fmt % name) 37 | ann = scipy.io.loadmat(ann_path)['GTinst'][0][0] 38 | 39 | classes = [int(x[0]) for x in ann[2]] 40 | seg = ann[0] 41 | 42 | for idx in range(len(classes)): 43 | mask = (seg == (idx + 1)).astype(np.float) 44 | 45 | rle = pycocotools.mask.encode(np.asfortranarray(mask.astype(np.uint8))) 46 | rle['counts'] = rle['counts'].decode('ascii') 47 | 48 | annotations.append({ 49 | 'id': ann_id, 50 | 'image_id': image_id, 51 | 'category_id': classes[idx], 52 | 'segmentation': rle, 53 | 'area': float(mask.sum()), 54 | 'bbox': [int(x) for x in mask2bbox(mask)], 55 | 'iscrowd': 0 56 | }) 57 | 58 | ann_id += 1 59 | 60 | img_name = img_name_fmt % name 61 | img = scipy.ndimage.imread(os.path.join(img_path, img_name)) 62 | 63 | images.append({ 64 | 'id': image_id, 65 | 'width': img.shape[1], 66 | 'height': img.shape[0], 67 | 'file_name': img_name 68 | }) 69 | 70 | image_id += 1 71 | 72 | info = { 73 | 'year': 2012, 74 | 'version': 1, 75 | 'description': 'Pascal SBD', 76 | } 77 | 78 | categories = [{'id': x+1} for x in range(20)] 79 | 80 | with open('pascal_sbd_%s.json' % t, 'w') as f: 81 | json.dump({ 82 | 'info': info, 83 | 'images': images, 84 | 'annotations': annotations, 85 | 'licenses': {}, 86 | 'categories': categories 87 | }, f) 88 | 89 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-small 3 | #SBATCH -t 2:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./eval.sh weights extra_args 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 eval.py --trained_model=$1 --no_bar $2 > logs/eval/$(basename -- $1).log 2>&1 15 | -------------------------------------------------------------------------------- /scripts/make_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math, random 3 | 4 | import matplotlib.pyplot as plt 5 | from matplotlib.widgets import Slider, Button 6 | 7 | 8 | fig, ax = plt.subplots() 9 | plt.subplots_adjust(bottom=0.24) 10 | im_handle = None 11 | 12 | save_path = 'grid.npy' 13 | 14 | center_x, center_y = (0.5, 0.5) 15 | grid_w, grid_h = (35, 35) 16 | spacing = 0 17 | scale = 4 18 | angle = 0 19 | grid = None 20 | 21 | all_grids = [] 22 | unique = False 23 | 24 | # A hack 25 | disable_render = False 26 | 27 | def render(): 28 | if disable_render: 29 | return 30 | 31 | x = np.tile(np.array(list(range(grid_w)), dtype=np.float).reshape(1, grid_w), [grid_h, 1]) - grid_w * center_x 32 | y = np.tile(np.array(list(range(grid_h)), dtype=np.float).reshape(grid_h, 1), [1, grid_w]) - grid_h * center_y 33 | 34 | x /= scale 35 | y /= scale 36 | 37 | a1 = angle + math.pi / 3 38 | a2 = -angle + math.pi / 3 39 | a3 = angle 40 | 41 | z1 = x * math.sin(a1) + y * math.cos(a1) 42 | z2 = x * math.sin(a2) - y * math.cos(a2) 43 | z3 = x * math.sin(a3) + y * math.cos(a3) 44 | 45 | s1 = np.square(np.sin(z1)) 46 | s2 = np.square(np.sin(z2)) 47 | s3 = np.square(np.sin(z3)) 48 | 49 | line_1 = np.exp(s1 * spacing) * s1 50 | line_2 = np.exp(s2 * spacing) * s2 51 | line_3 = np.exp(s3 * spacing) * s3 52 | 53 | global grid 54 | grid = np.clip(1 - (line_1 + line_2 + line_3) / 3, 0, 1) 55 | 56 | global im_handle 57 | if im_handle is None: 58 | im_handle = plt.imshow(grid) 59 | else: 60 | im_handle.set_data(grid) 61 | fig.canvas.draw_idle() 62 | 63 | def update_scale(val): 64 | global scale 65 | scale = val 66 | 67 | render() 68 | 69 | def update_angle(val): 70 | global angle 71 | angle = val 72 | 73 | render() 74 | 75 | def update_centerx(val): 76 | global center_x 77 | center_x = val 78 | 79 | render() 80 | 81 | def update_centery(val): 82 | global center_y 83 | center_y = val 84 | 85 | render() 86 | 87 | def update_spacing(val): 88 | global spacing 89 | spacing = val 90 | 91 | render() 92 | 93 | def randomize(val): 94 | global center_x, center_y, spacing, scale, angle, disable_render 95 | 96 | center_x, center_y = (random.uniform(0, 1), random.uniform(0, 1)) 97 | spacing = random.uniform(-0.2, 2) 98 | scale = 4 * math.exp(random.uniform(-1, 1)) 99 | angle = random.uniform(-math.pi, math.pi) 100 | 101 | disable_render = True 102 | 103 | scale_slider.set_val(scale) 104 | angle_slider.set_val(angle) 105 | centx_slider.set_val(center_x) 106 | centy_slider.set_val(center_y) 107 | spaci_slider.set_val(spacing) 108 | 109 | disable_render = False 110 | 111 | render() 112 | 113 | def add(val): 114 | all_grids.append(grid) 115 | 116 | global unique 117 | if not unique: 118 | unique = test_uniqueness(np.stack(all_grids)) 119 | 120 | export_len_text.set_text('Num Grids: ' + str(len(all_grids))) 121 | fig.canvas.draw_idle() 122 | 123 | def add_randomize(val): 124 | add(val) 125 | randomize(val) 126 | 127 | def export(val): 128 | np.save(save_path, np.stack(all_grids)) 129 | print('Saved %d grids to "%s"' % (len(all_grids), save_path)) 130 | 131 | global unique 132 | unique = False 133 | all_grids.clear() 134 | 135 | export_len_text.set_text('Num Grids: ' + str(len(all_grids))) 136 | fig.canvas.draw_idle() 137 | 138 | def test_uniqueness(grids): 139 | # Grids shape [ngrids, h, w] 140 | grids = grids.reshape((-1, grid_h, grid_w)) 141 | 142 | for y in range(grid_h): 143 | for x in range(grid_h): 144 | pixel_features = grids[:, y, x] 145 | 146 | # l1 distance for this pixel with every other 147 | l1_dist = np.sum(np.abs(grids - np.tile(pixel_features, grid_h*grid_w).reshape((-1, grid_h, grid_w))), axis=0) 148 | 149 | # Equal if l1 distance is really small. Note that this will include this pixel 150 | num_equal = np.sum((l1_dist < 0.0001).astype(np.int32)) 151 | 152 | if num_equal > 1: 153 | print('Pixel at (%d, %d) has %d other pixel%s with the same representation.' % (x, y, num_equal-1, '' if num_equal==2 else 's')) 154 | return False 155 | 156 | print('Each pixel has a distinct representation.') 157 | return True 158 | 159 | 160 | 161 | render() 162 | 163 | axis = plt.axes([0.22, 0.19, 0.59, 0.03], facecolor='lightgoldenrodyellow') 164 | scale_slider = Slider(axis, 'Scale', 0.1, 20, valinit=scale, valstep=0.1) 165 | scale_slider.on_changed(update_scale) 166 | 167 | axis = plt.axes([0.22, 0.15, 0.59, 0.03], facecolor='lightgoldenrodyellow') 168 | angle_slider = Slider(axis, 'Angle', -math.pi, math.pi, valinit=angle, valstep=0.1) 169 | angle_slider.on_changed(update_angle) 170 | 171 | axis = plt.axes([0.22, 0.11, 0.59, 0.03], facecolor='lightgoldenrodyellow') 172 | centx_slider = Slider(axis, 'Center X', 0, 1, valinit=center_x, valstep=0.05) 173 | centx_slider.on_changed(update_centerx) 174 | 175 | axis = plt.axes([0.22, 0.07, 0.59, 0.03], facecolor='lightgoldenrodyellow') 176 | centy_slider = Slider(axis, 'Center Y', 0, 1, valinit=center_y, valstep=0.05) 177 | centy_slider.on_changed(update_centery) 178 | 179 | axis = plt.axes([0.22, 0.03, 0.59, 0.03], facecolor='lightgoldenrodyellow') 180 | spaci_slider = Slider(axis, 'Spacing', -1, 2, valinit=spacing, valstep=0.05) 181 | spaci_slider.on_changed(update_spacing) 182 | 183 | axis = plt.axes([0.8, 0.54, 0.15, 0.05], facecolor='lightgoldenrodyellow') 184 | rando_button = Button(axis, 'Randomize') 185 | rando_button.on_clicked(randomize) 186 | 187 | axis = plt.axes([0.8, 0.48, 0.15, 0.05], facecolor='lightgoldenrodyellow') 188 | addgr_button = Button(axis, 'Add') 189 | addgr_button.on_clicked(add) 190 | 191 | # Likely not a good way to do this but whatever 192 | export_len_text = plt.text(0, 3, 'Num Grids: 0') 193 | 194 | axis = plt.axes([0.8, 0.42, 0.15, 0.05], facecolor='lightgoldenrodyellow') 195 | addra_button = Button(axis, 'Add / Rand') 196 | addra_button.on_clicked(add_randomize) 197 | 198 | axis = plt.axes([0.8, 0.36, 0.15, 0.05], facecolor='lightgoldenrodyellow') 199 | saveg_button = Button(axis, 'Save') 200 | saveg_button.on_clicked(export) 201 | 202 | 203 | 204 | plt.show() 205 | -------------------------------------------------------------------------------- /scripts/optimize_bboxes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instead of clustering bbox widths and heights, this script 3 | directly optimizes average IoU across the training set given 4 | the specified number of anchor boxes. 5 | 6 | Run this script from the Yolact root directory. 7 | """ 8 | 9 | import pickle 10 | import random 11 | from itertools import product 12 | from math import sqrt 13 | 14 | import numpy as np 15 | import torch 16 | from scipy.optimize import minimize 17 | 18 | dump_file = 'weights/bboxes.pkl' 19 | aug_file = 'weights/bboxes_aug.pkl' 20 | 21 | use_augmented_boxes = True 22 | 23 | 24 | def intersect(box_a, box_b): 25 | """ We resize both tensors to [A,B,2] without new malloc: 26 | [A,2] -> [A,1,2] -> [A,B,2] 27 | [B,2] -> [1,B,2] -> [A,B,2] 28 | Then we compute the area of intersect between box_a and box_b. 29 | Args: 30 | box_a: (tensor) bounding boxes, Shape: [A,4]. 31 | box_b: (tensor) bounding boxes, Shape: [B,4]. 32 | Return: 33 | (tensor) intersection area, Shape: [A,B]. 34 | """ 35 | A = box_a.size(0) 36 | B = box_b.size(0) 37 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 38 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 39 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 40 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 41 | inter = torch.clamp((max_xy - min_xy), min=0) 42 | return inter[:, :, 0] * inter[:, :, 1] 43 | 44 | 45 | def jaccard(box_a, box_b, iscrowd=False): 46 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 47 | is simply the intersection over union of two boxes. Here we operate on 48 | ground truth boxes and default boxes. If iscrowd=True, put the crowd in box_b. 49 | E.g.: 50 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 51 | Args: 52 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 53 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 54 | Return: 55 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 56 | """ 57 | inter = intersect(box_a, box_b) 58 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 59 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 60 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 61 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 62 | union = area_a + area_b - inter 63 | 64 | if iscrowd: 65 | return inter / area_a 66 | else: 67 | return inter / union # [A,B] 68 | 69 | # Also convert to point form 70 | def to_relative(bboxes): 71 | return np.concatenate((bboxes[:, 2:4] / bboxes[:, :2], (bboxes[:, 2:4] + bboxes[:, 4:]) / bboxes[:, :2]), axis=1) 72 | 73 | 74 | def make_priors(conv_size, scales, aspect_ratios): 75 | prior_data = [] 76 | conv_h = conv_size[0] 77 | conv_w = conv_size[1] 78 | 79 | # Iteration order is important (it has to sync up with the convout) 80 | for j, i in product(range(conv_h), range(conv_w)): 81 | x = (i + 0.5) / conv_w 82 | y = (j + 0.5) / conv_h 83 | 84 | for scale, ars in zip(scales, aspect_ratios): 85 | for ar in ars: 86 | w = scale * ar / conv_w 87 | h = scale / ar / conv_h 88 | 89 | # Point form 90 | prior_data += [x - w/2, y - h/2, x + w/2, y + h/2] 91 | return torch.Tensor(prior_data).view(-1, 4).cuda() 92 | 93 | 94 | 95 | scales = [[1.68, 2.91], [2.95, 2.22, 0.84], [2.17, 2.22, 3.22], [0.76, 2.06, 2.81], [5.33, 2.79], [13.69]] 96 | aspect_ratios = [[[0.72, 0.96], [0.68, 1.17]], [[1.30, 0.66], [0.63, 1.23], [0.87, 1.41]], [[1.96, 1.23], [0.58, 0.84], [0.61, 1.15]], [[19.79, 2.21], [0.47, 1.76], [1.38, 0.79]], [[4.79, 17.96], [1.04]], [[14.82]]] 97 | conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 98 | 99 | optimize_scales = False 100 | 101 | batch_idx = 0 102 | 103 | 104 | def compute_hits(bboxes, anchors, iou_threshold=0.5): 105 | ious = jaccard(bboxes, anchors) 106 | perGTAnchorMax, _ = torch.max(ious, dim=1) 107 | 108 | return (perGTAnchorMax > iou_threshold) 109 | 110 | def compute_recall(hits, base_hits): 111 | hits = (hits | base_hits).float() 112 | return torch.sum(hits) / hits.size(0) 113 | 114 | 115 | def step(x, x_func, bboxes, base_hits, optim_idx): 116 | # This should set the scale and aspect ratio 117 | x_func(x, scales[optim_idx], aspect_ratios[optim_idx]) 118 | 119 | anchors = make_priors(conv_sizes[optim_idx], scales[optim_idx], aspect_ratios[optim_idx]) 120 | 121 | return -float(compute_recall(compute_hits(bboxes, anchors), base_hits).cpu()) 122 | 123 | 124 | def optimize(full_bboxes, optim_idx, batch_size=5000): 125 | global batch_idx, scales, aspect_ratios, conv_sizes 126 | 127 | start = batch_idx * batch_size 128 | end = min((batch_idx + 1) * batch_size, full_bboxes.size(0)) 129 | 130 | if batch_idx > (full_bboxes.size(0) // batch_size): 131 | batch_idx = 0 132 | 133 | bboxes = full_bboxes[start:end, :] 134 | 135 | anchor_base = [ 136 | make_priors(conv_sizes[idx], scales[idx], aspect_ratios[idx]) 137 | for idx in range(len(conv_sizes)) if idx != optim_idx] 138 | base_hits = compute_hits(bboxes, torch.cat(anchor_base, dim=0)) 139 | 140 | 141 | def set_x(x, scales, aspect_ratios): 142 | if optimize_scales: 143 | for i in range(len(scales)): 144 | scales[i] = max(x[i], 0) 145 | else: 146 | k = 0 147 | for i in range(len(aspect_ratios)): 148 | for j in range(len(aspect_ratios[i])): 149 | aspect_ratios[i][j] = x[k] 150 | k += 1 151 | 152 | 153 | res = minimize(step, x0=scales[optim_idx] if optimize_scales else sum(aspect_ratios[optim_idx], []), method='Powell', 154 | args = (set_x, bboxes, base_hits, optim_idx),) 155 | 156 | 157 | def pretty_str(x:list): 158 | if isinstance(x, list): 159 | return '[' + ', '.join([pretty_str(y) for y in x]) + ']' 160 | elif isinstance(x, np.ndarray): 161 | return pretty_str(list(x)) 162 | else: 163 | return '%.2f' % x 164 | 165 | if __name__ == '__main__': 166 | 167 | if use_augmented_boxes: 168 | with open(aug_file, 'rb') as f: 169 | bboxes = pickle.load(f) 170 | else: 171 | # Load widths and heights from a dump file. Obtain this with 172 | # python3 scripts/save_bboxes.py 173 | with open(dump_file, 'rb') as f: 174 | bboxes = pickle.load(f) 175 | 176 | bboxes = np.array(bboxes) 177 | bboxes = to_relative(bboxes) 178 | 179 | with torch.no_grad(): 180 | bboxes = torch.Tensor(bboxes).cuda() 181 | 182 | def print_out(): 183 | if optimize_scales: 184 | print('Scales: ' + pretty_str(scales)) 185 | else: 186 | print('Aspect Ratios: ' + pretty_str(aspect_ratios)) 187 | 188 | for p in range(10): 189 | print('(Sub Iteration) ', end='') 190 | for i in range(len(conv_sizes)): 191 | print('%d ' % i, end='', flush=True) 192 | optimize(bboxes, i) 193 | print('Done', end='\r') 194 | 195 | print('(Iteration %d) ' % p, end='') 196 | print_out() 197 | print() 198 | 199 | optimize_scales = not optimize_scales 200 | 201 | print('scales = ' + pretty_str(scales)) 202 | print('aspect_ratios = ' + pretty_str(aspect_ratios)) 203 | 204 | 205 | -------------------------------------------------------------------------------- /scripts/parse_eval.py: -------------------------------------------------------------------------------- 1 | import re, sys, os 2 | import matplotlib.pyplot as plt 3 | from matplotlib._color_data import XKCD_COLORS 4 | 5 | with open(sys.argv[1], 'r') as f: 6 | txt = f.read() 7 | 8 | txt, overall = txt.split('overall performance') 9 | 10 | class_names = [] 11 | mAP_overall = [] 12 | mAP_small = [] 13 | mAP_medium = [] 14 | mAP_large = [] 15 | 16 | for class_result in txt.split('evaluate category: ')[1:]: 17 | lines = class_result.split('\n') 18 | class_names.append(lines[0]) 19 | 20 | def grabMAP(string): 21 | return float(string.split('] = ')[1]) * 100 22 | 23 | mAP_overall.append(grabMAP(lines[ 7])) 24 | mAP_small .append(grabMAP(lines[10])) 25 | mAP_medium .append(grabMAP(lines[11])) 26 | mAP_large .append(grabMAP(lines[12])) 27 | 28 | mAP_map = { 29 | 'small': mAP_small, 30 | 'medium': mAP_medium, 31 | 'large': mAP_large, 32 | } 33 | 34 | if len(sys.argv) > 2: 35 | bars = plt.bar(class_names, mAP_map[sys.argv[2]]) 36 | plt.title(sys.argv[2] + ' mAP per class') 37 | else: 38 | bars = plt.bar(class_names, mAP_overall) 39 | plt.title('overall mAP per class') 40 | 41 | colors = list(XKCD_COLORS.values()) 42 | 43 | for idx, bar in enumerate(bars): 44 | # Mmm pseudorandom colors 45 | char_sum = sum([ord(char) for char in class_names[idx]]) 46 | bar.set_color(colors[char_sum % len(colors)]) 47 | 48 | plt.xticks(rotation='vertical') 49 | plt.show() 50 | -------------------------------------------------------------------------------- /scripts/plot_loss.py: -------------------------------------------------------------------------------- 1 | import re, sys, os 2 | import matplotlib.pyplot as plt 3 | 4 | from utils.functions import MovingAverage 5 | 6 | with open(sys.argv[1], 'r') as f: 7 | inp = f.read() 8 | 9 | patterns = { 10 | 'train': re.compile(r'\[\s*(?P\d+)\]\s*(?P\d+) \|\| B: (?P\S+) \| C: (?P\S+) \| M: (?P\S+) \|( S: (?P\S+) \|)? T: (?P\S+)'), 11 | 'val': re.compile(r'\s*(?P[a-z]+) \|\s*(?P\S+)') 12 | } 13 | data = {key: [] for key in patterns} 14 | 15 | for line in inp.split('\n'): 16 | for key, pattern in patterns.items(): 17 | f = pattern.search(line) 18 | 19 | if f is not None: 20 | datum = f.groupdict() 21 | for k, v in datum.items(): 22 | if v is not None: 23 | try: 24 | v = float(v) 25 | except ValueError: 26 | pass 27 | datum[k] = v 28 | 29 | if key == 'val': 30 | datum = (datum, data['train'][-1]) 31 | data[key].append(datum) 32 | break 33 | 34 | 35 | def smoother(y, interval=100): 36 | avg = MovingAverage(interval) 37 | 38 | for i in range(len(y)): 39 | avg.append(y[i]) 40 | y[i] = avg.get_avg() 41 | 42 | return y 43 | 44 | def plot_train(data): 45 | plt.title(os.path.basename(sys.argv[1]) + ' Training Loss') 46 | plt.xlabel('Iteration') 47 | plt.ylabel('Loss') 48 | 49 | loss_names = ['BBox Loss', 'Conf Loss', 'Mask Loss'] 50 | 51 | x = [x['iteration'] for x in data] 52 | plt.plot(x, smoother([y['b'] for y in data])) 53 | plt.plot(x, smoother([y['c'] for y in data])) 54 | plt.plot(x, smoother([y['m'] for y in data])) 55 | 56 | if data[0]['s'] is not None: 57 | plt.plot(x, smoother([y['s'] for y in data])) 58 | loss_names.append('Segmentation Loss') 59 | 60 | plt.legend(loss_names) 61 | plt.show() 62 | 63 | def plot_val(data): 64 | plt.title(os.path.basename(sys.argv[1]) + ' Validation mAP') 65 | plt.xlabel('Epoch') 66 | plt.ylabel('mAP') 67 | 68 | x = [x[1]['epoch'] for x in data if x[0]['type'] == 'box'] 69 | plt.plot(x, [x[0]['all'] for x in data if x[0]['type'] == 'box']) 70 | plt.plot(x, [x[0]['all'] for x in data if x[0]['type'] == 'mask']) 71 | 72 | plt.legend(['BBox mAP', 'Mask mAP']) 73 | plt.show() 74 | 75 | if len(sys.argv) > 2 and sys.argv[2] == 'val': 76 | plot_val(data['val']) 77 | else: 78 | plot_train(data['train']) 79 | -------------------------------------------------------------------------------- /scripts/resume.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-shared 3 | #SBATCH -t 48:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./resume.sh config batch_size resume_file 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 train.py --config $1 --batch_size $2 --resume=$3 --save_interval 5000 --start_iter=-1 >>logs/$1_log 2>&1 15 | -------------------------------------------------------------------------------- /scripts/save_bboxes.py: -------------------------------------------------------------------------------- 1 | """ This script transforms and saves bbox coordinates into a pickle object for easy loading. """ 2 | 3 | 4 | import os.path as osp 5 | import json, pickle 6 | import sys 7 | 8 | import numpy as np 9 | 10 | COCO_ROOT = osp.join('.', 'data/coco/') 11 | 12 | annotation_file = 'instances_train2017.json' 13 | annotation_path = osp.join(COCO_ROOT, 'annotations/', annotation_file) 14 | 15 | dump_file = 'weights/bboxes.pkl' 16 | 17 | with open(annotation_path, 'r') as f: 18 | annotations_json = json.load(f) 19 | 20 | annotations = annotations_json['annotations'] 21 | images = annotations_json['images'] 22 | images = {image['id']: image for image in images} 23 | bboxes = [] 24 | 25 | for ann in annotations: 26 | image = images[ann['image_id']] 27 | w,h = (image['width'], image['height']) 28 | 29 | if 'bbox' in ann: 30 | bboxes.append([w, h] + ann['bbox']) 31 | 32 | with open(dump_file, 'wb') as f: 33 | pickle.dump(bboxes, f) 34 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-shared 3 | #SBATCH -t 48:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./train.sh config batch_size 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 train.py --config $1 --batch_size $2 --save_interval 5000 &>logs/$1_log 15 | -------------------------------------------------------------------------------- /scripts/unpack_statedict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys, os 3 | 4 | # Usage python scripts/unpack_statedict.py path_to_pth out_folder/ 5 | # Make sure to include that slash after your out folder, since I can't 6 | # be arsed to do path concatenation so I'd rather type out this comment 7 | 8 | print('Loading state dict...') 9 | state = torch.load(sys.argv[1]) 10 | 11 | if not os.path.exists(sys.argv[2]): 12 | os.mkdir(sys.argv[2]) 13 | 14 | print('Saving stuff...') 15 | for key, val in state.items(): 16 | torch.save(val, sys.argv[2] + key) 17 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import SSDAugmentation -------------------------------------------------------------------------------- /utils/cython_nms.pyx: -------------------------------------------------------------------------------- 1 | ## Note: Figure out the license details later. 2 | # 3 | # Based on: 4 | # -------------------------------------------------------- 5 | # Fast R-CNN 6 | # Copyright (c) 2015 Microsoft 7 | # Licensed under The MIT License [see LICENSE for details] 8 | # Written by Ross Girshick 9 | # -------------------------------------------------------- 10 | 11 | cimport cython 12 | import numpy as np 13 | cimport numpy as np 14 | 15 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b) nogil: 16 | return a if a >= b else b 17 | 18 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b) nogil: 19 | return a if a <= b else b 20 | 21 | @cython.boundscheck(False) 22 | @cython.cdivision(True) 23 | @cython.wraparound(False) 24 | def nms(np.ndarray[np.float32_t, ndim=2] dets, np.float32_t thresh): 25 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 26 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 27 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 28 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 29 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 30 | 31 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 32 | cdef np.ndarray[np.int64_t, ndim=1] order = scores.argsort()[::-1] 33 | 34 | cdef int ndets = dets.shape[0] 35 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 36 | np.zeros((ndets), dtype=np.int) 37 | 38 | # nominal indices 39 | cdef int _i, _j 40 | # sorted indices 41 | cdef int i, j 42 | # temp variables for box i's (the box currently under consideration) 43 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 44 | # variables for computing overlap with box j (lower scoring box) 45 | cdef np.float32_t xx1, yy1, xx2, yy2 46 | cdef np.float32_t w, h 47 | cdef np.float32_t inter, ovr 48 | 49 | with nogil: 50 | for _i in range(ndets): 51 | i = order[_i] 52 | if suppressed[i] == 1: 53 | continue 54 | ix1 = x1[i] 55 | iy1 = y1[i] 56 | ix2 = x2[i] 57 | iy2 = y2[i] 58 | iarea = areas[i] 59 | for _j in range(_i + 1, ndets): 60 | j = order[_j] 61 | if suppressed[j] == 1: 62 | continue 63 | xx1 = max(ix1, x1[j]) 64 | yy1 = max(iy1, y1[j]) 65 | xx2 = min(ix2, x2[j]) 66 | yy2 = min(iy2, y2[j]) 67 | w = max(0.0, xx2 - xx1 + 1) 68 | h = max(0.0, yy2 - yy1 + 1) 69 | inter = w * h 70 | ovr = inter / (iarea + areas[j] - inter) 71 | if ovr >= thresh: 72 | suppressed[j] = 1 73 | 74 | return np.where(suppressed == 0)[0] 75 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | from collections import deque 6 | from pathlib import Path 7 | from layers.interpolate import InterpolateModule 8 | 9 | class MovingAverage(): 10 | """ Keeps an average window of the specified number of items. """ 11 | 12 | def __init__(self, max_window_size=1000): 13 | self.max_window_size = max_window_size 14 | self.reset() 15 | 16 | def add(self, elem): 17 | """ Adds an element to the window, removing the earliest element if necessary. """ 18 | if not math.isfinite(elem): 19 | print('Warning: Moving average ignored a value of %f' % elem) 20 | return 21 | 22 | self.window.append(elem) 23 | self.sum += elem 24 | 25 | if len(self.window) > self.max_window_size: 26 | self.sum -= self.window.popleft() 27 | 28 | def append(self, elem): 29 | """ Same as add just more pythonic. """ 30 | self.add(elem) 31 | 32 | def reset(self): 33 | """ Resets the MovingAverage to its initial state. """ 34 | self.window = deque() 35 | self.sum = 0 36 | 37 | def get_avg(self): 38 | """ Returns the average of the elements in the window. """ 39 | return self.sum / max(len(self.window), 1) 40 | 41 | def __str__(self): 42 | return str(self.get_avg()) 43 | 44 | def __repr__(self): 45 | return repr(self.get_avg()) 46 | 47 | def __len__(self): 48 | return len(self.window) 49 | 50 | 51 | class ProgressBar(): 52 | """ A simple progress bar that just outputs a string. """ 53 | 54 | def __init__(self, length, max_val): 55 | self.max_val = max_val 56 | self.length = length 57 | self.cur_val = 0 58 | 59 | self.cur_num_bars = -1 60 | self._update_str() 61 | 62 | def set_val(self, new_val): 63 | self.cur_val = new_val 64 | 65 | if self.cur_val > self.max_val: 66 | self.cur_val = self.max_val 67 | if self.cur_val < 0: 68 | self.cur_val = 0 69 | 70 | self._update_str() 71 | 72 | def is_finished(self): 73 | return self.cur_val == self.max_val 74 | 75 | def _update_str(self): 76 | num_bars = int(self.length * (self.cur_val / self.max_val)) 77 | 78 | if num_bars != self.cur_num_bars: 79 | self.cur_num_bars = num_bars 80 | self.string = '█' * num_bars + '░' * (self.length - num_bars) 81 | 82 | def __repr__(self): 83 | return self.string 84 | 85 | def __str__(self): 86 | return self.string 87 | 88 | 89 | def init_console(): 90 | """ 91 | Initialize the console to be able to use ANSI escape characters on Windows. 92 | """ 93 | if os.name == 'nt': 94 | from colorama import init 95 | init() 96 | 97 | 98 | class SavePath: 99 | """ 100 | Why is this a class? 101 | Why do I have a class for creating and parsing save paths? 102 | What am I doing with my life? 103 | """ 104 | 105 | def __init__(self, model_name:str, epoch:int, iteration:int): 106 | self.model_name = model_name 107 | self.epoch = epoch 108 | self.iteration = iteration 109 | 110 | def get_path(self, root:str=''): 111 | file_name = self.model_name + '_' + str(self.epoch) + '_' + str(self.iteration) + '.pth' 112 | return os.path.join(root, file_name) 113 | 114 | @staticmethod 115 | def from_str(path:str): 116 | file_name = os.path.basename(path) 117 | 118 | if file_name.endswith('.pth'): 119 | file_name = file_name[:-4] 120 | 121 | params = file_name.split('_') 122 | 123 | if file_name.endswith('interrupt'): 124 | params = params[:-1] 125 | 126 | model_name = '_'.join(params[:-2]) 127 | epoch = params[-2] 128 | iteration = params[-1] 129 | 130 | return SavePath(model_name, int(epoch), int(iteration)) 131 | 132 | @staticmethod 133 | def remove_interrupt(save_folder): 134 | for p in Path(save_folder).glob('*_interrupt.pth'): 135 | p.unlink() 136 | 137 | @staticmethod 138 | def get_interrupt(save_folder): 139 | for p in Path(save_folder).glob('*_interrupt.pth'): 140 | return str(p) 141 | return None 142 | 143 | @staticmethod 144 | def get_latest(save_folder, config): 145 | """ Note: config should be config.name. """ 146 | max_iter = -1 147 | max_name = None 148 | 149 | for p in Path(save_folder).glob(config + '_*'): 150 | path_name = str(p) 151 | 152 | try: 153 | save = SavePath.from_str(path_name) 154 | except: 155 | continue 156 | 157 | if save.model_name == config and save.iteration > max_iter: 158 | max_iter = save.iteration 159 | max_name = path_name 160 | 161 | return max_name 162 | 163 | def make_net(in_channels, conf, include_last_relu=True): 164 | """ 165 | A helper function to take a config setting and turn it into a network. 166 | Used by protonet and extrahead. Returns (network, out_channels) 167 | """ 168 | def make_layer(layer_cfg): 169 | nonlocal in_channels 170 | 171 | # Possible patterns: 172 | # ( 256, 3, {}) -> conv 173 | # ( 256,-2, {}) -> deconv 174 | # (None,-2, {}) -> bilinear interpolate 175 | # ('cat',[],{}) -> concat the subnetworks in the list 176 | # 177 | # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme. 178 | # Whatever, it's too late now. 179 | if isinstance(layer_cfg[0], str): 180 | layer_name = layer_cfg[0] 181 | 182 | if layer_name == 'cat': 183 | nets = [make_net(in_channels, x) for x in layer_cfg[1]] 184 | layer = Concat([net[0] for net in nets], layer_cfg[2]) 185 | num_channels = sum([net[1] for net in nets]) 186 | else: 187 | num_channels = layer_cfg[0] 188 | kernel_size = layer_cfg[1] 189 | 190 | if kernel_size > 0: 191 | layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2]) 192 | else: 193 | if num_channels is None: 194 | layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, **layer_cfg[2]) 195 | else: 196 | layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2]) 197 | 198 | in_channels = num_channels if num_channels is not None else in_channels 199 | 200 | # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything 201 | # output-wise, but there's no need to go through a ReLU here. 202 | # Commented out for backwards compatibility with previous models 203 | # if num_channels is None: 204 | # return [layer] 205 | # else: 206 | return [layer, nn.ReLU(inplace=True)] 207 | 208 | # Use sum to concat together all the component layer lists 209 | net = sum([make_layer(x) for x in conf], []) 210 | if not include_last_relu: 211 | net = net[:-1] 212 | 213 | return nn.Sequential(*(net)), in_channels -------------------------------------------------------------------------------- /utils/nvinfo.py: -------------------------------------------------------------------------------- 1 | # My version of nvgpu because nvgpu didn't have all the information I was looking for. 2 | import re 3 | import subprocess 4 | import shutil 5 | import os 6 | 7 | def gpu_info() -> list: 8 | """ 9 | Returns a dictionary of stats mined from nvidia-smi for each gpu in a list. 10 | Adapted from nvgpu: https://pypi.org/project/nvgpu/, but mine has more info. 11 | """ 12 | gpus = [line for line in _run_cmd(['nvidia-smi', '-L']) if line] 13 | gpu_infos = [re.match('GPU ([0-9]+): ([^(]+) \(UUID: ([^)]+)\)', gpu).groups() for gpu in gpus] 14 | gpu_infos = [dict(zip(['idx', 'name', 'uuid'], info)) for info in gpu_infos] 15 | gpu_count = len(gpus) 16 | 17 | lines = _run_cmd(['nvidia-smi']) 18 | selected_lines = lines[7:7 + 3 * gpu_count] 19 | for i in range(gpu_count): 20 | mem_used, mem_total = [int(m.strip().replace('MiB', '')) for m in 21 | selected_lines[3 * i + 1].split('|')[2].strip().split('/')] 22 | 23 | pw_tmp_info, mem_info, util_info = [x.strip() for x in selected_lines[3 * i + 1].split('|')[1:-1]] 24 | 25 | pw_tmp_info = [x[:-1] for x in pw_tmp_info.split(' ') if len(x) > 0] 26 | fan_speed, temperature, pwr_used, pwr_cap = [int(pw_tmp_info[i]) for i in (0, 1, 3, 5)] 27 | gpu_infos[i]['fan_spd' ] = fan_speed 28 | gpu_infos[i]['temp' ] = temperature 29 | gpu_infos[i]['pwr_used'] = pwr_used 30 | gpu_infos[i]['pwr_cap' ] = pwr_cap 31 | 32 | mem_used, mem_total = [int(x) for x in mem_info.replace('MiB', '').split(' / ')] 33 | gpu_infos[i]['mem_used' ] = mem_used 34 | gpu_infos[i]['mem_total'] = mem_total 35 | 36 | utilization = int(util_info.split(' ')[0][:-1]) 37 | gpu_infos[i]['util'] = utilization 38 | 39 | gpu_infos[i]['idx'] = int(gpu_infos[i]['idx']) 40 | 41 | return gpu_infos 42 | 43 | def nvsmi_available() -> bool: 44 | """ Returns whether or not nvidia-smi is present in this system's PATH. """ 45 | return shutil.which('nvidia-smi') is not None 46 | 47 | 48 | def visible_gpus() -> list: 49 | """ Returns a list of the indexes of all the gpus visible to pytorch. """ 50 | 51 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 52 | return list(range(len(gpu_info()))) 53 | else: 54 | return [int(x.strip()) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')] 55 | 56 | 57 | 58 | 59 | def _run_cmd(cmd:list) -> list: 60 | """ Runs a command and returns a list of output lines. """ 61 | output = subprocess.check_output(cmd) 62 | output = output.decode('UTF-8') 63 | return output.split('\n') -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | _total_times = defaultdict(lambda: 0) 5 | _start_times = defaultdict(lambda: -1) 6 | _disabled_names = set() 7 | _timer_stack = [] 8 | _running_timer = None 9 | _disable_all = False 10 | 11 | def disable_all(): 12 | global _disable_all 13 | _disable_all = True 14 | 15 | def enable_all(): 16 | global _disable_all 17 | _disable_all = False 18 | 19 | def disable(fn_name): 20 | """ Disables the given function name fom being considered for the average or outputted in print_stats. """ 21 | _disabled_names.add(fn_name) 22 | 23 | def enable(fn_name): 24 | """ Enables function names disabled by disable. """ 25 | _disabled_names.remove(fn_name) 26 | 27 | def reset(): 28 | """ Resets the current timer. Call this at the start of an iteration. """ 29 | global _running_timer 30 | _total_times.clear() 31 | _start_times.clear() 32 | _timer_stack.clear() 33 | _running_timer = None 34 | 35 | def start(fn_name, use_stack=True): 36 | """ 37 | Start timing the specific function. 38 | Note: If use_stack is True, only one timer can be active at a time. 39 | Once you stop this timer, the previous one will start again. 40 | """ 41 | global _running_timer, _disable_all 42 | 43 | if _disable_all: 44 | return 45 | 46 | if use_stack: 47 | if _running_timer is not None: 48 | stop(_running_timer, use_stack=False) 49 | _timer_stack.append(_running_timer) 50 | start(fn_name, use_stack=False) 51 | _running_timer = fn_name 52 | else: 53 | _start_times[fn_name] = time.perf_counter() 54 | 55 | def stop(fn_name=None, use_stack=True): 56 | """ 57 | If use_stack is True, this will stop the currently running timer and restore 58 | the previous timer on the stack if that exists. Note if use_stack is True, 59 | fn_name will be ignored. 60 | 61 | If use_stack is False, this will just stop timing the timer fn_name. 62 | """ 63 | global _running_timer, _disable_all 64 | 65 | if _disable_all: 66 | return 67 | 68 | if use_stack: 69 | if _running_timer is not None: 70 | stop(_running_timer, use_stack=False) 71 | if len(_timer_stack) > 0: 72 | _running_timer = _timer_stack.pop() 73 | start(_running_timer, use_stack=False) 74 | else: 75 | _running_timer = None 76 | else: 77 | print('Warning: timer stopped with no timer running!') 78 | else: 79 | if _start_times[fn_name] > -1: 80 | _total_times[fn_name] += time.perf_counter() - _start_times[fn_name] 81 | else: 82 | print('Warning: timer for %s stopped before starting!' % fn_name) 83 | 84 | 85 | def print_stats(): 86 | """ Prints the current timing information into a table. """ 87 | print() 88 | 89 | all_fn_names = [k for k in _total_times.keys() if k not in _disabled_names] 90 | 91 | max_name_width = max([len(k) for k in all_fn_names] + [4]) 92 | if max_name_width % 2 == 1: max_name_width += 1 93 | format_str = ' {:>%d} | {:>10.4f} ' % max_name_width 94 | 95 | header = (' {:^%d} | {:^10} ' % max_name_width).format('Name', 'Time (ms)') 96 | print(header) 97 | 98 | sep_idx = header.find('|') 99 | sep_text = ('-' * sep_idx) + '+' + '-' * (len(header)-sep_idx-1) 100 | print(sep_text) 101 | 102 | for name in all_fn_names: 103 | print(format_str.format(name, _total_times[name]*1000)) 104 | 105 | print(sep_text) 106 | print(format_str.format('Total', total_time()*1000)) 107 | print() 108 | 109 | def total_time(): 110 | """ Returns the total amount accumulated across all functions in seconds. """ 111 | return sum([elapsed_time for name, elapsed_time in _total_times.items() if name not in _disabled_names]) 112 | 113 | 114 | class env(): 115 | """ 116 | A class that lets you go: 117 | with timer.env(fn_name): 118 | # (...) 119 | That automatically manages a timer start and stop for you. 120 | """ 121 | 122 | def __init__(self, fn_name, use_stack=True): 123 | self.fn_name = fn_name 124 | self.use_stack = use_stack 125 | 126 | def __enter__(self): 127 | start(self.fn_name, use_stack=self.use_stack) 128 | 129 | def __exit__(self, e, ev, t): 130 | stop(self.fn_name, use_stack=self.use_stack) 131 | 132 | -------------------------------------------------------------------------------- /web/css/index.css: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | Pallete: 4 | 5 | FFFFFF 6 | D2CBCB 7 | 7D8491 8 | 003459 9 | 274C77 10 | 161925 11 | */ 12 | 13 | * { box-sizing: border-box; } 14 | 15 | .big { 16 | font-size:72px; 17 | margin-bottom: 20px; 18 | } 19 | 20 | .list_wrapper { 21 | width: 500px; 22 | padding-top: 2px; 23 | padding-bottom: 20px; 24 | } 25 | 26 | 27 | body { 28 | margin:0; 29 | padding:0; 30 | vertical-align: top; 31 | 32 | background-color: #274C77; 33 | color: #ffffff; 34 | font-family: 'Open Sans', sans-serif; 35 | font-size: 24px; 36 | width: 100%; 37 | height: 99vh; 38 | 39 | display: grid; 40 | grid-template-areas: 41 | 'header' 42 | 'main' 43 | 'footer'; 44 | 45 | grid-template-rows: 100px auto 25px; 46 | 47 | text-align: center; 48 | } 49 | 50 | .box { 51 | background-color: #23395B; 52 | border-radius: 10px; 53 | } 54 | 55 | .header { grid-area: header; } 56 | .main { grid-area: main; } 57 | .footer { grid-area: footer; } 58 | 59 | span { 60 | margin:0; 61 | padding:0; 62 | vertical-align: top; 63 | } 64 | -------------------------------------------------------------------------------- /web/css/list.css: -------------------------------------------------------------------------------- 1 | ul { 2 | list-style-type: none; 3 | margin: 0; 4 | padding: 0; 5 | } 6 | 7 | li { 8 | /* font: 200 24px/1.5 Helvetica, Verdana, sans-serif; */ 9 | font-size: 22px; 10 | } 11 | 12 | li a { 13 | text-decoration: none; 14 | color: #fff; 15 | display: block; 16 | width: 100%; 17 | 18 | -webkit-transition: font-size 0.2s ease, background-color 0.2s ease; 19 | -moz-transition: font-size 0.2s ease, background-color 0.2s ease; 20 | -o-transition: font-size 0.2s ease, background-color 0.2s ease; 21 | -ms-transition: font-size 0.2s ease, background-color 0.2s ease; 22 | transition: font-size 0.2s ease, background-color 0.2s ease; 23 | } 24 | 25 | li a:hover { 26 | font-size: 30px; 27 | background: rgb(95, 138, 219); 28 | } 29 | -------------------------------------------------------------------------------- /web/css/toggle.css: -------------------------------------------------------------------------------- 1 | .switch { 2 | position: relative; 3 | top: 5; 4 | } 5 | 6 | .switch input {display:none;} 7 | 8 | .slider { 9 | position: relative; 10 | display: inline-block; 11 | width: 60px; 12 | height: 26px; 13 | cursor: pointer; 14 | top: 0; 15 | left: 0; 16 | right: 0; 17 | bottom: 0; 18 | background-color: #ccc; 19 | -webkit-transition: .4s; 20 | transition: .4s; 21 | } 22 | 23 | .slider:before { 24 | position: absolute; 25 | content: ""; 26 | height: 20px; 27 | width: 20px; 28 | left: 3px; 29 | bottom: 3px; 30 | background-color: white; 31 | -webkit-transition: .4s; 32 | transition: .4s; 33 | } 34 | 35 | input:checked + .slider { 36 | background-color: #2196F3; 37 | } 38 | 39 | input:focus + .slider { 40 | box-shadow: 0 0 1px #2196F3; 41 | } 42 | 43 | input:checked + .slider:before { 44 | -webkit-transform: translateX(34px); 45 | -ms-transform: translateX(34px); 46 | transform: translateX(34px); 47 | } 48 | 49 | /* Rounded sliders */ 50 | .slider.round { 51 | border-radius: 34px; 52 | } 53 | 54 | .slider.round:before { 55 | border-radius: 50%; 56 | } 57 | -------------------------------------------------------------------------------- /web/css/viewer.css: -------------------------------------------------------------------------------- 1 | 2 | .info { grid-area: info; } 3 | .image { grid-area: image; } 4 | .controls { grid-area: controls; } 5 | 6 | 7 | #viewer { 8 | display: grid; 9 | grid-template-areas: 'info image controls'; 10 | grid-template-columns: 1fr 2fr 1fr; 11 | grid-gap: 0; 12 | } 13 | 14 | #viewer > div.box { 15 | padding: 10px; 16 | margin: 0 10px 10px 10px; 17 | } 18 | 19 | .image_box { 20 | display: grid; 21 | grid-template-rows: max-content auto; 22 | grid-gap: 10px; 23 | } 24 | 25 | #image_idx, #config_name, .info_value { 26 | color: rgb(152, 160, 175); 27 | } 28 | 29 | .info_section { 30 | text-align: center; 31 | border-bottom: 1px solid #fff; 32 | } 33 | 34 | a { 35 | text-decoration: none; 36 | color: #fff; 37 | } 38 | 39 | a:hover { 40 | color: rgb(152, 160, 175); 41 | } 42 | 43 | .setting { 44 | display: grid; 45 | grid-template-areas: 'label value input'; 46 | grid-template-columns: max-content 30px 1fr; 47 | grid-gap: 20px; 48 | padding: 0 10px 0 10px; 49 | text-align: left; 50 | } 51 | .setting_label { grid-area: label; } 52 | .setting_input { 53 | grid-area: input; 54 | } 55 | .setting_value { 56 | grid-area: value; 57 | color: rgb(152, 160, 175); 58 | } 59 | 60 | .box_title { 61 | width: 100%; 62 | border-bottom: 1px solid #fff; 63 | } 64 | 65 | -------------------------------------------------------------------------------- /web/dets/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "Cross-Class NMS": {"BBox mAP": 0.9523,"Mask mAP": 0.4231},"Per-Class NMS": {"BBox mAP": 0.1,"Mask mAP": 0.2}, 4 | "Config" : {} 5 | }, 6 | "images": [ 7 | { 8 | "image_id": 42, 9 | "dets": [ 10 | { 11 | "score": 0.14, 12 | "bbox": [20, 20, 100, 100] 13 | }, 14 | { 15 | "score": 0.09, 16 | "bbox": [40, 100, 230, 10] 17 | } 18 | ] 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Configurations 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | Detections Viewer 20 | 21 |
22 |
23 |

Select a configuration

24 |
    25 |
26 |
27 |
28 | 29 | By Daniel Bolya 30 | 31 | 32 | -------------------------------------------------------------------------------- /web/iou.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | IoU thingy 11 | 12 | 13 | 14 | 15 | 16 | 17 | 35 | 36 | 37 | 38 |
39 | 40 | This text is displayed if your browser does not support HTML5 Canvas. 41 | 42 |
43 |
44 |

IoU:

45 |


46 |

Bbox manupluation sourced from here

47 |
48 | 49 | 50 | -------------------------------------------------------------------------------- /web/scripts/index.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | // Load in det_index and fill the config list with the appropriate elements 3 | $.ajax({ 4 | url: 'detindex', 5 | dataType: 'text', 6 | success: function (data) { 7 | data = data.trim().split('\n'); 8 | for (let i = 0; i < data.length; i++) { 9 | name = data[i]; 10 | 11 | $('#config_list').append( 12 | '
  • ' + name + '
  • ' 13 | ); 14 | } 15 | } 16 | }); 17 | }); 18 | -------------------------------------------------------------------------------- /web/scripts/iou.js: -------------------------------------------------------------------------------- 1 | // IoU added by Daniel Bolya 2 | // 3 | // Last updated November 2010 by Simon Sarris 4 | // www.simonsarris.com 5 | // sarris@acm.org 6 | // 7 | // Free to use and distribute at will 8 | // So long as you are nice to people, etc 9 | 10 | // This is a self-executing function that I added only to stop this 11 | // new script from interfering with the old one. It's a good idea in general, but not 12 | // something I wanted to go over during this tutorial 13 | (function(window) { 14 | 15 | 16 | // holds all our boxes 17 | var boxes2 = []; 18 | 19 | // New, holds the 8 tiny boxes that will be our selection handles 20 | // the selection handles will be in this order: 21 | // 0 1 2 22 | // 3 4 23 | // 5 6 7 24 | var selectionHandles = []; 25 | 26 | // Hold canvas information 27 | var canvas; 28 | var ctx; 29 | var WIDTH; 30 | var HEIGHT; 31 | var INTERVAL = 20; // how often, in milliseconds, we check to see if a redraw is needed 32 | 33 | var isDrag = false; 34 | var isResizeDrag = false; 35 | var expectResize = -1; // New, will save the # of the selection handle if the mouse is over one. 36 | var mx, my; // mouse coordinates 37 | 38 | // when set to true, the canvas will redraw everything 39 | // invalidate() just sets this to false right now 40 | // we want to call invalidate() whenever we make a change 41 | var canvasValid = false; 42 | 43 | // The node (if any) being selected. 44 | // If in the future we want to select multiple objects, this will get turned into an array 45 | var mySel = null; 46 | 47 | // The selection color and width. Right now we have a red selection with a small width 48 | var mySelColor = '#CC0000'; 49 | var mySelWidth = 2; 50 | var mySelBoxColor = 'darkred'; // New for selection boxes 51 | var mySelBoxSize = 6; 52 | 53 | // we use a fake canvas to draw individual shapes for selection testing 54 | var ghostcanvas; 55 | var gctx; // fake canvas context 56 | 57 | // since we can drag from anywhere in a node 58 | // instead of just its x/y corner, we need to save 59 | // the offset of the mouse when we start dragging. 60 | var offsetx, offsety; 61 | 62 | // Padding and border style widths for mouse offsets 63 | var stylePaddingLeft, stylePaddingTop, styleBorderLeft, styleBorderTop; 64 | 65 | 66 | 67 | 68 | // Box object to hold data 69 | function Box2() { 70 | this.x = 0; 71 | this.y = 0; 72 | this.w = 1; // default width and height? 73 | this.h = 1; 74 | this.fill = '#444444'; 75 | } 76 | 77 | // New methods on the Box class 78 | Box2.prototype = { 79 | // we used to have a solo draw function 80 | // but now each box is responsible for its own drawing 81 | // mainDraw() will call this with the normal canvas 82 | // myDown will call this with the ghost canvas with 'black' 83 | draw: function(context, optionalColor) { 84 | if (context === gctx) { 85 | context.fillStyle = 'black'; // always want black for the ghost canvas 86 | } else { 87 | context.fillStyle = this.fill; 88 | } 89 | 90 | // We can skip the drawing of elements that have moved off the screen: 91 | if (this.x > WIDTH || this.y > HEIGHT) return; 92 | if (this.x + this.w < 0 || this.y + this.h < 0) return; 93 | 94 | context.fillRect(this.x,this.y,this.w,this.h); 95 | 96 | // draw selection 97 | // this is a stroke along the box and also 8 new selection handles 98 | if (mySel === this) { 99 | context.strokeStyle = mySelColor; 100 | context.lineWidth = mySelWidth; 101 | context.strokeRect(this.x,this.y,this.w,this.h); 102 | 103 | // draw the boxes 104 | 105 | var half = mySelBoxSize / 2; 106 | 107 | // 0 1 2 108 | // 3 4 109 | // 5 6 7 110 | 111 | // top left, middle, right 112 | selectionHandles[0].x = this.x-half; 113 | selectionHandles[0].y = this.y-half; 114 | 115 | selectionHandles[1].x = this.x+this.w/2-half; 116 | selectionHandles[1].y = this.y-half; 117 | 118 | selectionHandles[2].x = this.x+this.w-half; 119 | selectionHandles[2].y = this.y-half; 120 | 121 | //middle left 122 | selectionHandles[3].x = this.x-half; 123 | selectionHandles[3].y = this.y+this.h/2-half; 124 | 125 | //middle right 126 | selectionHandles[4].x = this.x+this.w-half; 127 | selectionHandles[4].y = this.y+this.h/2-half; 128 | 129 | //bottom left, middle, right 130 | selectionHandles[6].x = this.x+this.w/2-half; 131 | selectionHandles[6].y = this.y+this.h-half; 132 | 133 | selectionHandles[5].x = this.x-half; 134 | selectionHandles[5].y = this.y+this.h-half; 135 | 136 | selectionHandles[7].x = this.x+this.w-half; 137 | selectionHandles[7].y = this.y+this.h-half; 138 | 139 | 140 | context.fillStyle = mySelBoxColor; 141 | for (var i = 0; i < 8; i ++) { 142 | var cur = selectionHandles[i]; 143 | context.fillRect(cur.x, cur.y, mySelBoxSize, mySelBoxSize); 144 | } 145 | } 146 | 147 | } // end draw 148 | 149 | } 150 | 151 | //Initialize a new Box, add it, and invalidate the canvas 152 | function addRect(x, y, w, h, fill) { 153 | var rect = new Box2; 154 | rect.x = x; 155 | rect.y = y; 156 | rect.w = w 157 | rect.h = h; 158 | rect.fill = fill; 159 | boxes2.push(rect); 160 | invalidate(); 161 | } 162 | 163 | // initialize our canvas, add a ghost canvas, set draw loop 164 | // then add everything we want to intially exist on the canvas 165 | function init2() { 166 | canvas = document.getElementById('canvas2'); 167 | canvas.style.width='50%'; 168 | canvas.style.height='65%'; 169 | canvas.width = canvas.offsetWidth; 170 | canvas.height = canvas.offsetHeight; 171 | HEIGHT = canvas.height; 172 | WIDTH = canvas.width; 173 | ctx = canvas.getContext('2d'); 174 | ghostcanvas = document.createElement('canvas'); 175 | ghostcanvas.height = HEIGHT; 176 | ghostcanvas.width = WIDTH; 177 | gctx = ghostcanvas.getContext('2d'); 178 | 179 | //fixes a problem where double clicking causes text to get selected on the canvas 180 | canvas.onselectstart = function () { return false; } 181 | 182 | // fixes mouse co-ordinate problems when there's a border or padding 183 | // see getMouse for more detail 184 | if (document.defaultView && document.defaultView.getComputedStyle) { 185 | stylePaddingLeft = parseInt(document.defaultView.getComputedStyle(canvas, null)['paddingLeft'], 10) || 0; 186 | stylePaddingTop = parseInt(document.defaultView.getComputedStyle(canvas, null)['paddingTop'], 10) || 0; 187 | styleBorderLeft = parseInt(document.defaultView.getComputedStyle(canvas, null)['borderLeftWidth'], 10) || 0; 188 | styleBorderTop = parseInt(document.defaultView.getComputedStyle(canvas, null)['borderTopWidth'], 10) || 0; 189 | } 190 | 191 | // make mainDraw() fire every INTERVAL milliseconds 192 | setInterval(mainDraw, INTERVAL); 193 | 194 | // set our events. Up and down are for dragging, 195 | // double click is for making new boxes 196 | canvas.onmousedown = myDown; 197 | canvas.onmouseup = myUp; 198 | canvas.ondblclick = myDblClick; 199 | canvas.onmousemove = myMove; 200 | 201 | // set up the selection handle boxes 202 | for (var i = 0; i < 8; i ++) { 203 | var rect = new Box2; 204 | selectionHandles.push(rect); 205 | } 206 | 207 | // add custom initialization here: 208 | 209 | 210 | // add a large green rectangle 211 | addRect(260, 70, WIDTH/2, HEIGHT/2, 'rgba(255, 210, 75, 0.7)'); 212 | 213 | // add a green-blue rectangle 214 | addRect(240, 120, WIDTH/2, HEIGHT/2, 'rgba(255, 210, 75, 0.7)'); 215 | 216 | // add a smaller purple rectangle 217 | // addRect(45, 60, 25, 25, 'rgba(150,150,250,0.7)'); 218 | } 219 | 220 | 221 | //wipes the canvas context 222 | function clear(c) { 223 | c.clearRect(0, 0, WIDTH, HEIGHT); 224 | } 225 | 226 | // Main draw loop. 227 | // While draw is called as often as the INTERVAL variable demands, 228 | // It only ever does something if the canvas gets invalidated by our code 229 | function mainDraw() { 230 | if (canvasValid == false) { 231 | clear(ctx); 232 | 233 | // Add stuff you want drawn in the background all the time here 234 | 235 | // draw all boxes 236 | var l = boxes2.length; 237 | for (var i = 0; i < l; i++) { 238 | boxes2[i].draw(ctx); // we used to call drawshape, but now each box draws itself 239 | } 240 | 241 | // Add stuff you want drawn on top all the time here 242 | document.querySelector('#iou').innerHTML = computeIoU(boxes2[0], boxes2[1]); 243 | 244 | canvasValid = true; 245 | } 246 | } 247 | 248 | function computeIoU(a, b) { 249 | var leftX = Math.max(a.x, b.x); 250 | var rightX = Math.min(a.x+a.w, b.x+b.w); 251 | 252 | var topY = Math.max(a.y, b.y); 253 | var botY = Math.min(a.y+a.h, b.y+b.h); 254 | 255 | if (rightX < leftX || botY < topY) 256 | return 0; 257 | 258 | var inter = (rightX-leftX) * (botY-topY); 259 | var areaA = a.w * a.h; 260 | var areaB = b.w * b.h; 261 | var union = areaA + areaB - inter; 262 | 263 | var iou = inter / union; 264 | 265 | return Math.round(iou * 100) / 100; 266 | } 267 | 268 | // Happens when the mouse is moving inside the canvas 269 | function myMove(e){ 270 | if (isDrag) { 271 | getMouse(e); 272 | 273 | mySel.x = mx - offsetx; 274 | mySel.y = my - offsety; 275 | 276 | // something is changing position so we better invalidate the canvas! 277 | invalidate(); 278 | } else if (isResizeDrag) { 279 | // time ro resize! 280 | var oldx = mySel.x; 281 | var oldy = mySel.y; 282 | 283 | // 0 1 2 284 | // 3 4 285 | // 5 6 7 286 | switch (expectResize) { 287 | case 0: 288 | mySel.x = mx; 289 | mySel.y = my; 290 | mySel.w += oldx - mx; 291 | mySel.h += oldy - my; 292 | break; 293 | case 1: 294 | mySel.y = my; 295 | mySel.h += oldy - my; 296 | break; 297 | case 2: 298 | mySel.y = my; 299 | mySel.w = mx - oldx; 300 | mySel.h += oldy - my; 301 | break; 302 | case 3: 303 | mySel.x = mx; 304 | mySel.w += oldx - mx; 305 | break; 306 | case 4: 307 | mySel.w = mx - oldx; 308 | break; 309 | case 5: 310 | mySel.x = mx; 311 | mySel.w += oldx - mx; 312 | mySel.h = my - oldy; 313 | break; 314 | case 6: 315 | mySel.h = my - oldy; 316 | break; 317 | case 7: 318 | mySel.w = mx - oldx; 319 | mySel.h = my - oldy; 320 | break; 321 | } 322 | 323 | invalidate(); 324 | } 325 | 326 | getMouse(e); 327 | // if there's a selection see if we grabbed one of the selection handles 328 | if (mySel !== null && !isResizeDrag) { 329 | for (var i = 0; i < 8; i++) { 330 | // 0 1 2 331 | // 3 4 332 | // 5 6 7 333 | 334 | var cur = selectionHandles[i]; 335 | 336 | // we dont need to use the ghost context because 337 | // selection handles will always be rectangles 338 | if (mx >= cur.x && mx <= cur.x + mySelBoxSize && 339 | my >= cur.y && my <= cur.y + mySelBoxSize) { 340 | // we found one! 341 | expectResize = i; 342 | invalidate(); 343 | 344 | switch (i) { 345 | case 0: 346 | this.style.cursor='nw-resize'; 347 | break; 348 | case 1: 349 | this.style.cursor='n-resize'; 350 | break; 351 | case 2: 352 | this.style.cursor='ne-resize'; 353 | break; 354 | case 3: 355 | this.style.cursor='w-resize'; 356 | break; 357 | case 4: 358 | this.style.cursor='e-resize'; 359 | break; 360 | case 5: 361 | this.style.cursor='sw-resize'; 362 | break; 363 | case 6: 364 | this.style.cursor='s-resize'; 365 | break; 366 | case 7: 367 | this.style.cursor='se-resize'; 368 | break; 369 | } 370 | return; 371 | } 372 | 373 | } 374 | // not over a selection box, return to normal 375 | isResizeDrag = false; 376 | expectResize = -1; 377 | this.style.cursor='auto'; 378 | } 379 | 380 | } 381 | 382 | // Happens when the mouse is clicked in the canvas 383 | function myDown(e){ 384 | getMouse(e); 385 | 386 | //we are over a selection box 387 | if (expectResize !== -1) { 388 | isResizeDrag = true; 389 | return; 390 | } 391 | 392 | clear(gctx); 393 | var l = boxes2.length; 394 | for (var i = l-1; i >= 0; i--) { 395 | // draw shape onto ghost context 396 | boxes2[i].draw(gctx, 'black'); 397 | 398 | // get image data at the mouse x,y pixel 399 | var imageData = gctx.getImageData(mx, my, 1, 1); 400 | var index = (mx + my * imageData.width) * 4; 401 | 402 | // if the mouse pixel exists, select and break 403 | if (imageData.data[3] > 0) { 404 | mySel = boxes2[i]; 405 | offsetx = mx - mySel.x; 406 | offsety = my - mySel.y; 407 | mySel.x = mx - offsetx; 408 | mySel.y = my - offsety; 409 | isDrag = true; 410 | 411 | invalidate(); 412 | clear(gctx); 413 | return; 414 | } 415 | 416 | } 417 | // havent returned means we have selected nothing 418 | mySel = null; 419 | // clear the ghost canvas for next time 420 | clear(gctx); 421 | // invalidate because we might need the selection border to disappear 422 | invalidate(); 423 | } 424 | 425 | function myUp(){ 426 | isDrag = false; 427 | isResizeDrag = false; 428 | expectResize = -1; 429 | } 430 | 431 | // adds a new node 432 | function myDblClick(e) { 433 | getMouse(e); 434 | // for this method width and height determine the starting X and Y, too. 435 | // so I left them as vars in case someone wanted to make them args for something and copy this code 436 | // var width = 20; 437 | // var height = 20; 438 | // addRect(mx - (width / 2), my - (height / 2), width, height, 'rgba(220,205,65,0.7)'); 439 | } 440 | 441 | 442 | function invalidate() { 443 | canvasValid = false; 444 | } 445 | 446 | // Sets mx,my to the mouse position relative to the canvas 447 | // unfortunately this can be tricky, we have to worry about padding and borders 448 | function getMouse(e) { 449 | var element = canvas, offsetX = 0, offsetY = 0; 450 | 451 | if (element.offsetParent) { 452 | do { 453 | offsetX += element.offsetLeft; 454 | offsetY += element.offsetTop; 455 | } while ((element = element.offsetParent)); 456 | } 457 | 458 | // Add padding and border style widths to offset 459 | offsetX += stylePaddingLeft; 460 | offsetY += stylePaddingTop; 461 | 462 | offsetX += styleBorderLeft; 463 | offsetY += styleBorderTop; 464 | 465 | mx = e.pageX - offsetX; 466 | my = e.pageY - offsetY 467 | } 468 | 469 | // If you dont want to use 470 | // You could uncomment this init() reference and place the script reference inside the body tag 471 | //init(); 472 | window.init2 = init2; 473 | })(window); 474 | 475 | // Andy added, as a replacement for 476 | // 477 | $(document).ready(function(){ 478 | // Your code here 479 | init2(); 480 | }); 481 | 482 | -------------------------------------------------------------------------------- /web/scripts/js.cookie.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * JavaScript Cookie v2.2.0 3 | * https://github.com/js-cookie/js-cookie 4 | * 5 | * Copyright 2006, 2015 Klaus Hartl & Fagner Brack 6 | * Released under the MIT license 7 | */ 8 | ;(function (factory) { 9 | var registeredInModuleLoader = false; 10 | if (typeof define === 'function' && define.amd) { 11 | define(factory); 12 | registeredInModuleLoader = true; 13 | } 14 | if (typeof exports === 'object') { 15 | module.exports = factory(); 16 | registeredInModuleLoader = true; 17 | } 18 | if (!registeredInModuleLoader) { 19 | var OldCookies = window.Cookies; 20 | var api = window.Cookies = factory(); 21 | api.noConflict = function () { 22 | window.Cookies = OldCookies; 23 | return api; 24 | }; 25 | } 26 | }(function () { 27 | function extend () { 28 | var i = 0; 29 | var result = {}; 30 | for (; i < arguments.length; i++) { 31 | var attributes = arguments[ i ]; 32 | for (var key in attributes) { 33 | result[key] = attributes[key]; 34 | } 35 | } 36 | return result; 37 | } 38 | 39 | function init (converter) { 40 | function api (key, value, attributes) { 41 | var result; 42 | if (typeof document === 'undefined') { 43 | return; 44 | } 45 | 46 | // Write 47 | 48 | if (arguments.length > 1) { 49 | attributes = extend({ 50 | path: '/' 51 | }, api.defaults, attributes); 52 | 53 | if (typeof attributes.expires === 'number') { 54 | var expires = new Date(); 55 | expires.setMilliseconds(expires.getMilliseconds() + attributes.expires * 864e+5); 56 | attributes.expires = expires; 57 | } 58 | 59 | // We're using "expires" because "max-age" is not supported by IE 60 | attributes.expires = attributes.expires ? attributes.expires.toUTCString() : ''; 61 | 62 | try { 63 | result = JSON.stringify(value); 64 | if (/^[\{\[]/.test(result)) { 65 | value = result; 66 | } 67 | } catch (e) {} 68 | 69 | if (!converter.write) { 70 | value = encodeURIComponent(String(value)) 71 | .replace(/%(23|24|26|2B|3A|3C|3E|3D|2F|3F|40|5B|5D|5E|60|7B|7D|7C)/g, decodeURIComponent); 72 | } else { 73 | value = converter.write(value, key); 74 | } 75 | 76 | key = encodeURIComponent(String(key)); 77 | key = key.replace(/%(23|24|26|2B|5E|60|7C)/g, decodeURIComponent); 78 | key = key.replace(/[\(\)]/g, escape); 79 | 80 | var stringifiedAttributes = ''; 81 | 82 | for (var attributeName in attributes) { 83 | if (!attributes[attributeName]) { 84 | continue; 85 | } 86 | stringifiedAttributes += '; ' + attributeName; 87 | if (attributes[attributeName] === true) { 88 | continue; 89 | } 90 | stringifiedAttributes += '=' + attributes[attributeName]; 91 | } 92 | return (document.cookie = key + '=' + value + stringifiedAttributes); 93 | } 94 | 95 | // Read 96 | 97 | if (!key) { 98 | result = {}; 99 | } 100 | 101 | // To prevent the for loop in the first place assign an empty array 102 | // in case there are no cookies at all. Also prevents odd result when 103 | // calling "get()" 104 | var cookies = document.cookie ? document.cookie.split('; ') : []; 105 | var rdecode = /(%[0-9A-Z]{2})+/g; 106 | var i = 0; 107 | 108 | for (; i < cookies.length; i++) { 109 | var parts = cookies[i].split('='); 110 | var cookie = parts.slice(1).join('='); 111 | 112 | if (!this.json && cookie.charAt(0) === '"') { 113 | cookie = cookie.slice(1, -1); 114 | } 115 | 116 | try { 117 | var name = parts[0].replace(rdecode, decodeURIComponent); 118 | cookie = converter.read ? 119 | converter.read(cookie, name) : converter(cookie, name) || 120 | cookie.replace(rdecode, decodeURIComponent); 121 | 122 | if (this.json) { 123 | try { 124 | cookie = JSON.parse(cookie); 125 | } catch (e) {} 126 | } 127 | 128 | if (key === name) { 129 | result = cookie; 130 | break; 131 | } 132 | 133 | if (!key) { 134 | result[name] = cookie; 135 | } 136 | } catch (e) {} 137 | } 138 | 139 | return result; 140 | } 141 | 142 | api.set = api; 143 | api.get = function (key) { 144 | return api.call(api, key); 145 | }; 146 | api.getJSON = function () { 147 | return api.apply({ 148 | json: true 149 | }, [].slice.call(arguments)); 150 | }; 151 | api.defaults = {}; 152 | 153 | api.remove = function (key, attributes) { 154 | api(key, '', extend(attributes, { 155 | expires: -1 156 | })); 157 | }; 158 | 159 | api.withConverter = init; 160 | 161 | return api; 162 | } 163 | 164 | return init(function () {}); 165 | })); 166 | -------------------------------------------------------------------------------- /web/scripts/utils.js: -------------------------------------------------------------------------------- 1 | function load_RLE(rle_obj, fillColor=[255, 255, 255], alpha=255) { 2 | var h = rle_obj.size[0], w = rle_obj.size[1]; 3 | var counts = uncompress_RLE(rle_obj.counts); 4 | 5 | var buffer_size = (w*h*4); 6 | var buffer = new Uint8ClampedArray(w*h*4); 7 | var bufferIdx = 0; 8 | 9 | for (var countsIdx = 0; countsIdx < counts.length; countsIdx++) { 10 | while (counts[countsIdx] > 0) { 11 | // Kind of transpose the image as we go 12 | if (bufferIdx >= buffer_size) 13 | bufferIdx = (bufferIdx % buffer_size) + 4; 14 | 15 | buffer[bufferIdx+0] = fillColor[0]; 16 | buffer[bufferIdx+1] = fillColor[1]; 17 | buffer[bufferIdx+2] = fillColor[2]; 18 | buffer[bufferIdx+3] = alpha * (countsIdx % 2); 19 | 20 | bufferIdx += 4*w; 21 | counts[countsIdx]--; 22 | } 23 | } 24 | 25 | // Load into an off-screen canvas and return an image with that data 26 | var canvas = document.createElement('canvas'); 27 | var ctx = canvas.getContext('2d'); 28 | 29 | canvas.width = w; 30 | canvas.height = h; 31 | 32 | var idata = ctx.createImageData(w, h); 33 | idata.data.set(buffer); 34 | 35 | ctx.putImageData(idata, 0, 0); 36 | 37 | var img = new Image(); 38 | img.src = canvas.toDataURL(); 39 | 40 | return img; 41 | } 42 | 43 | function uncompress_RLE(rle_str) { 44 | // Don't ask me how this works--I'm just transcribing from the pycocotools c api. 45 | var p = 0, m = 0; 46 | var counts = Array(rle_str.lenght); 47 | 48 | while (p < rle_str.length) { 49 | var x=0, k=0, more=1; 50 | 51 | while (more) { 52 | var c = rle_str.charCodeAt(p) - 48; 53 | x |= (c & 0x1f) << 5*k; 54 | more = c & 0x20; 55 | p++; k++; 56 | if (!more && (c & 0x10)) 57 | x |= (-1 << 5*k); 58 | } 59 | 60 | if (m > 2) 61 | x += counts[m-2]; 62 | counts[m++] = (x >>> 0); 63 | } 64 | 65 | return counts; 66 | } 67 | 68 | function hexToRgb(hex) { 69 | var result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); 70 | return result ? [parseInt(result[1], 16), parseInt(result[2], 16), parseInt(result[3], 16)] : null; 71 | } 72 | -------------------------------------------------------------------------------- /web/scripts/viewer.js: -------------------------------------------------------------------------------- 1 | // Global variables so I remember them 2 | config_name = null; 3 | img_idx = null; 4 | 5 | img = null; 6 | dets = null; 7 | masks = null; 8 | 9 | // Must be in hex 10 | colors = ['#FF0000', '#FF7F00', '#00FF00', '#0000FF', '#4B0082', '#9400D3']; 11 | 12 | settings = { 13 | 'top_k': 5, 14 | 'font_height': 20, 15 | 'mask_alpha': 100, 16 | 17 | 'show_class': true, 18 | 'show_score': true, 19 | 'show_bbox': true, 20 | 'show_mask': true, 21 | 22 | 'show_one': false, 23 | } 24 | 25 | function save_settings() { 26 | Cookies.set('settings', settings); 27 | } 28 | 29 | function load_settings() { 30 | var new_settings = Cookies.getJSON('settings'); 31 | 32 | for (var key in new_settings) 33 | settings[key] = new_settings[key]; 34 | } 35 | 36 | $.urlParam = function(name){ 37 | var results = new RegExp('[\?&]' + name + '=([^&#]*)').exec(window.location.href); 38 | if (results==null){ 39 | return null; 40 | } 41 | else{ 42 | return decodeURI(results[1]) || 0; 43 | } 44 | } 45 | 46 | $(document).ready(function() { 47 | config_name = $.urlParam('config'); 48 | $('#config_name').html(config_name); 49 | 50 | img_idx = $.urlParam('idx'); 51 | if (img_idx === null) img_idx = 0; 52 | img_idx = parseInt(img_idx); 53 | 54 | load_settings(); 55 | 56 | $.getJSON('dets/' + config_name + '.json', function(data) { 57 | img_idx = (img_idx+data.images.length) % data.images.length; 58 | var info = data.info; 59 | var data = data.images[img_idx]; 60 | 61 | // These are globals on purpose 62 | dets = data.dets; 63 | img = new Image(); 64 | masks = Array(dets.length); 65 | 66 | img.onload = function() { render(); } 67 | img.src = 'image' + data.image_id; 68 | 69 | $('#image_name').html(data.image_id); 70 | $('#image_idx').html(img_idx); 71 | 72 | fill_info(info); 73 | fill_controls(); 74 | }); 75 | }); 76 | 77 | function is_object(val) { return val === Object(val); } 78 | 79 | function fill_info(info) { 80 | var html = ''; 81 | 82 | var add_item = function(item, val) { 83 | html += '' + item + '' 84 | html += ' ' 85 | html += '' + val + '' 86 | html += '
    ' 87 | } 88 | 89 | for (var item in info) { 90 | var val = info[item]; 91 | 92 | if (is_object(val)) { 93 | html += '' + item + '
    '; 94 | 95 | for (var item2 in val) 96 | add_item(item2, val[item2]); 97 | 98 | html += '
    ' 99 | } else add_item(item, val); 100 | } 101 | 102 | $('#info_box').html(html); 103 | } 104 | 105 | function fill_controls() { 106 | var html = ''; 107 | 108 | var append_html = function() { 109 | $('#control_box').append(html); 110 | html = ''; 111 | } 112 | 113 | var make_slider = function (name, setting, min, max) { 114 | settings[setting] = Math.min(max, settings[setting]); 115 | var value = settings[setting]; 116 | 117 | html += '
    '; 118 | html += '' + name + ''; 119 | html += ''; 120 | html += '' + value + ''; 121 | html += '
    '; 122 | append_html(); 123 | 124 | $('input#'+setting).change(function(e) { 125 | settings[setting] = $('input#'+setting).prop('value'); 126 | $('span#'+setting).html(settings[setting]); 127 | save_settings(); 128 | render(); 129 | }); 130 | } 131 | 132 | var make_toggle = function(name, setting) { 133 | html += '
    '; 134 | html += '' + name + ''; 135 | html += '
    '; 139 | append_html(); 140 | 141 | $('input#' + setting).change(function (e) { 142 | settings[setting] = $('input#' + setting).prop('checked'); 143 | save_settings(); 144 | render(); 145 | }); 146 | } 147 | 148 | 149 | make_slider('Top K', 'top_k', 1, dets.length); 150 | make_toggle('Show One', 'show_one'); 151 | html += '
    '; 152 | make_toggle('Show BBox', 'show_bbox'); 153 | make_toggle('Show Class', 'show_class'); 154 | make_toggle('Show Score', 'show_score'); 155 | html += '
    '; 156 | make_slider('Mask Alpha', 'mask_alpha', 0, 255); 157 | make_toggle('Show Mask', 'show_mask'); 158 | 159 | html += '

    '; 160 | html += 'Prev'; 161 | html += '   '; 162 | html += 'Next'; 163 | html += '

    '; 164 | html += 'Back'; 165 | 166 | append_html(); 167 | } 168 | 169 | function render() { 170 | var canvas = document.querySelector('#image_canvas'); 171 | var ctx = canvas.getContext('2d'); 172 | 173 | canvas.style.width='100%'; 174 | canvas.style.height='94%'; 175 | canvas.width = canvas.offsetWidth; 176 | canvas.height = canvas.offsetHeight; 177 | 178 | var scale = Math.min(canvas.width / img.width, canvas.height / img.height); 179 | 180 | var im_x = canvas.width/2-img.width*scale/2; 181 | var im_y = canvas.height/2-img.height*scale/2; 182 | ctx.translate(im_x, im_y); 183 | ctx.drawImage(img, 0, 0, img.width * scale, img.height * scale); 184 | 185 | var startIdx = Math.min(dets.length, settings.top_k)-1; 186 | var endIdx = (settings.show_one ? startIdx : 0); 187 | 188 | // Draw masks behind everything 189 | for (var i = startIdx; i >= endIdx; i--) { 190 | if (settings.show_mask) { 191 | var mask = masks[i]; 192 | if (typeof mask == 'undefined') { 193 | masks[i] = load_RLE(dets[i].mask, hexToRgb(colors[i % colors.length])); 194 | masks[i].onload = function() { render(); } 195 | } else { 196 | ctx.globalAlpha = settings.mask_alpha / 255; 197 | ctx.drawImage(mask, 0, 0, mask.width * scale, mask.height * scale); 198 | ctx.globalAlpha = 1; 199 | } 200 | } 201 | } 202 | 203 | for (var i = startIdx; i >= endIdx; i--) { 204 | ctx.strokeStyle = colors[i % colors.length]; 205 | ctx.fillStyle = ctx.strokeStyle; 206 | ctx.lineWidth = 4; 207 | ctx.font = settings.font_height + 'px sans-serif'; 208 | 209 | var x = dets[i].bbox[0] * scale; 210 | var y = dets[i].bbox[1] * scale; 211 | var w = dets[i].bbox[2] * scale; 212 | var h = dets[i].bbox[3] * scale; 213 | 214 | if (settings.show_bbox) { 215 | ctx.strokeRect(x, y, w, h); 216 | ctx.stroke(); 217 | } 218 | 219 | var text_array = [] 220 | if (settings.show_class) 221 | text_array.push(dets[i].category); 222 | if (settings.show_score) 223 | text_array.push(Math.round(dets[i].score * 1000) / 1000); 224 | 225 | if (text_array.length > 0) { 226 | var text = text_array.join(' '); 227 | 228 | text_w = ctx.measureText(text).width; 229 | ctx.fillRect(x-ctx.lineWidth/2, y-settings.font_height-8, text_w+ctx.lineWidth, settings.font_height+8); 230 | 231 | ctx.fillStyle = 'white'; 232 | ctx.fillText(text, x, y-8); 233 | } 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /web/server.py: -------------------------------------------------------------------------------- 1 | from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus 2 | from pathlib import Path 3 | import os 4 | 5 | PORT = 6337 6 | IMAGE_PATH = '../data/coco/images/' 7 | IMAGE_FMT = '%012d.jpg' 8 | 9 | class Handler(SimpleHTTPRequestHandler): 10 | 11 | def do_GET(self): 12 | if self.path == '/detindex': 13 | self.send_str('\n'.join([p.name[:-5] for p in Path('dets/').glob('*.json')])) 14 | elif self.path.startswith('/image'): 15 | # Unsafe practices ahead! 16 | path = self.translate_path(self.path).split('image') 17 | self.send_file(os.path.join(path[0], IMAGE_PATH, IMAGE_FMT % int(path[1]))) 18 | else: 19 | super().do_GET() 20 | 21 | def send_str(self, string): 22 | self.send_response(HTTPStatus.OK) 23 | self.send_header('Content-type', 'text/plain') 24 | self.send_header('Content-Length', str(len(string))) 25 | self.send_header('Last-Modified', self.date_time_string()) 26 | self.end_headers() 27 | 28 | self.wfile.write(string.encode()) 29 | 30 | def send_file(self, path): 31 | try: 32 | f = open(path, 'rb') 33 | except OSError: 34 | self.send_error(HTTPStatus.NOT_FOUND, "File not found") 35 | return 36 | 37 | try: 38 | self.send_response(HTTPStatus.OK) 39 | self.send_header("Content-type", self.guess_type(path)) 40 | fs = os.fstat(f.fileno()) 41 | self.send_header("Content-Length", str(fs[6])) 42 | self.send_header("Last-Modified", self.date_time_string(fs.st_mtime)) 43 | self.end_headers() 44 | 45 | self.copyfile(f, self.wfile) 46 | finally: 47 | f.close() 48 | 49 | def send_response(self, code, message=None): 50 | super().send_response(code, message) 51 | 52 | 53 | with HTTPServer(('', PORT), Handler) as httpd: 54 | print('Serving at port', PORT) 55 | try: 56 | httpd.serve_forever() 57 | except KeyboardInterrupt: 58 | pass 59 | -------------------------------------------------------------------------------- /web/viewer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Detections Viewer 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | Detections Viewer 23 | 24 |
    25 |
    26 |
    Info  

    27 |
    28 |
    29 | 30 |
    31 |
     
    32 | 33 |
    34 | 35 |
    36 |
    Controls

    37 |
    38 |
    39 |
    40 | 41 | By Daniel Bolya 42 | 43 | 44 | --------------------------------------------------------------------------------