├── .DS_Store ├── LICENSE.txt ├── README.md ├── conda_list.txt ├── config.py ├── datasets ├── AIR │ ├── test.txt │ └── train.txt ├── CUB │ ├── test.txt │ └── train.txt └── STCAR │ ├── .DS_Store │ ├── test.txt │ └── train.txt ├── models ├── Asoftmax_linear.py ├── LoadModel.py ├── __pycache__ │ ├── LoadModel.cpython-36.pyc │ └── focal_loss.cpython-36.pyc └── focal_loss.py ├── test.py ├── train.py ├── transforms ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── functional.cpython-36.pyc │ └── transforms.cpython-36.pyc ├── functional.py └── transforms.py └── utils ├── Asoftmax_loss.py ├── __pycache__ ├── Asoftmax_loss.cpython-36.pyc ├── autoaugment.cpython-36.pyc ├── dataset_DCL.cpython-36.pyc ├── eval_model.cpython-36.pyc ├── train_model.cpython-36.pyc └── utils.cpython-36.pyc ├── autoaugment.py ├── dataset_DCL.py ├── eval_model.py ├── test_tool.py ├── train_model.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/.DS_Store -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright [2019], [京东JD.com JD AI] 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | ---------------------------------------------------------------------------------------------------------- 16 | 17 | From PyTorch: 18 | 19 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 20 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 21 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 22 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 23 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 24 | Copyright (c) 2011-2013 NYU (Clement Farabet) 25 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 26 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 27 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 28 | 29 | From Caffe2: 30 | 31 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 32 | 33 | All contributions by Facebook: 34 | Copyright (c) 2016 Facebook Inc. 35 | 36 | All contributions by Google: 37 | Copyright (c) 2015 Google Inc. 38 | All rights reserved. 39 | 40 | All contributions by Yangqing Jia: 41 | Copyright (c) 2015 Yangqing Jia 42 | All rights reserved. 43 | 44 | All contributions from Caffe: 45 | Copyright(c) 2013, 2014, 2015, the respective contributors 46 | All rights reserved. 47 | 48 | All other contributions: 49 | Copyright(c) 2015, 2016 the respective contributors 50 | All rights reserved. 51 | 52 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 53 | copyright over their contributions to Caffe2. The project versioning records 54 | all such contribution and copyright details. If a contributor wants to further 55 | mark their specific copyright on a particular contribution, they should 56 | indicate their copyright solely in the commit message of the change when it is 57 | committed. 58 | 59 | All rights reserved. 60 | 61 | Redistribution and use in source and binary forms, with or without 62 | modification, are permitted provided that the following conditions are met: 63 | 64 | 1. Redistributions of source code must retain the above copyright 65 | notice, this list of conditions and the following disclaimer. 66 | 67 | 2. Redistributions in binary form must reproduce the above copyright 68 | notice, this list of conditions and the following disclaimer in the 69 | documentation and/or other materials provided with the distribution. 70 | 71 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 72 | and IDIAP Research Institute nor the names of its contributors may be 73 | used to endorse or promote products derived from this software without 74 | specific prior written permission. 75 | 76 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 77 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 78 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 79 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 80 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 81 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 82 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 83 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 84 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 85 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 86 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Destruction and Construction Learning for Fine-grained Image Recognition 2 | 3 | By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei 4 | 5 | Special thanks to [Yuanzhi Liang](https://github.com/akira-l) for code refactoring. 6 | 7 | ## UPDATE Jun. 10, 2020 8 | ```diff 9 | ! Research Intern Position Opening. Please send your cv to baiyalong[AT]jd.com if you are interested. 10 | ``` 11 | - **First Place** in [CVPR 2020 AliProducts Challenge: Large-scale Product Recognition](https://tianchi.aliyun.com/competition/entrance/231780/rankingList) 12 | 13 | ## UPDATE Jun. 21, 2019 14 | 15 | Our solution for the FGVC Challenge 2019 (The Sixth Workshop on Fine-Grained Visual Categorization in CVPR 2019) is updated! 16 | 17 | With ensemble of several DCL based classification models, we won: 18 | 19 | - **First Place** in [iMaterialist Challenge on Product Recognition](https://www.kaggle.com/c/imaterialist-product-2019/leaderboard) 20 | - **First Place** in [Fieldguide Challenge: Moths & Butterflies](https://www.kaggle.com/c/fieldguide-challenge-moths-and-butterflies/leaderboard) 21 | - **Second Place** in [iFood - 2019 at FGVC6](https://www.kaggle.com/c/ifood-2019-fgvc6/leaderboard) 22 | 23 | ## Introduction 24 | 25 | This project is a DCL pytorch implementation of [*Destruction and Construction Learning for Fine-grained Image Recognition*](http://openaccess.thecvf.com/content_CVPR_2019/html/Chen_Destruction_and_Construction_Learning_for_Fine-Grained_Image_Recognition_CVPR_2019_paper.html), CVPR2019. 26 | 27 | ## Requirements 28 | 29 | 1. Python 3.6 30 | 31 | 2. Pytorch 0.4.0 or 0.4.1 32 | 33 | 3. CUDA 8.0 or higher 34 | 35 | For docker environment: 36 | 37 | ```shell 38 | docker pull pytorch/pytorch:0.4-cuda9-cudnn7-devel 39 | ``` 40 | 41 | For conda environment: 42 | 43 | ```shell 44 | conda create --name DCL file conda_list.txt 45 | ``` 46 | 47 | For more backbone supports in DCL, please check [pretrainmodels](https://github.com/Cadene/pretrained-models.pytorch) and install: 48 | 49 | ```shell 50 | pip install pretrainedmodels 51 | ``` 52 | 53 | 54 | ## Datasets Prepare 55 | 56 | 1. Download correspond dataset to folder 'datasets' 57 | 58 | 2. Data organization: eg. CUB 59 | 60 | All the image data are in './datasets/CUB/data/' 61 | e.g. './datasets/CUB/data/*.jpg' 62 | 63 | The annotation files are in './datasets/CUB/anno/' 64 | e.g. './dataset/CUB/data/train.txt' 65 | 66 | In annotations: 67 | 68 | ```shell 69 | name_of_image.jpg label_num\n 70 | ``` 71 | 72 | e.g. for CUB in repository: 73 | 74 | ```shell 75 | Black_Footed_Albatross_0009_34.jpg 0 76 | Black_Footed_Albatross_0014_89.jpg 0 77 | Laysan_Albatross_0044_784.jpg 1 78 | Sooty_Albatross_0021_796339.jpg 2 79 | ... 80 | ``` 81 | 82 | Some examples of datasets like CUB, Stanford Car, etc. are already given in our repository. You can use DCL to your datasets by simply converting annotations to train.txt/val.txt/test.txt and modify the class number in `config.py` as in line67: numcls=200. 83 | 84 | ## Training 85 | 86 | Run `train.py` to train DCL. 87 | 88 | For training CUB / STCAR / AIR from scratch 89 | 90 | ```shell 91 | python train.py --data CUB --epoch 360 --backbone resnet50 \ 92 | --tb 16 --tnw 16 --vb 512 --vnw 16 \ 93 | --lr 0.0008 --lr_step 60 \ 94 | --cls_lr_ratio 10 --start_epoch 0 \ 95 | --detail training_descibe --size 512 \ 96 | --crop 448 --cls_mul --swap_num 7 7 97 | ``` 98 | 99 | For training CUB / STCAR / AIR from trained checkpoint 100 | 101 | ```shell 102 | python train.py --data CUB --epoch 360 --backbone resnet50 \ 103 | --tb 16 --tnw 16 --vb 512 --vnw 16 \ 104 | --lr 0.0008 --lr_step 60 \ 105 | --cls_lr_ratio 10 --start_epoch $LAST_EPOCH \ 106 | --detail training_descibe4checkpoint --size 512 \ 107 | --crop 448 --cls_mul --swap_num 7 7 108 | ``` 109 | 110 | For training FGVC product datasets from scratch 111 | 112 | ```shell 113 | python train.py --data product --epoch 60 --backbone senet154 \ 114 | --tb 96 --tnw 32 --vb 512 --vnw 32 \ 115 | --lr 0.01 --lr_step 12 \ 116 | --cls_lr_ratio 10 --start_epoch 0 \ 117 | --detail training_descibe --size 512 \ 118 | --crop 448 --cls_2 --swap_num 7 7 119 | ``` 120 | 121 | For training FGVC datasets from trained checkpoint 122 | 123 | ```shell 124 | python train.py --data product --epoch 60 --backbone senet154 \ 125 | --tb 96 --tnw 32 --vb 512 --vnw 32 \ 126 | --lr 0.01 --lr_step 12 \ 127 | --cls_lr_ratio 10 --start_epoch $LAST_EPOCH \ 128 | --detail training_descibe4checkpoint --size 512 \ 129 | --crop 448 --cls_2 --swap_num 7 7 130 | ``` 131 | To achieve the similar results of paper, please use the default parameter settings. 132 | 133 | ## Citation 134 | Please cite our CVPR19 paper if you use this codebase in your work: 135 | ``` 136 | @InProceedings{Chen_2019_CVPR, 137 | author = {Chen, Yue and Bai, Yalong and Zhang, Wei and Mei, Tao}, 138 | title = {Destruction and Construction Learning for Fine-Grained Image Recognition}, 139 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 140 | month = {June}, 141 | year = {2019} 142 | } 143 | ``` 144 | ### Find our more recent work: 145 | Look-into-Object: Self-supervised Structure Modeling for Object Recognition. CVPR2020 [[pdf](https://arxiv.org/abs/2003.14142), [Source Code](https://github.com/JDAI-CV/LIO)] 146 | -------------------------------------------------------------------------------- /conda_list.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.continuum.io/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 6 | https://repo.continuum.io/pkgs/main/linux-64/bzip2-1.0.6-h9a117a8_4.tar.bz2 7 | https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2018.4.16-0.tar.bz2 8 | https://conda.anaconda.org/caffe2/linux-64/caffe2-cuda8.0-cudnn7-0.8.dev-py36_2018.05.14.tar.bz2 9 | https://conda.anaconda.org/conda-forge/linux-64/pandas-0.23.4-py36hf8a1672_0.tar.bz2 10 | https://repo.continuum.io/pkgs/main/linux-64/cairo-1.14.12-h7636065_2.tar.bz2 11 | https://repo.continuum.io/pkgs/main/linux-64/certifi-2018.4.16-py36_0.tar.bz2 12 | https://repo.continuum.io/pkgs/main/linux-64/cffi-1.11.5-py36h9745a5d_0.tar.bz2 13 | https://repo.continuum.io/pkgs/free/linux-64/cudatoolkit-8.0-3.tar.bz2 14 | https://repo.continuum.io/pkgs/main/linux-64/cycler-0.10.0-py36h93f1223_0.tar.bz2 15 | https://repo.continuum.io/pkgs/main/linux-64/dbus-1.13.2-h714fa37_1.tar.bz2 16 | https://repo.continuum.io/pkgs/main/linux-64/expat-2.2.5-he0dffb1_0.tar.bz2 17 | https://repo.continuum.io/pkgs/main/linux-64/ffmpeg-3.4-h7264315_0.tar.bz2 18 | https://repo.continuum.io/pkgs/main/linux-64/fontconfig-2.12.6-h49f89f6_0.tar.bz2 19 | https://repo.continuum.io/pkgs/free/linux-64/freeglut-2.8.1-0.tar.bz2 20 | https://repo.continuum.io/pkgs/main/linux-64/freetype-2.8-hab7d2ae_1.tar.bz2 21 | https://repo.continuum.io/pkgs/free/linux-64/future-0.16.0-py36_1.tar.bz2 22 | https://repo.continuum.io/pkgs/main/linux-64/gflags-2.2.1-hf484d3e_0.tar.bz2 23 | https://repo.continuum.io/pkgs/main/linux-64/glib-2.56.1-h000015b_0.tar.bz2 24 | https://repo.continuum.io/pkgs/main/linux-64/glog-0.3.5-hf484d3e_1.tar.bz2 25 | https://repo.continuum.io/pkgs/main/linux-64/graphite2-1.3.11-hf63cedd_1.tar.bz2 26 | https://repo.continuum.io/pkgs/main/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.tar.bz2 27 | https://repo.continuum.io/pkgs/main/linux-64/gstreamer-1.14.0-hb453b48_1.tar.bz2 28 | https://repo.continuum.io/pkgs/main/linux-64/h5py-2.8.0-py36hca9c191_0.tar.bz2 29 | https://repo.continuum.io/pkgs/main/linux-64/harfbuzz-1.7.6-h5f0a787_1.tar.bz2 30 | https://repo.continuum.io/pkgs/main/linux-64/hdf5-1.8.18-h6792536_1.tar.bz2 31 | https://repo.continuum.io/pkgs/main/linux-64/icu-58.2-h9c2bf20_1.tar.bz2 32 | https://repo.continuum.io/pkgs/main/linux-64/intel-openmp-2018.0.0-8.tar.bz2 33 | https://repo.continuum.io/pkgs/main/linux-64/jasper-2.0.14-h07fcdf6_0.tar.bz2 34 | https://repo.continuum.io/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 35 | https://repo.continuum.io/pkgs/main/linux-64/kiwisolver-1.0.1-py36h764f252_0.tar.bz2 36 | https://repo.continuum.io/pkgs/main/linux-64/libedit-3.1-heed3624_0.tar.bz2 37 | https://repo.continuum.io/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 38 | https://repo.continuum.io/pkgs/main/linux-64/libgcc-ng-7.2.0-hdf63c60_3.tar.bz2 39 | https://repo.continuum.io/pkgs/main/linux-64/libgfortran-ng-7.2.0-hdf63c60_3.tar.bz2 40 | https://repo.continuum.io/pkgs/main/linux-64/libglu-9.0.0-h0c0bdc1_1.tar.bz2 41 | https://repo.continuum.io/pkgs/main/linux-64/libopus-1.2.1-hb9ed12e_0.tar.bz2 42 | https://repo.continuum.io/pkgs/main/linux-64/libpng-1.6.34-hb9fc6fc_0.tar.bz2 43 | https://repo.continuum.io/pkgs/main/linux-64/libprotobuf-3.5.2-h6f1eeef_0.tar.bz2 44 | https://repo.continuum.io/pkgs/main/linux-64/libstdcxx-ng-7.2.0-hdf63c60_3.tar.bz2 45 | https://repo.continuum.io/pkgs/main/linux-64/libtiff-4.0.9-h28f6b97_0.tar.bz2 46 | https://repo.continuum.io/pkgs/main/linux-64/libvpx-1.6.1-h888fd40_0.tar.bz2 47 | https://repo.continuum.io/pkgs/main/linux-64/libxcb-1.13-h1bed415_1.tar.bz2 48 | https://repo.continuum.io/pkgs/main/linux-64/libxml2-2.9.8-hf84eae3_0.tar.bz2 49 | https://repo.continuum.io/pkgs/main/linux-64/matplotlib-2.2.2-py36h0e671d2_1.tar.bz2 50 | https://repo.continuum.io/pkgs/main/linux-64/mkl-2018.0.2-1.tar.bz2 51 | https://repo.continuum.io/pkgs/main/linux-64/mkl_fft-1.0.1-py36h3010b51_0.tar.bz2 52 | https://repo.continuum.io/pkgs/main/linux-64/mkl_random-1.0.1-py36h629b387_0.tar.bz2 53 | https://repo.continuum.io/pkgs/main/linux-64/ncurses-6.0-h9df7e31_2.tar.bz2 54 | https://repo.continuum.io/pkgs/main/linux-64/ninja-1.8.2-py36h6bb024c_1.tar.bz2 55 | https://repo.continuum.io/pkgs/main/linux-64/numpy-1.14.3-py36hcd700cb_1.tar.bz2 56 | https://repo.continuum.io/pkgs/main/linux-64/numpy-base-1.14.3-py36h9be14a7_1.tar.bz2 57 | https://repo.continuum.io/pkgs/main/linux-64/olefile-0.45.1-py36_0.tar.bz2 58 | https://repo.continuum.io/pkgs/main/linux-64/opencv-3.3.1-py36h9248ab4_2.tar.bz2 59 | https://repo.continuum.io/pkgs/main/linux-64/openssl-1.0.2o-h20670df_0.tar.bz2 60 | https://repo.continuum.io/pkgs/main/linux-64/pcre-8.42-h439df22_0.tar.bz2 61 | https://repo.continuum.io/pkgs/main/linux-64/pillow-5.1.0-py36h3deb7b8_0.tar.bz2 62 | https://repo.continuum.io/pkgs/main/linux-64/pip-10.0.1-py36_0.tar.bz2 63 | https://repo.continuum.io/pkgs/main/linux-64/pixman-0.34.0-hceecf20_3.tar.bz2 64 | https://conda.anaconda.org/conda-forge/linux-64/protobuf-3.5.2-py36_0.tar.bz2 65 | https://repo.continuum.io/pkgs/main/linux-64/pycparser-2.18-py36hf9f622e_1.tar.bz2 66 | https://repo.continuum.io/pkgs/main/linux-64/pyparsing-2.2.0-py36hee85983_1.tar.bz2 67 | https://repo.continuum.io/pkgs/main/linux-64/pyqt-5.9.2-py36h751905a_0.tar.bz2 68 | https://repo.continuum.io/pkgs/main/linux-64/python-3.6.5-hc3d631a_2.tar.bz2 69 | https://repo.continuum.io/pkgs/main/linux-64/python-dateutil-2.7.2-py36_0.tar.bz2 70 | https://conda.anaconda.org/pytorch/linux-64/pytorch-0.4.0-py36_cuda8.0.61_cudnn7.1.2_1.tar.bz2 71 | https://repo.continuum.io/pkgs/main/linux-64/pytz-2018.4-py36_0.tar.bz2 72 | https://repo.continuum.io/pkgs/main/linux-64/pyyaml-3.12-py36hafb9ca4_1.tar.bz2 73 | https://repo.continuum.io/pkgs/main/linux-64/qt-5.9.5-h7e424d6_0.tar.bz2 74 | https://repo.continuum.io/pkgs/main/linux-64/readline-7.0-ha6073c6_4.tar.bz2 75 | https://repo.continuum.io/pkgs/main/linux-64/scikit-learn-0.19.1-py36h7aa7ec6_0.tar.bz2 76 | https://repo.continuum.io/pkgs/main/linux-64/scipy-1.1.0-py36hfc37229_0.tar.bz2 77 | https://repo.continuum.io/pkgs/main/linux-64/setuptools-39.1.0-py36_0.tar.bz2 78 | https://repo.continuum.io/pkgs/main/linux-64/sip-4.19.8-py36hf484d3e_0.tar.bz2 79 | https://repo.continuum.io/pkgs/main/linux-64/six-1.11.0-py36h372c433_1.tar.bz2 80 | https://repo.continuum.io/pkgs/main/linux-64/sqlite-3.23.1-he433501_0.tar.bz2 81 | https://repo.continuum.io/pkgs/main/linux-64/tk-8.6.7-hc745277_3.tar.bz2 82 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.2.1-py36_1.tar.bz2 83 | https://repo.continuum.io/pkgs/main/linux-64/tornado-5.0.2-py36_0.tar.bz2 84 | https://repo.continuum.io/pkgs/main/linux-64/tqdm-4.23.0-py36_0.tar.bz2 85 | https://repo.continuum.io/pkgs/main/linux-64/wheel-0.31.0-py36_0.tar.bz2 86 | https://repo.continuum.io/pkgs/main/linux-64/xz-5.2.3-h5e939de_4.tar.bz2 87 | https://repo.continuum.io/pkgs/main/linux-64/yaml-0.1.7-had09818_2.tar.bz2 88 | https://repo.continuum.io/pkgs/main/linux-64/zlib-1.2.11-ha838bed_2.tar.bz2 89 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | 5 | from transforms import transforms 6 | from utils.autoaugment import ImageNetPolicy 7 | 8 | # pretrained model checkpoints 9 | pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',} 10 | 11 | # transforms dict 12 | def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]): 13 | center_resize = 600 14 | Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 15 | data_transforms = { 16 | 'swap': transforms.Compose([ 17 | transforms.Randomswap((swap_num[0], swap_num[1])), 18 | ]), 19 | 'common_aug': transforms.Compose([ 20 | transforms.Resize((resize_reso, resize_reso)), 21 | transforms.RandomRotation(degrees=15), 22 | transforms.RandomCrop((crop_reso,crop_reso)), 23 | transforms.RandomHorizontalFlip(), 24 | ]), 25 | 'train_totensor': transforms.Compose([ 26 | transforms.Resize((crop_reso, crop_reso)), 27 | # ImageNetPolicy(), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 30 | ]), 31 | 'val_totensor': transforms.Compose([ 32 | transforms.Resize((crop_reso, crop_reso)), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 35 | ]), 36 | 'test_totensor': transforms.Compose([ 37 | transforms.Resize((resize_reso, resize_reso)), 38 | transforms.CenterCrop((crop_reso, crop_reso)), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 41 | ]), 42 | 'None': None, 43 | } 44 | return data_transforms 45 | 46 | 47 | class LoadConfig(object): 48 | def __init__(self, args, version): 49 | if version == 'train': 50 | get_list = ['train', 'val'] 51 | elif version == 'val': 52 | get_list = ['val'] 53 | elif version == 'test': 54 | get_list = ['test'] 55 | else: 56 | raise Exception("train/val/test ???\n") 57 | 58 | ############################### 59 | #### add dataset info here #### 60 | ############################### 61 | 62 | # put image data in $PATH/data 63 | # put annotation txt file in $PATH/anno 64 | 65 | if args.dataset == 'product': 66 | self.dataset = args.dataset 67 | self.rawdata_root = './../FGVC_product/data' 68 | self.anno_root = './../FGVC_product/anno' 69 | self.numcls = 2019 70 | elif args.dataset == 'CUB': 71 | self.dataset = args.dataset 72 | self.rawdata_root = './dataset/CUB_200_2011/data' 73 | self.anno_root = './dataset/CUB_200_2011/anno' 74 | self.numcls = 200 75 | elif args.dataset == 'STCAR': 76 | self.dataset = args.dataset 77 | self.rawdata_root = './dataset/st_car/data' 78 | self.anno_root = './dataset/st_car/anno' 79 | self.numcls = 196 80 | elif args.dataset == 'AIR': 81 | self.dataset = args.dataset 82 | self.rawdata_root = './dataset/aircraft/data' 83 | self.anno_root = './dataset/aircraft/anno' 84 | self.numcls = 100 85 | else: 86 | raise Exception('dataset not defined ???') 87 | 88 | # annotation file organized as : 89 | # path/image_name cls_num\n 90 | 91 | if 'train' in get_list: 92 | self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\ 93 | sep=" ",\ 94 | header=None,\ 95 | names=['ImageName', 'label']) 96 | 97 | if 'val' in get_list: 98 | self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\ 99 | sep=" ",\ 100 | header=None,\ 101 | names=['ImageName', 'label']) 102 | 103 | if 'test' in get_list: 104 | self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\ 105 | sep=" ",\ 106 | header=None,\ 107 | names=['ImageName', 'label']) 108 | 109 | self.swap_num = args.swap_num 110 | 111 | self.save_dir = './net_model' 112 | if not os.path.exists(self.save_dir): 113 | os.mkdir(self.save_dir) 114 | self.backbone = args.backbone 115 | 116 | self.use_dcl = True 117 | self.use_backbone = False if self.use_dcl else True 118 | self.use_Asoftmax = False 119 | self.use_focal_loss = False 120 | self.use_fpn = False 121 | self.use_hier = False 122 | 123 | self.weighted_sample = False 124 | self.cls_2 = True 125 | self.cls_2xmul = False 126 | 127 | self.log_folder = './logs' 128 | if not os.path.exists(self.log_folder): 129 | os.mkdir(self.log_folder) 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /datasets/STCAR/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/datasets/STCAR/.DS_Store -------------------------------------------------------------------------------- /models/Asoftmax_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | import math 7 | 8 | def myphi(x,m): 9 | x = x * m 10 | return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \ 11 | x**8/math.factorial(8) - x**9/math.factorial(9) 12 | 13 | class AngleLinear(nn.Module): 14 | def __init__(self, in_features, out_features, m = 4, phiflag=True): 15 | super(AngleLinear, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.Tensor(in_features,out_features)) 19 | self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 20 | self.phiflag = phiflag 21 | self.m = m 22 | self.mlambda = [ 23 | lambda x: x**0, 24 | lambda x: x**1, 25 | lambda x: 2*x**2-1, 26 | lambda x: 4*x**3-3*x, 27 | lambda x: 8*x**4-8*x**2+1, 28 | lambda x: 16*x**5-20*x**3+5*x 29 | ] 30 | 31 | def forward(self, input): 32 | x = input # size=(B,F) F is feature len 33 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features 34 | 35 | ww = w.renorm(2,1,1e-5).mul(1e5) 36 | xlen = x.pow(2).sum(1).pow(0.5) # size=B 37 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum 38 | 39 | cos_theta = x.mm(ww) # size=(B,Classnum) 40 | cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1) 41 | cos_theta = cos_theta.clamp(-1,1) 42 | 43 | if self.phiflag: 44 | cos_m_theta = self.mlambda[self.m](cos_theta) 45 | theta = Variable(cos_theta.data.acos()) 46 | k = (self.m*theta/3.14159265).floor() 47 | n_one = k*0.0 - 1 48 | phi_theta = (n_one**k) * cos_m_theta - 2*k 49 | else: 50 | theta = cos_theta.acos() 51 | phi_theta = myphi(theta,self.m) 52 | phi_theta = phi_theta.clamp(-1*self.m,1) 53 | 54 | cos_theta = cos_theta * xlen.view(-1,1) 55 | phi_theta = phi_theta * xlen.view(-1,1) 56 | output = (cos_theta,phi_theta) 57 | return output # size=(B,Classnum,2) 58 | 59 | 60 | -------------------------------------------------------------------------------- /models/LoadModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | import torch 4 | from torchvision import models, transforms, datasets 5 | import torch.nn.functional as F 6 | import pretrainedmodels 7 | 8 | from config import pretrained_model 9 | 10 | import pdb 11 | 12 | class MainModel(nn.Module): 13 | def __init__(self, config): 14 | super(MainModel, self).__init__() 15 | self.use_dcl = config.use_dcl 16 | self.num_classes = config.numcls 17 | self.backbone_arch = config.backbone 18 | self.use_Asoftmax = config.use_Asoftmax 19 | print(self.backbone_arch) 20 | 21 | if self.backbone_arch in dir(models): 22 | self.model = getattr(models, self.backbone_arch)() 23 | if self.backbone_arch in pretrained_model: 24 | self.model.load_state_dict(torch.load(pretrained_model[self.backbone_arch])) 25 | else: 26 | if self.backbone_arch in pretrained_model: 27 | self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000, pretrained=None) 28 | else: 29 | self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000) 30 | 31 | if self.backbone_arch == 'resnet50' or self.backbone_arch == 'se_resnet50': 32 | self.model = nn.Sequential(*list(self.model.children())[:-2]) 33 | if self.backbone_arch == 'senet154': 34 | self.model = nn.Sequential(*list(self.model.children())[:-3]) 35 | if self.backbone_arch == 'se_resnext101_32x4d': 36 | self.model = nn.Sequential(*list(self.model.children())[:-2]) 37 | if self.backbone_arch == 'se_resnet101': 38 | self.model = nn.Sequential(*list(self.model.children())[:-2]) 39 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) 40 | self.classifier = nn.Linear(2048, self.num_classes, bias=False) 41 | 42 | if self.use_dcl: 43 | if config.cls_2: 44 | self.classifier_swap = nn.Linear(2048, 2, bias=False) 45 | if config.cls_2xmul: 46 | self.classifier_swap = nn.Linear(2048, 2*self.num_classes, bias=False) 47 | self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True) 48 | self.avgpool2 = nn.AvgPool2d(2, stride=2) 49 | 50 | if self.use_Asoftmax: 51 | self.Aclassifier = AngleLinear(2048, self.num_classes, bias=False) 52 | 53 | def forward(self, x, last_cont=None): 54 | x = self.model(x) 55 | if self.use_dcl: 56 | mask = self.Convmask(x) 57 | mask = self.avgpool2(mask) 58 | mask = torch.tanh(mask) 59 | mask = mask.view(mask.size(0), -1) 60 | 61 | x = self.avgpool(x) 62 | x = x.view(x.size(0), -1) 63 | out = [] 64 | out.append(self.classifier(x)) 65 | 66 | if self.use_dcl: 67 | out.append(self.classifier_swap(x)) 68 | out.append(mask) 69 | 70 | if self.use_Asoftmax: 71 | if last_cont is None: 72 | x_size = x.size(0) 73 | out.append(self.Aclassifier(x[0:x_size:2])) 74 | else: 75 | last_x = self.model(last_cont) 76 | last_x = self.avgpool(last_x) 77 | last_x = last_x.view(last_x.size(0), -1) 78 | out.append(self.Aclassifier(last_x)) 79 | 80 | return out 81 | -------------------------------------------------------------------------------- /models/__pycache__/LoadModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/models/__pycache__/LoadModel.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/focal_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/models/__pycache__/focal_loss.cpython-36.pyc -------------------------------------------------------------------------------- /models/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): #1d and 2d 6 | 7 | def __init__(self, gamma=2, size_average=True): 8 | super(FocalLoss, self).__init__() 9 | self.gamma = gamma 10 | self.size_average = size_average 11 | 12 | 13 | def forward(self, logit, target, class_weight=None, type='softmax'): 14 | target = target.view(-1, 1).long() 15 | if type=='sigmoid': 16 | if class_weight is None: 17 | class_weight = [1]*2 #[0.5, 0.5] 18 | 19 | prob = torch.sigmoid(logit) 20 | prob = prob.view(-1, 1) 21 | prob = torch.cat((1-prob, prob), 1) 22 | select = torch.FloatTensor(len(prob), 2).zero_().cuda() 23 | select.scatter_(1, target, 1.) 24 | 25 | elif type=='softmax': 26 | B,C = logit.size() 27 | if class_weight is None: 28 | class_weight =[1]*C #[1/C]*C 29 | 30 | #logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) 31 | prob = F.softmax(logit,1) 32 | select = torch.FloatTensor(len(prob), C).zero_().cuda() 33 | select.scatter_(1, target, 1.) 34 | 35 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) 36 | class_weight = torch.gather(class_weight, 0, target) 37 | 38 | prob = (prob*select).sum(1).view(-1,1) 39 | prob = torch.clamp(prob,1e-8,1-1e-8) 40 | batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log() 41 | 42 | if self.size_average: 43 | loss = batch_loss.mean() 44 | else: 45 | loss = batch_loss 46 | 47 | return loss 48 | 49 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import json 4 | import csv 5 | import argparse 6 | import pandas as pd 7 | import numpy as np 8 | from math import ceil 9 | from tqdm import tqdm 10 | import pickle 11 | import shutil 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch.nn import CrossEntropyLoss 17 | from torchvision import datasets, models 18 | import torch.backends.cudnn as cudnn 19 | import torch.nn.functional as F 20 | 21 | from transforms import transforms 22 | from models.LoadModel import MainModel 23 | from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset 24 | from config import LoadConfig, load_data_transformers 25 | from utils.test_tool import set_text, save_multi_img, cls_base_acc 26 | 27 | import pdb 28 | 29 | os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID' 30 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description='dcl parameters') 34 | parser.add_argument('--data', dest='dataset', 35 | default='CUB', type=str) 36 | parser.add_argument('--backbone', dest='backbone', 37 | default='resnet50', type=str) 38 | parser.add_argument('--b', dest='batch_size', 39 | default=16, type=int) 40 | parser.add_argument('--nw', dest='num_workers', 41 | default=16, type=int) 42 | parser.add_argument('--ver', dest='version', 43 | default='val', type=str) 44 | parser.add_argument('--save', dest='resume', 45 | default=None, type=str) 46 | parser.add_argument('--size', dest='resize_resolution', 47 | default=512, type=int) 48 | parser.add_argument('--crop', dest='crop_resolution', 49 | default=448, type=int) 50 | parser.add_argument('--ss', dest='save_suffix', 51 | default=None, type=str) 52 | parser.add_argument('--acc_report', dest='acc_report', 53 | action='store_true') 54 | parser.add_argument('--swap_num', default=[7, 7], 55 | nargs=2, metavar=('swap1', 'swap2'), 56 | type=int, help='specify a range') 57 | args = parser.parse_args() 58 | return args 59 | 60 | if __name__ == '__main__': 61 | args = parse_args() 62 | print(args) 63 | if args.submit: 64 | args.version = 'test' 65 | if args.save_suffix == '': 66 | raise Exception('**** miss --ss save suffix is needed. ') 67 | 68 | Config = LoadConfig(args, args.version) 69 | transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num) 70 | data_set = dataset(Config,\ 71 | anno=Config.val_anno if args.version == 'val' else Config.test_anno ,\ 72 | unswap=transformers["None"],\ 73 | swap=transformers["None"],\ 74 | totensor=transformers['test_totensor'],\ 75 | test=True) 76 | 77 | dataloader = torch.utils.data.DataLoader(data_set,\ 78 | batch_size=args.batch_size,\ 79 | shuffle=False,\ 80 | num_workers=args.num_workers,\ 81 | collate_fn=collate_fn4test) 82 | 83 | setattr(dataloader, 'total_item_len', len(data_set)) 84 | 85 | cudnn.benchmark = True 86 | 87 | model = MainModel(Config) 88 | model_dict=model.state_dict() 89 | pretrained_dict=torch.load(resume) 90 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} 91 | model_dict.update(pretrained_dict) 92 | model.load_state_dict(model_dict) 93 | model.cuda() 94 | model = nn.DataParallel(model) 95 | 96 | model.train(False) 97 | with torch.no_grad(): 98 | val_corrects1 = 0 99 | val_corrects2 = 0 100 | val_corrects3 = 0 101 | val_size = ceil(len(data_set) / dataloader.batch_size) 102 | result_gather = {} 103 | count_bar = tqdm(total=dataloader.__len__()) 104 | for batch_cnt_val, data_val in enumerate(dataloader): 105 | count_bar.update(1) 106 | inputs, labels, img_name = data_val 107 | inputs = Variable(inputs.cuda()) 108 | labels = Variable(torch.from_numpy(np.array(labels)).long().cuda()) 109 | 110 | outputs = model(inputs) 111 | outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls] 112 | 113 | top3_val, top3_pos = torch.topk(outputs_pred, 3) 114 | 115 | if args.version == 'val': 116 | batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item() 117 | val_corrects1 += batch_corrects1 118 | batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item() 119 | val_corrects2 += (batch_corrects2 + batch_corrects1) 120 | batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item() 121 | val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1) 122 | 123 | if args.acc_report: 124 | for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()): 125 | result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2], 126 | 'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2], 127 | 'label': sub_label} 128 | if args.acc_report: 129 | torch.save(result_gather, 'result_gather_%s'%resume.split('/')[-1][:-4]+ '.pt') 130 | 131 | count_bar.close() 132 | 133 | if args.acc_report: 134 | 135 | val_acc1 = val_corrects1 / len(data_set) 136 | val_acc2 = val_corrects2 / len(data_set) 137 | val_acc3 = val_corrects3 / len(data_set) 138 | print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-', val_acc3, 8*'-')) 139 | 140 | cls_top1, cls_top3, cls_count = cls_base_acc(result_gather) 141 | 142 | acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, resume.split('/')[-1]), 'w') 143 | json.dump({'val_acc1':val_acc1, 144 | 'val_acc2':val_acc2, 145 | 'val_acc3':val_acc3, 146 | 'cls_top1':cls_top1, 147 | 'cls_top3':cls_top3, 148 | 'cls_count':cls_count}, acc_report_io) 149 | acc_report_io.close() 150 | 151 | 152 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import datetime 4 | import argparse 5 | import logging 6 | import pandas as pd 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import CrossEntropyLoss 11 | import torch.utils.data as torchdata 12 | from torchvision import datasets, models 13 | import torch.optim as optim 14 | from torch.optim import lr_scheduler 15 | import torch.backends.cudnn as cudnn 16 | 17 | from transforms import transforms 18 | from utils.train_model import train 19 | from models.LoadModel import MainModel 20 | from config import LoadConfig, load_data_transformers 21 | from utils.dataset_DCL import collate_fn4train, collate_fn4val, collate_fn4test, collate_fn4backbone, dataset 22 | 23 | import pdb 24 | 25 | os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID' 26 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 27 | 28 | # parameters setting 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='dcl parameters') 31 | parser.add_argument('--data', dest='dataset', 32 | default='CUB', type=str) 33 | parser.add_argument('--save', dest='resume', 34 | default=None, 35 | type=str) 36 | parser.add_argument('--backbone', dest='backbone', 37 | default='resnet50', type=str) 38 | parser.add_argument('--auto_resume', dest='auto_resume', 39 | action='store_true') 40 | parser.add_argument('--epoch', dest='epoch', 41 | default=360, type=int) 42 | parser.add_argument('--tb', dest='train_batch', 43 | default=16, type=int) 44 | parser.add_argument('--vb', dest='val_batch', 45 | default=512, type=int) 46 | parser.add_argument('--sp', dest='save_point', 47 | default=5000, type=int) 48 | parser.add_argument('--cp', dest='check_point', 49 | default=5000, type=int) 50 | parser.add_argument('--lr', dest='base_lr', 51 | default=0.0008, type=float) 52 | parser.add_argument('--lr_step', dest='decay_step', 53 | default=60, type=int) 54 | parser.add_argument('--cls_lr_ratio', dest='cls_lr_ratio', 55 | default=10.0, type=float) 56 | parser.add_argument('--start_epoch', dest='start_epoch', 57 | default=0, type=int) 58 | parser.add_argument('--tnw', dest='train_num_workers', 59 | default=16, type=int) 60 | parser.add_argument('--vnw', dest='val_num_workers', 61 | default=32, type=int) 62 | parser.add_argument('--detail', dest='discribe', 63 | default='', type=str) 64 | parser.add_argument('--size', dest='resize_resolution', 65 | default=512, type=int) 66 | parser.add_argument('--crop', dest='crop_resolution', 67 | default=448, type=int) 68 | parser.add_argument('--cls_2', dest='cls_2', 69 | action='store_true') 70 | parser.add_argument('--cls_mul', dest='cls_mul', 71 | action='store_true') 72 | parser.add_argument('--swap_num', default=[7, 7], 73 | nargs=2, metavar=('swap1', 'swap2'), 74 | type=int, help='specify a range') 75 | args = parser.parse_args() 76 | return args 77 | 78 | def auto_load_resume(load_dir): 79 | folders = os.listdir(load_dir) 80 | date_list = [int(x.split('_')[1].replace(' ',0)) for x in folders] 81 | choosed = folders[date_list.index(max(date_list))] 82 | weight_list = os.listdir(os.path.join(load_dir, choosed)) 83 | acc_list = [x[:-4].split('_')[-1] if x[:7]=='weights' else 0 for x in weight_list] 84 | acc_list = [float(x) for x in acc_list] 85 | choosed_w = weight_list[acc_list.index(max(acc_list))] 86 | return os.path.join(load_dir, choosed, choosed_w) 87 | 88 | 89 | if __name__ == '__main__': 90 | args = parse_args() 91 | print(args, flush=True) 92 | Config = LoadConfig(args, 'train') 93 | Config.cls_2 = args.cls_2 94 | Config.cls_2xmul = args.cls_mul 95 | assert Config.cls_2 ^ Config.cls_2xmul 96 | 97 | transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num) 98 | 99 | # inital dataloader 100 | train_set = dataset(Config = Config,\ 101 | anno = Config.train_anno,\ 102 | common_aug = transformers["common_aug"],\ 103 | swap = transformers["swap"],\ 104 | totensor = transformers["train_totensor"],\ 105 | train = True) 106 | 107 | trainval_set = dataset(Config = Config,\ 108 | anno = Config.train_anno,\ 109 | common_aug = transformers["None"],\ 110 | swap = transformers["None"],\ 111 | totensor = transformers["val_totensor"],\ 112 | train = False, 113 | train_val = True) 114 | 115 | val_set = dataset(Config = Config,\ 116 | anno = Config.val_anno,\ 117 | common_aug = transformers["None"],\ 118 | swap = transformers["None"],\ 119 | totensor = transformers["test_totensor"],\ 120 | test=True) 121 | 122 | dataloader = {} 123 | dataloader['train'] = torch.utils.data.DataLoader(train_set,\ 124 | batch_size=args.train_batch,\ 125 | shuffle=True,\ 126 | num_workers=args.train_num_workers,\ 127 | collate_fn=collate_fn4train if not Config.use_backbone else collate_fn4backbone, 128 | drop_last=True if Config.use_backbone else False, 129 | pin_memory=True) 130 | 131 | setattr(dataloader['train'], 'total_item_len', len(train_set)) 132 | 133 | dataloader['trainval'] = torch.utils.data.DataLoader(trainval_set,\ 134 | batch_size=args.val_batch,\ 135 | shuffle=False,\ 136 | num_workers=args.val_num_workers,\ 137 | collate_fn=collate_fn4val if not Config.use_backbone else collate_fn4backbone, 138 | drop_last=True if Config.use_backbone else False, 139 | pin_memory=True) 140 | 141 | setattr(dataloader['trainval'], 'total_item_len', len(trainval_set)) 142 | setattr(dataloader['trainval'], 'num_cls', Config.numcls) 143 | 144 | dataloader['val'] = torch.utils.data.DataLoader(val_set,\ 145 | batch_size=args.val_batch,\ 146 | shuffle=False,\ 147 | num_workers=args.val_num_workers,\ 148 | collate_fn=collate_fn4test if not Config.use_backbone else collate_fn4backbone, 149 | drop_last=True if Config.use_backbone else False, 150 | pin_memory=True) 151 | 152 | setattr(dataloader['val'], 'total_item_len', len(val_set)) 153 | setattr(dataloader['val'], 'num_cls', Config.numcls) 154 | 155 | 156 | cudnn.benchmark = True 157 | 158 | print('Choose model and train set', flush=True) 159 | model = MainModel(Config) 160 | 161 | # load model 162 | if (args.resume is None) and (not args.auto_resume): 163 | print('train from imagenet pretrained models ...', flush=True) 164 | else: 165 | if not args.resume is None: 166 | resume = args.resume 167 | print('load from pretrained checkpoint %s ...'% resume, flush=True) 168 | elif args.auto_resume: 169 | resume = auto_load_resume(Config.save_dir) 170 | print('load from %s ...'%resume, flush=True) 171 | else: 172 | raise Exception("no checkpoints to load") 173 | 174 | model_dict = model.state_dict() 175 | pretrained_dict = torch.load(resume) 176 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} 177 | model_dict.update(pretrained_dict) 178 | model.load_state_dict(model_dict) 179 | 180 | print('Set cache dir', flush=True) 181 | time = datetime.datetime.now() 182 | filename = '%s_%d%d%d_%s'%(args.discribe, time.month, time.day, time.hour, Config.dataset) 183 | save_dir = os.path.join(Config.save_dir, filename) 184 | if not os.path.exists(save_dir): 185 | os.makedirs(save_dir) 186 | 187 | model.cuda() 188 | model = nn.DataParallel(model) 189 | 190 | # optimizer prepare 191 | if Config.use_backbone: 192 | ignored_params = list(map(id, model.module.classifier.parameters())) 193 | else: 194 | ignored_params1 = list(map(id, model.module.classifier.parameters())) 195 | ignored_params2 = list(map(id, model.module.classifier_swap.parameters())) 196 | ignored_params3 = list(map(id, model.module.Convmask.parameters())) 197 | 198 | ignored_params = ignored_params1 + ignored_params2 + ignored_params3 199 | print('the num of new layers:', len(ignored_params), flush=True) 200 | base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters()) 201 | 202 | lr_ratio = args.cls_lr_ratio 203 | base_lr = args.base_lr 204 | if Config.use_backbone: 205 | optimizer = optim.SGD([{'params': base_params}, 206 | {'params': model.module.classifier.parameters(), 'lr': base_lr}], lr = base_lr, momentum=0.9) 207 | else: 208 | optimizer = optim.SGD([{'params': base_params}, 209 | {'params': model.module.classifier.parameters(), 'lr': lr_ratio*base_lr}, 210 | {'params': model.module.classifier_swap.parameters(), 'lr': lr_ratio*base_lr}, 211 | {'params': model.module.Convmask.parameters(), 'lr': lr_ratio*base_lr}, 212 | ], lr = base_lr, momentum=0.9) 213 | 214 | 215 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1) 216 | 217 | # train entry 218 | train(Config, 219 | model, 220 | epoch_num=args.epoch, 221 | start_epoch=args.start_epoch, 222 | optimizer=optimizer, 223 | exp_lr_scheduler=exp_lr_scheduler, 224 | data_loader=dataloader, 225 | save_dir=save_dir, 226 | data_size=args.crop_resolution, 227 | savepoint=args.save_point, 228 | checkpoint=args.check_point) 229 | 230 | 231 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/functional.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/transforms/__pycache__/functional.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/transforms/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | 17 | def _is_pil_image(img): 18 | if accimage is not None: 19 | return isinstance(img, (Image.Image, accimage.Image)) 20 | else: 21 | return isinstance(img, Image.Image) 22 | 23 | 24 | def _is_tensor_image(img): 25 | return torch.is_tensor(img) and img.ndimension() == 3 26 | 27 | 28 | def _is_numpy_image(img): 29 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 30 | 31 | 32 | def to_tensor(pic): 33 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 34 | 35 | See ``ToTensor`` for more details. 36 | 37 | Args: 38 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 39 | 40 | Returns: 41 | Tensor: Converted image. 42 | """ 43 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 44 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 45 | 46 | if isinstance(pic, np.ndarray): 47 | # handle numpy array 48 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 49 | # backward compatibility 50 | if isinstance(img, torch.ByteTensor): 51 | return img.float().div(255) 52 | else: 53 | return img 54 | 55 | if accimage is not None and isinstance(pic, accimage.Image): 56 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 57 | pic.copyto(nppic) 58 | return torch.from_numpy(nppic) 59 | 60 | # handle PIL Image 61 | if pic.mode == 'I': 62 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 63 | elif pic.mode == 'I;16': 64 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 65 | elif pic.mode == 'F': 66 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 67 | elif pic.mode == '1': 68 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 69 | else: 70 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 71 | # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK 72 | if pic.mode == 'YCbCr': 73 | nchannel = 3 74 | elif pic.mode == 'I;16': 75 | nchannel = 1 76 | else: 77 | nchannel = len(pic.mode) 78 | img = img.view(pic.size[1], pic.size[0], nchannel) 79 | # put it from HWC to CHW format 80 | # yikes, this transpose takes 80% of the loading time/CPU 81 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 82 | if isinstance(img, torch.ByteTensor): 83 | return img.float().div(255) 84 | else: 85 | return img 86 | 87 | 88 | def to_pil_image(pic, mode=None): 89 | """Convert a tensor or an ndarray to PIL Image. 90 | 91 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 92 | 93 | Args: 94 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 95 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 96 | 97 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 98 | 99 | Returns: 100 | PIL Image: Image converted to PIL Image. 101 | """ 102 | if not(_is_numpy_image(pic) or _is_tensor_image(pic)): 103 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 104 | 105 | npimg = pic 106 | if isinstance(pic, torch.FloatTensor): 107 | pic = pic.mul(255).byte() 108 | if torch.is_tensor(pic): 109 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 110 | 111 | if not isinstance(npimg, np.ndarray): 112 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 113 | 'not {}'.format(type(npimg))) 114 | 115 | if npimg.shape[2] == 1: 116 | expected_mode = None 117 | npimg = npimg[:, :, 0] 118 | if npimg.dtype == np.uint8: 119 | expected_mode = 'L' 120 | elif npimg.dtype == np.int16: 121 | expected_mode = 'I;16' 122 | elif npimg.dtype == np.int32: 123 | expected_mode = 'I' 124 | elif npimg.dtype == np.float32: 125 | expected_mode = 'F' 126 | if mode is not None and mode != expected_mode: 127 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 128 | .format(mode, np.dtype, expected_mode)) 129 | mode = expected_mode 130 | 131 | elif npimg.shape[2] == 4: 132 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 133 | if mode is not None and mode not in permitted_4_channel_modes: 134 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 135 | 136 | if mode is None and npimg.dtype == np.uint8: 137 | mode = 'RGBA' 138 | else: 139 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 140 | if mode is not None and mode not in permitted_3_channel_modes: 141 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 142 | if mode is None and npimg.dtype == np.uint8: 143 | mode = 'RGB' 144 | 145 | if mode is None: 146 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 147 | 148 | return Image.fromarray(npimg, mode=mode) 149 | 150 | 151 | def normalize(tensor, mean, std): 152 | """Normalize a tensor image with mean and standard deviation. 153 | 154 | See ``Normalize`` for more details. 155 | 156 | Args: 157 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 158 | mean (sequence): Sequence of means for each channel. 159 | std (sequence): Sequence of standard deviations for each channely. 160 | 161 | Returns: 162 | Tensor: Normalized Tensor image. 163 | """ 164 | if not _is_tensor_image(tensor): 165 | raise TypeError('tensor is not a torch image.') 166 | # TODO: make efficient 167 | for t, m, s in zip(tensor, mean, std): 168 | t.sub_(m).div_(s) 169 | return tensor 170 | 171 | 172 | def resize(img, size, interpolation=Image.BILINEAR): 173 | """Resize the input PIL Image to the given size. 174 | 175 | Args: 176 | img (PIL Image): Image to be resized. 177 | size (sequence or int): Desired output size. If size is a sequence like 178 | (h, w), the output size will be matched to this. If size is an int, 179 | the smaller edge of the image will be matched to this number maintaing 180 | the aspect ratio. i.e, if height > width, then image will be rescaled to 181 | (size * height / width, size) 182 | interpolation (int, optional): Desired interpolation. Default is 183 | ``PIL.Image.BILINEAR`` 184 | 185 | Returns: 186 | PIL Image: Resized image. 187 | """ 188 | if not _is_pil_image(img): 189 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 190 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 191 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 192 | 193 | if isinstance(size, int): 194 | w, h = img.size 195 | if (w <= h and w == size) or (h <= w and h == size): 196 | return img 197 | if w < h: 198 | ow = size 199 | oh = int(size * h / w) 200 | return img.resize((ow, oh), interpolation) 201 | else: 202 | oh = size 203 | ow = int(size * w / h) 204 | return img.resize((ow, oh), interpolation) 205 | else: 206 | return img.resize(size[::-1], interpolation) 207 | 208 | 209 | def scale(*args, **kwargs): 210 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 211 | "please use transforms.Resize instead.") 212 | return resize(*args, **kwargs) 213 | 214 | 215 | def pad(img, padding, fill=0, padding_mode='constant'): 216 | """Pad the given PIL Image on all sides with speficified padding mode and fill value. 217 | 218 | Args: 219 | img (PIL Image): Image to be padded. 220 | padding (int or tuple): Padding on each border. If a single int is provided this 221 | is used to pad all borders. If tuple of length 2 is provided this is the padding 222 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 223 | this is the padding for the left, top, right and bottom borders 224 | respectively. 225 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 226 | length 3, it is used to fill R, G, B channels respectively. 227 | This value is only used when the padding_mode is constant 228 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 229 | constant: pads with a constant value, this value is specified with fill 230 | edge: pads with the last value on the edge of the image 231 | reflect: pads with reflection of image (without repeating the last value on the edge) 232 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 233 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 234 | symmetric: pads with reflection of image (repeating the last value on the edge) 235 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 236 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 237 | 238 | Returns: 239 | PIL Image: Padded image. 240 | """ 241 | if not _is_pil_image(img): 242 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 243 | 244 | if not isinstance(padding, (numbers.Number, tuple)): 245 | raise TypeError('Got inappropriate padding arg') 246 | if not isinstance(fill, (numbers.Number, str, tuple)): 247 | raise TypeError('Got inappropriate fill arg') 248 | if not isinstance(padding_mode, str): 249 | raise TypeError('Got inappropriate padding_mode arg') 250 | 251 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 252 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 253 | "{} element tuple".format(len(padding))) 254 | 255 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ 256 | 'Padding mode should be either constant, edge, reflect or symmetric' 257 | 258 | if padding_mode == 'constant': 259 | return ImageOps.expand(img, border=padding, fill=fill) 260 | else: 261 | if isinstance(padding, int): 262 | pad_left = pad_right = pad_top = pad_bottom = padding 263 | if isinstance(padding, collections.Sequence) and len(padding) == 2: 264 | pad_left = pad_right = padding[0] 265 | pad_top = pad_bottom = padding[1] 266 | if isinstance(padding, collections.Sequence) and len(padding) == 4: 267 | pad_left = padding[0] 268 | pad_top = padding[1] 269 | pad_right = padding[2] 270 | pad_bottom = padding[3] 271 | 272 | img = np.asarray(img) 273 | # RGB image 274 | if len(img.shape) == 3: 275 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) 276 | # Grayscale image 277 | if len(img.shape) == 2: 278 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 279 | 280 | return Image.fromarray(img) 281 | 282 | 283 | def crop(img, i, j, h, w): 284 | """Crop the given PIL Image. 285 | 286 | Args: 287 | img (PIL Image): Image to be cropped. 288 | i: Upper pixel coordinate. 289 | j: Left pixel coordinate. 290 | h: Height of the cropped image. 291 | w: Width of the cropped image. 292 | 293 | Returns: 294 | PIL Image: Cropped image. 295 | """ 296 | if not _is_pil_image(img): 297 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 298 | 299 | return img.crop((j, i, j + w, i + h)) 300 | 301 | 302 | def center_crop(img, output_size): 303 | if isinstance(output_size, numbers.Number): 304 | output_size = (int(output_size), int(output_size)) 305 | w, h = img.size 306 | th, tw = output_size 307 | i = int(round((h - th) / 2.)) 308 | j = int(round((w - tw) / 2.)) 309 | return crop(img, i, j, th, tw) 310 | 311 | 312 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 313 | """Crop the given PIL Image and resize it to desired size. 314 | 315 | Notably used in RandomResizedCrop. 316 | 317 | Args: 318 | img (PIL Image): Image to be cropped. 319 | i: Upper pixel coordinate. 320 | j: Left pixel coordinate. 321 | h: Height of the cropped image. 322 | w: Width of the cropped image. 323 | size (sequence or int): Desired output size. Same semantics as ``scale``. 324 | interpolation (int, optional): Desired interpolation. Default is 325 | ``PIL.Image.BILINEAR``. 326 | Returns: 327 | PIL Image: Cropped image. 328 | """ 329 | assert _is_pil_image(img), 'img should be PIL Image' 330 | img = crop(img, i, j, h, w) 331 | img = resize(img, size, interpolation) 332 | return img 333 | 334 | 335 | def hflip(img): 336 | """Horizontally flip the given PIL Image. 337 | 338 | Args: 339 | img (PIL Image): Image to be flipped. 340 | 341 | Returns: 342 | PIL Image: Horizontall flipped image. 343 | """ 344 | if not _is_pil_image(img): 345 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 346 | 347 | return img.transpose(Image.FLIP_LEFT_RIGHT) 348 | 349 | 350 | def vflip(img): 351 | """Vertically flip the given PIL Image. 352 | 353 | Args: 354 | img (PIL Image): Image to be flipped. 355 | 356 | Returns: 357 | PIL Image: Vertically flipped image. 358 | """ 359 | if not _is_pil_image(img): 360 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 361 | 362 | return img.transpose(Image.FLIP_TOP_BOTTOM) 363 | 364 | 365 | def swap(img, crop): 366 | def crop_image(image, cropnum): 367 | width, high = image.size 368 | crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)] 369 | crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)] 370 | im_list = [] 371 | for j in range(len(crop_y) - 1): 372 | for i in range(len(crop_x) - 1): 373 | im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high)))) 374 | return im_list 375 | 376 | widthcut, highcut = img.size 377 | img = img.crop((10, 10, widthcut-10, highcut-10)) 378 | images = crop_image(img, crop) 379 | pro = 5 380 | if pro >= 5: 381 | tmpx = [] 382 | tmpy = [] 383 | count_x = 0 384 | count_y = 0 385 | k = 1 386 | RAN = 2 387 | for i in range(crop[1] * crop[0]): 388 | tmpx.append(images[i]) 389 | count_x += 1 390 | if len(tmpx) >= k: 391 | tmp = tmpx[count_x - RAN:count_x] 392 | random.shuffle(tmp) 393 | tmpx[count_x - RAN:count_x] = tmp 394 | if count_x == crop[0]: 395 | tmpy.append(tmpx) 396 | count_x = 0 397 | count_y += 1 398 | tmpx = [] 399 | if len(tmpy) >= k: 400 | tmp2 = tmpy[count_y - RAN:count_y] 401 | random.shuffle(tmp2) 402 | tmpy[count_y - RAN:count_y] = tmp2 403 | random_im = [] 404 | for line in tmpy: 405 | random_im.extend(line) 406 | 407 | # random.shuffle(images) 408 | width, high = img.size 409 | iw = int(width / crop[0]) 410 | ih = int(high / crop[1]) 411 | toImage = Image.new('RGB', (iw * crop[0], ih * crop[1])) 412 | x = 0 413 | y = 0 414 | for i in random_im: 415 | i = i.resize((iw, ih), Image.ANTIALIAS) 416 | toImage.paste(i, (x * iw, y * ih)) 417 | x += 1 418 | if x == crop[0]: 419 | x = 0 420 | y += 1 421 | else: 422 | toImage = img 423 | toImage = toImage.resize((widthcut, highcut)) 424 | return toImage 425 | 426 | 427 | 428 | def five_crop(img, size): 429 | """Crop the given PIL Image into four corners and the central crop. 430 | 431 | .. Note:: 432 | This transform returns a tuple of images and there may be a 433 | mismatch in the number of inputs and targets your ``Dataset`` returns. 434 | 435 | Args: 436 | size (sequence or int): Desired output size of the crop. If size is an 437 | int instead of sequence like (h, w), a square crop (size, size) is 438 | made. 439 | Returns: 440 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 441 | top right, bottom left, bottom right and center crop. 442 | """ 443 | if isinstance(size, numbers.Number): 444 | size = (int(size), int(size)) 445 | else: 446 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 447 | 448 | w, h = img.size 449 | crop_h, crop_w = size 450 | if crop_w > w or crop_h > h: 451 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 452 | (h, w))) 453 | tl = img.crop((0, 0, crop_w, crop_h)) 454 | tr = img.crop((w - crop_w, 0, w, crop_h)) 455 | bl = img.crop((0, h - crop_h, crop_w, h)) 456 | br = img.crop((w - crop_w, h - crop_h, w, h)) 457 | center = center_crop(img, (crop_h, crop_w)) 458 | return (tl, tr, bl, br, center) 459 | 460 | 461 | def ten_crop(img, size, vertical_flip=False): 462 | """Crop the given PIL Image into four corners and the central crop plus the 463 | flipped version of these (horizontal flipping is used by default). 464 | 465 | .. Note:: 466 | This transform returns a tuple of images and there may be a 467 | mismatch in the number of inputs and targets your ``Dataset`` returns. 468 | 469 | Args: 470 | size (sequence or int): Desired output size of the crop. If size is an 471 | int instead of sequence like (h, w), a square crop (size, size) is 472 | made. 473 | vertical_flip (bool): Use vertical flipping instead of horizontal 474 | 475 | Returns: 476 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 477 | br_flip, center_flip) corresponding top left, top right, 478 | bottom left, bottom right and center crop and same for the 479 | flipped image. 480 | """ 481 | if isinstance(size, numbers.Number): 482 | size = (int(size), int(size)) 483 | else: 484 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 485 | 486 | first_five = five_crop(img, size) 487 | 488 | if vertical_flip: 489 | img = vflip(img) 490 | else: 491 | img = hflip(img) 492 | 493 | second_five = five_crop(img, size) 494 | return first_five + second_five 495 | 496 | 497 | def adjust_brightness(img, brightness_factor): 498 | """Adjust brightness of an Image. 499 | 500 | Args: 501 | img (PIL Image): PIL Image to be adjusted. 502 | brightness_factor (float): How much to adjust the brightness. Can be 503 | any non negative number. 0 gives a black image, 1 gives the 504 | original image while 2 increases the brightness by a factor of 2. 505 | 506 | Returns: 507 | PIL Image: Brightness adjusted image. 508 | """ 509 | if not _is_pil_image(img): 510 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 511 | 512 | enhancer = ImageEnhance.Brightness(img) 513 | img = enhancer.enhance(brightness_factor) 514 | return img 515 | 516 | 517 | def adjust_contrast(img, contrast_factor): 518 | """Adjust contrast of an Image. 519 | 520 | Args: 521 | img (PIL Image): PIL Image to be adjusted. 522 | contrast_factor (float): How much to adjust the contrast. Can be any 523 | non negative number. 0 gives a solid gray image, 1 gives the 524 | original image while 2 increases the contrast by a factor of 2. 525 | 526 | Returns: 527 | PIL Image: Contrast adjusted image. 528 | """ 529 | if not _is_pil_image(img): 530 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 531 | 532 | enhancer = ImageEnhance.Contrast(img) 533 | img = enhancer.enhance(contrast_factor) 534 | return img 535 | 536 | 537 | def adjust_saturation(img, saturation_factor): 538 | """Adjust color saturation of an image. 539 | 540 | Args: 541 | img (PIL Image): PIL Image to be adjusted. 542 | saturation_factor (float): How much to adjust the saturation. 0 will 543 | give a black and white image, 1 will give the original image while 544 | 2 will enhance the saturation by a factor of 2. 545 | 546 | Returns: 547 | PIL Image: Saturation adjusted image. 548 | """ 549 | if not _is_pil_image(img): 550 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 551 | 552 | enhancer = ImageEnhance.Color(img) 553 | img = enhancer.enhance(saturation_factor) 554 | return img 555 | 556 | 557 | def adjust_hue(img, hue_factor): 558 | """Adjust hue of an image. 559 | 560 | The image hue is adjusted by converting the image to HSV and 561 | cyclically shifting the intensities in the hue channel (H). 562 | The image is then converted back to original image mode. 563 | 564 | `hue_factor` is the amount of shift in H channel and must be in the 565 | interval `[-0.5, 0.5]`. 566 | 567 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 568 | 569 | Args: 570 | img (PIL Image): PIL Image to be adjusted. 571 | hue_factor (float): How much to shift the hue channel. Should be in 572 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 573 | HSV space in positive and negative direction respectively. 574 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 575 | with complementary colors while 0 gives the original image. 576 | 577 | Returns: 578 | PIL Image: Hue adjusted image. 579 | """ 580 | if not(-0.5 <= hue_factor <= 0.5): 581 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 582 | 583 | if not _is_pil_image(img): 584 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 585 | 586 | input_mode = img.mode 587 | if input_mode in {'L', '1', 'I', 'F'}: 588 | return img 589 | 590 | h, s, v = img.convert('HSV').split() 591 | 592 | np_h = np.array(h, dtype=np.uint8) 593 | # uint8 addition take cares of rotation across boundaries 594 | with np.errstate(over='ignore'): 595 | np_h += np.uint8(hue_factor * 255) 596 | h = Image.fromarray(np_h, 'L') 597 | 598 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 599 | return img 600 | 601 | 602 | def adjust_gamma(img, gamma, gain=1): 603 | """Perform gamma correction on an image. 604 | 605 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 606 | based on the following equation: 607 | 608 | I_out = 255 * gain * ((I_in / 255) ** gamma) 609 | 610 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 611 | 612 | Args: 613 | img (PIL Image): PIL Image to be adjusted. 614 | gamma (float): Non negative real number. gamma larger than 1 make the 615 | shadows darker, while gamma smaller than 1 make dark regions 616 | lighter. 617 | gain (float): The constant multiplier. 618 | """ 619 | if not _is_pil_image(img): 620 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 621 | 622 | if gamma < 0: 623 | raise ValueError('Gamma should be a non-negative real number') 624 | 625 | input_mode = img.mode 626 | img = img.convert('RGB') 627 | 628 | gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 629 | img = img.point(gamma_map) # use PIL's point-function to accelerate this part 630 | 631 | img = img.convert(input_mode) 632 | return img 633 | 634 | 635 | def rotate(img, angle, resample=False, expand=False, center=None): 636 | """Rotate the image by angle. 637 | 638 | 639 | Args: 640 | img (PIL Image): PIL Image to be rotated. 641 | angle ({float, int}): In degrees degrees counter clockwise order. 642 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 643 | An optional resampling filter. 644 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 645 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 646 | expand (bool, optional): Optional expansion flag. 647 | If true, expands the output image to make it large enough to hold the entire rotated image. 648 | If false or omitted, make the output image the same size as the input image. 649 | Note that the expand flag assumes rotation around the center and no translation. 650 | center (2-tuple, optional): Optional center of rotation. 651 | Origin is the upper left corner. 652 | Default is the center of the image. 653 | """ 654 | 655 | if not _is_pil_image(img): 656 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 657 | 658 | return img.rotate(angle, resample, expand, center) 659 | 660 | 661 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear): 662 | # Helper method to compute inverse matrix for affine transformation 663 | 664 | # As it is explained in PIL.Image.rotate 665 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 666 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 667 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 668 | # RSS is rotation with scale and shear matrix 669 | # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] 670 | # [ sin(a)*scale cos(a + shear)*scale 0] 671 | # [ 0 0 1] 672 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 673 | 674 | angle = math.radians(angle) 675 | shear = math.radians(shear) 676 | scale = 1.0 / scale 677 | 678 | # Inverted rotation matrix with scale and shear 679 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) 680 | matrix = [ 681 | math.cos(angle + shear), math.sin(angle + shear), 0, 682 | -math.sin(angle), math.cos(angle), 0 683 | ] 684 | matrix = [scale / d * m for m in matrix] 685 | 686 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 687 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) 688 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) 689 | 690 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 691 | matrix[2] += center[0] 692 | matrix[5] += center[1] 693 | return matrix 694 | 695 | 696 | def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): 697 | """Apply affine transformation on the image keeping image center invariant 698 | 699 | Args: 700 | img (PIL Image): PIL Image to be rotated. 701 | angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction. 702 | translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) 703 | scale (float): overall scale 704 | shear (float): shear angle value in degrees between -180 to 180, clockwise direction. 705 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 706 | An optional resampling filter. 707 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 708 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 709 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 710 | """ 711 | if not _is_pil_image(img): 712 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 713 | 714 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 715 | "Argument translate should be a list or tuple of length 2" 716 | 717 | assert scale > 0.0, "Argument scale should be positive" 718 | 719 | output_size = img.size 720 | center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) 721 | matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) 722 | kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {} 723 | return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) 724 | 725 | 726 | def to_grayscale(img, num_output_channels=1): 727 | """Convert image to grayscale version of image. 728 | 729 | Args: 730 | img (PIL Image): Image to be converted to grayscale. 731 | 732 | Returns: 733 | PIL Image: Grayscale version of the image. 734 | if num_output_channels == 1 : returned image is single channel 735 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 736 | """ 737 | if not _is_pil_image(img): 738 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 739 | 740 | if num_output_channels == 1: 741 | img = img.convert('L') 742 | elif num_output_channels == 3: 743 | img = img.convert('L') 744 | np_img = np.array(img, dtype=np.uint8) 745 | np_img = np.dstack([np_img, np_img, np_img]) 746 | img = Image.fromarray(np_img, 'RGB') 747 | else: 748 | raise ValueError('num_output_channels should be either 1 or 3') 749 | 750 | return img 751 | -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | from . import functional as F 17 | 18 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", 19 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", 20 | "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", 21 | "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "Randomswap"] 22 | 23 | _pil_interpolation_to_str = { 24 | Image.NEAREST: 'PIL.Image.NEAREST', 25 | Image.BILINEAR: 'PIL.Image.BILINEAR', 26 | Image.BICUBIC: 'PIL.Image.BICUBIC', 27 | Image.LANCZOS: 'PIL.Image.LANCZOS', 28 | } 29 | 30 | 31 | class Compose(object): 32 | """Composes several transforms together. 33 | 34 | Args: 35 | transforms (list of ``Transform`` objects): list of transforms to compose. 36 | 37 | Example: 38 | >>> transforms.Compose([ 39 | >>> transforms.CenterCrop(10), 40 | >>> transforms.ToTensor(), 41 | >>> ]) 42 | """ 43 | 44 | def __init__(self, transforms): 45 | self.transforms = transforms 46 | 47 | def __call__(self, img): 48 | for t in self.transforms: 49 | img = t(img) 50 | return img 51 | 52 | def __repr__(self): 53 | format_string = self.__class__.__name__ + '(' 54 | for t in self.transforms: 55 | format_string += '\n' 56 | format_string += ' {0}'.format(t) 57 | format_string += '\n)' 58 | return format_string 59 | 60 | 61 | class ToTensor(object): 62 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 63 | 64 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 65 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 66 | """ 67 | 68 | def __call__(self, pic): 69 | """ 70 | Args: 71 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 72 | 73 | Returns: 74 | Tensor: Converted image. 75 | """ 76 | return F.to_tensor(pic) 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '()' 80 | 81 | 82 | class ToPILImage(object): 83 | """Convert a tensor or an ndarray to PIL Image. 84 | 85 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 86 | H x W x C to a PIL Image while preserving the value range. 87 | 88 | Args: 89 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 90 | If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 91 | 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 92 | 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 93 | 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, 94 | ``int``, ``float``, ``short``). 95 | 96 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 97 | """ 98 | def __init__(self, mode=None): 99 | self.mode = mode 100 | 101 | def __call__(self, pic): 102 | """ 103 | Args: 104 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 105 | 106 | Returns: 107 | PIL Image: Image converted to PIL Image. 108 | 109 | """ 110 | return F.to_pil_image(pic, self.mode) 111 | 112 | def __repr__(self): 113 | format_string = self.__class__.__name__ + '(' 114 | if self.mode is not None: 115 | format_string += 'mode={0}'.format(self.mode) 116 | format_string += ')' 117 | return format_string 118 | 119 | 120 | class Normalize(object): 121 | """Normalize a tensor image with mean and standard deviation. 122 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 123 | will normalize each channel of the input ``torch.*Tensor`` i.e. 124 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 125 | 126 | Args: 127 | mean (sequence): Sequence of means for each channel. 128 | std (sequence): Sequence of standard deviations for each channel. 129 | """ 130 | 131 | def __init__(self, mean, std): 132 | self.mean = mean 133 | self.std = std 134 | 135 | def __call__(self, tensor): 136 | """ 137 | Args: 138 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 139 | 140 | Returns: 141 | Tensor: Normalized Tensor image. 142 | """ 143 | return F.normalize(tensor, self.mean, self.std) 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 147 | 148 | 149 | class Randomswap(object): 150 | def __init__(self, size): 151 | self.size = size 152 | if isinstance(size, numbers.Number): 153 | self.size = (int(size), int(size)) 154 | else: 155 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 156 | self.size = size 157 | 158 | def __call__(self, img): 159 | return F.swap(img, self.size) 160 | 161 | def __repr__(self): 162 | return self.__class__.__name__ + '(size={0})'.format(self.size) 163 | 164 | 165 | class Resize(object): 166 | """Resize the input PIL Image to the given size. 167 | 168 | Args: 169 | size (sequence or int): Desired output size. If size is a sequence like 170 | (h, w), output size will be matched to this. If size is an int, 171 | smaller edge of the image will be matched to this number. 172 | i.e, if height > width, then image will be rescaled to 173 | (size * height / width, size) 174 | interpolation (int, optional): Desired interpolation. Default is 175 | ``PIL.Image.BILINEAR`` 176 | """ 177 | 178 | def __init__(self, size, interpolation=Image.BILINEAR): 179 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 180 | self.size = size 181 | self.interpolation = interpolation 182 | 183 | def __call__(self, img): 184 | """ 185 | Args: 186 | img (PIL Image): Image to be scaled. 187 | 188 | Returns: 189 | PIL Image: Rescaled image. 190 | """ 191 | return F.resize(img, self.size, self.interpolation) 192 | 193 | def __repr__(self): 194 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 195 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 196 | 197 | 198 | class Scale(Resize): 199 | """ 200 | Note: This transform is deprecated in favor of Resize. 201 | """ 202 | def __init__(self, *args, **kwargs): 203 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 204 | "please use transforms.Resize instead.") 205 | super(Scale, self).__init__(*args, **kwargs) 206 | 207 | 208 | class CenterCrop(object): 209 | """Crops the given PIL Image at the center. 210 | 211 | Args: 212 | size (sequence or int): Desired output size of the crop. If size is an 213 | int instead of sequence like (h, w), a square crop (size, size) is 214 | made. 215 | """ 216 | 217 | def __init__(self, size): 218 | if isinstance(size, numbers.Number): 219 | self.size = (int(size), int(size)) 220 | else: 221 | self.size = size 222 | 223 | def __call__(self, img): 224 | """ 225 | Args: 226 | img (PIL Image): Image to be cropped. 227 | 228 | Returns: 229 | PIL Image: Cropped image. 230 | """ 231 | return F.center_crop(img, self.size) 232 | 233 | def __repr__(self): 234 | return self.__class__.__name__ + '(size={0})'.format(self.size) 235 | 236 | 237 | class Pad(object): 238 | """Pad the given PIL Image on all sides with the given "pad" value. 239 | 240 | Args: 241 | padding (int or tuple): Padding on each border. If a single int is provided this 242 | is used to pad all borders. If tuple of length 2 is provided this is the padding 243 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 244 | this is the padding for the left, top, right and bottom borders 245 | respectively. 246 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 247 | length 3, it is used to fill R, G, B channels respectively. 248 | This value is only used when the padding_mode is constant 249 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 250 | constant: pads with a constant value, this value is specified with fill 251 | edge: pads with the last value at the edge of the image 252 | reflect: pads with reflection of image (without repeating the last value on the edge) 253 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 254 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 255 | symmetric: pads with reflection of image (repeating the last value on the edge) 256 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 257 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 258 | """ 259 | 260 | def __init__(self, padding, fill=0, padding_mode='constant'): 261 | assert isinstance(padding, (numbers.Number, tuple)) 262 | assert isinstance(fill, (numbers.Number, str, tuple)) 263 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 264 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 265 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 266 | "{} element tuple".format(len(padding))) 267 | 268 | self.padding = padding 269 | self.fill = fill 270 | self.padding_mode = padding_mode 271 | 272 | def __call__(self, img): 273 | """ 274 | Args: 275 | img (PIL Image): Image to be padded. 276 | 277 | Returns: 278 | PIL Image: Padded image. 279 | """ 280 | return F.pad(img, self.padding, self.fill, self.padding_mode) 281 | 282 | def __repr__(self): 283 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ 284 | format(self.padding, self.fill, self.padding_mode) 285 | 286 | 287 | class Lambda(object): 288 | """Apply a user-defined lambda as a transform. 289 | 290 | Args: 291 | lambd (function): Lambda/function to be used for transform. 292 | """ 293 | 294 | def __init__(self, lambd): 295 | assert isinstance(lambd, types.LambdaType) 296 | self.lambd = lambd 297 | 298 | def __call__(self, img): 299 | return self.lambd(img) 300 | 301 | def __repr__(self): 302 | return self.__class__.__name__ + '()' 303 | 304 | 305 | class RandomTransforms(object): 306 | """Base class for a list of transformations with randomness 307 | 308 | Args: 309 | transforms (list or tuple): list of transformations 310 | """ 311 | 312 | def __init__(self, transforms): 313 | assert isinstance(transforms, (list, tuple)) 314 | self.transforms = transforms 315 | 316 | def __call__(self, *args, **kwargs): 317 | raise NotImplementedError() 318 | 319 | def __repr__(self): 320 | format_string = self.__class__.__name__ + '(' 321 | for t in self.transforms: 322 | format_string += '\n' 323 | format_string += ' {0}'.format(t) 324 | format_string += '\n)' 325 | return format_string 326 | 327 | 328 | class RandomApply(RandomTransforms): 329 | """Apply randomly a list of transformations with a given probability 330 | 331 | Args: 332 | transforms (list or tuple): list of transformations 333 | p (float): probability 334 | """ 335 | 336 | def __init__(self, transforms, p=0.5): 337 | super(RandomApply, self).__init__(transforms) 338 | self.p = p 339 | 340 | def __call__(self, img): 341 | if self.p < random.random(): 342 | return img 343 | for t in self.transforms: 344 | img = t(img) 345 | return img 346 | 347 | def __repr__(self): 348 | format_string = self.__class__.__name__ + '(' 349 | format_string += '\n p={}'.format(self.p) 350 | for t in self.transforms: 351 | format_string += '\n' 352 | format_string += ' {0}'.format(t) 353 | format_string += '\n)' 354 | return format_string 355 | 356 | 357 | class RandomOrder(RandomTransforms): 358 | """Apply a list of transformations in a random order 359 | """ 360 | def __call__(self, img): 361 | order = list(range(len(self.transforms))) 362 | random.shuffle(order) 363 | for i in order: 364 | img = self.transforms[i](img) 365 | return img 366 | 367 | 368 | class RandomChoice(RandomTransforms): 369 | """Apply single transformation randomly picked from a list 370 | """ 371 | def __call__(self, img): 372 | t = random.choice(self.transforms) 373 | return t(img) 374 | 375 | 376 | class RandomCrop(object): 377 | """Crop the given PIL Image at a random location. 378 | 379 | Args: 380 | size (sequence or int): Desired output size of the crop. If size is an 381 | int instead of sequence like (h, w), a square crop (size, size) is 382 | made. 383 | padding (int or sequence, optional): Optional padding on each border 384 | of the image. Default is 0, i.e no padding. If a sequence of length 385 | 4 is provided, it is used to pad left, top, right, bottom borders 386 | respectively. 387 | pad_if_needed (boolean): It will pad the image if smaller than the 388 | desired size to avoid raising an exception. 389 | """ 390 | 391 | def __init__(self, size, padding=0, pad_if_needed=False): 392 | if isinstance(size, numbers.Number): 393 | self.size = (int(size), int(size)) 394 | else: 395 | self.size = size 396 | self.padding = padding 397 | self.pad_if_needed = pad_if_needed 398 | 399 | @staticmethod 400 | def get_params(img, output_size): 401 | """Get parameters for ``crop`` for a random crop. 402 | 403 | Args: 404 | img (PIL Image): Image to be cropped. 405 | output_size (tuple): Expected output size of the crop. 406 | 407 | Returns: 408 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 409 | """ 410 | w, h = img.size 411 | th, tw = output_size 412 | if w == tw and h == th: 413 | return 0, 0, h, w 414 | 415 | i = random.randint(0, h - th) 416 | j = random.randint(0, w - tw) 417 | return i, j, th, tw 418 | 419 | def __call__(self, img): 420 | """ 421 | Args: 422 | img (PIL Image): Image to be cropped. 423 | 424 | Returns: 425 | PIL Image: Cropped image. 426 | """ 427 | if self.padding > 0: 428 | img = F.pad(img, self.padding) 429 | 430 | # pad the width if needed 431 | if self.pad_if_needed and img.size[0] < self.size[1]: 432 | img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) 433 | # pad the height if needed 434 | if self.pad_if_needed and img.size[1] < self.size[0]: 435 | img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) 436 | 437 | i, j, h, w = self.get_params(img, self.size) 438 | 439 | return F.crop(img, i, j, h, w) 440 | 441 | def __repr__(self): 442 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 443 | 444 | 445 | class RandomHorizontalFlip(object): 446 | """Horizontally flip the given PIL Image randomly with a given probability. 447 | 448 | Args: 449 | p (float): probability of the image being flipped. Default value is 0.5 450 | """ 451 | 452 | def __init__(self, p=0.5): 453 | self.p = p 454 | 455 | def __call__(self, img): 456 | """ 457 | Args: 458 | img (PIL Image): Image to be flipped. 459 | 460 | Returns: 461 | PIL Image: Randomly flipped image. 462 | """ 463 | if random.random() < self.p: 464 | return F.hflip(img) 465 | return img 466 | 467 | def __repr__(self): 468 | return self.__class__.__name__ + '(p={})'.format(self.p) 469 | 470 | 471 | class RandomVerticalFlip(object): 472 | """Vertically flip the given PIL Image randomly with a given probability. 473 | 474 | Args: 475 | p (float): probability of the image being flipped. Default value is 0.5 476 | """ 477 | 478 | def __init__(self, p=0.5): 479 | self.p = p 480 | 481 | def __call__(self, img): 482 | """ 483 | Args: 484 | img (PIL Image): Image to be flipped. 485 | 486 | Returns: 487 | PIL Image: Randomly flipped image. 488 | """ 489 | if random.random() < self.p: 490 | return F.vflip(img) 491 | return img 492 | 493 | def __repr__(self): 494 | return self.__class__.__name__ + '(p={})'.format(self.p) 495 | 496 | 497 | class RandomResizedCrop(object): 498 | """Crop the given PIL Image to random size and aspect ratio. 499 | 500 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 501 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 502 | is finally resized to given size. 503 | This is popularly used to train the Inception networks. 504 | 505 | Args: 506 | size: expected output size of each edge 507 | scale: range of size of the origin size cropped 508 | ratio: range of aspect ratio of the origin aspect ratio cropped 509 | interpolation: Default: PIL.Image.BILINEAR 510 | """ 511 | 512 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 513 | self.size = (size, size) 514 | self.interpolation = interpolation 515 | self.scale = scale 516 | self.ratio = ratio 517 | 518 | @staticmethod 519 | def get_params(img, scale, ratio): 520 | """Get parameters for ``crop`` for a random sized crop. 521 | 522 | Args: 523 | img (PIL Image): Image to be cropped. 524 | scale (tuple): range of size of the origin size cropped 525 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 526 | 527 | Returns: 528 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 529 | sized crop. 530 | """ 531 | for attempt in range(10): 532 | area = img.size[0] * img.size[1] 533 | target_area = random.uniform(*scale) * area 534 | aspect_ratio = random.uniform(*ratio) 535 | 536 | w = int(round(math.sqrt(target_area * aspect_ratio))) 537 | h = int(round(math.sqrt(target_area / aspect_ratio))) 538 | 539 | if random.random() < 0.5: 540 | w, h = h, w 541 | 542 | if w <= img.size[0] and h <= img.size[1]: 543 | i = random.randint(0, img.size[1] - h) 544 | j = random.randint(0, img.size[0] - w) 545 | return i, j, h, w 546 | 547 | # Fallback 548 | w = min(img.size[0], img.size[1]) 549 | i = (img.size[1] - w) // 2 550 | j = (img.size[0] - w) // 2 551 | return i, j, w, w 552 | 553 | def __call__(self, img): 554 | """ 555 | Args: 556 | img (PIL Image): Image to be cropped and resized. 557 | 558 | Returns: 559 | PIL Image: Randomly cropped and resized image. 560 | """ 561 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 562 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 563 | 564 | def __repr__(self): 565 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 566 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 567 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 568 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 569 | format_string += ', interpolation={0})'.format(interpolate_str) 570 | return format_string 571 | 572 | 573 | class RandomSizedCrop(RandomResizedCrop): 574 | """ 575 | Note: This transform is deprecated in favor of RandomResizedCrop. 576 | """ 577 | def __init__(self, *args, **kwargs): 578 | warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + 579 | "please use transforms.RandomResizedCrop instead.") 580 | super(RandomSizedCrop, self).__init__(*args, **kwargs) 581 | 582 | 583 | class FiveCrop(object): 584 | """Crop the given PIL Image into four corners and the central crop 585 | 586 | .. Note:: 587 | This transform returns a tuple of images and there may be a mismatch in the number of 588 | inputs and targets your Dataset returns. See below for an example of how to deal with 589 | this. 590 | 591 | Args: 592 | size (sequence or int): Desired output size of the crop. If size is an ``int`` 593 | instead of sequence like (h, w), a square crop of size (size, size) is made. 594 | 595 | Example: 596 | >>> transform = Compose([ 597 | >>> FiveCrop(size), # this is a list of PIL Images 598 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 599 | >>> ]) 600 | >>> #In your test loop you can do the following: 601 | >>> input, target = batch # input is a 5d tensor, target is 2d 602 | >>> bs, ncrops, c, h, w = input.size() 603 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 604 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 605 | """ 606 | 607 | def __init__(self, size): 608 | self.size = size 609 | if isinstance(size, numbers.Number): 610 | self.size = (int(size), int(size)) 611 | else: 612 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 613 | self.size = size 614 | 615 | def __call__(self, img): 616 | return F.five_crop(img, self.size) 617 | 618 | def __repr__(self): 619 | return self.__class__.__name__ + '(size={0})'.format(self.size) 620 | 621 | 622 | class TenCrop(object): 623 | """Crop the given PIL Image into four corners and the central crop plus the flipped version of 624 | these (horizontal flipping is used by default) 625 | 626 | .. Note:: 627 | This transform returns a tuple of images and there may be a mismatch in the number of 628 | inputs and targets your Dataset returns. See below for an example of how to deal with 629 | this. 630 | 631 | Args: 632 | size (sequence or int): Desired output size of the crop. If size is an 633 | int instead of sequence like (h, w), a square crop (size, size) is 634 | made. 635 | vertical_flip(bool): Use vertical flipping instead of horizontal 636 | 637 | Example: 638 | >>> transform = Compose([ 639 | >>> TenCrop(size), # this is a list of PIL Images 640 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 641 | >>> ]) 642 | >>> #In your test loop you can do the following: 643 | >>> input, target = batch # input is a 5d tensor, target is 2d 644 | >>> bs, ncrops, c, h, w = input.size() 645 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 646 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 647 | """ 648 | 649 | def __init__(self, size, vertical_flip=False): 650 | self.size = size 651 | if isinstance(size, numbers.Number): 652 | self.size = (int(size), int(size)) 653 | else: 654 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 655 | self.size = size 656 | self.vertical_flip = vertical_flip 657 | 658 | def __call__(self, img): 659 | return F.ten_crop(img, self.size, self.vertical_flip) 660 | 661 | def __repr__(self): 662 | return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) 663 | 664 | 665 | class LinearTransformation(object): 666 | """Transform a tensor image with a square transformation matrix computed 667 | offline. 668 | 669 | Given transformation_matrix, will flatten the torch.*Tensor, compute the dot 670 | product with the transformation matrix and reshape the tensor to its 671 | original shape. 672 | 673 | Applications: 674 | - whitening: zero-center the data, compute the data covariance matrix 675 | [D x D] with np.dot(X.T, X), perform SVD on this matrix and 676 | pass it as transformation_matrix. 677 | 678 | Args: 679 | transformation_matrix (Tensor): tensor [D x D], D = C x H x W 680 | """ 681 | 682 | def __init__(self, transformation_matrix): 683 | if transformation_matrix.size(0) != transformation_matrix.size(1): 684 | raise ValueError("transformation_matrix should be square. Got " + 685 | "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) 686 | self.transformation_matrix = transformation_matrix 687 | 688 | def __call__(self, tensor): 689 | """ 690 | Args: 691 | tensor (Tensor): Tensor image of size (C, H, W) to be whitened. 692 | 693 | Returns: 694 | Tensor: Transformed image. 695 | """ 696 | if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): 697 | raise ValueError("tensor and transformation matrix have incompatible shape." + 698 | "[{} x {} x {}] != ".format(*tensor.size()) + 699 | "{}".format(self.transformation_matrix.size(0))) 700 | flat_tensor = tensor.view(1, -1) 701 | transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) 702 | tensor = transformed_tensor.view(tensor.size()) 703 | return tensor 704 | 705 | def __repr__(self): 706 | format_string = self.__class__.__name__ + '(' 707 | format_string += (str(self.transformation_matrix.numpy().tolist()) + ')') 708 | return format_string 709 | 710 | 711 | class ColorJitter(object): 712 | """Randomly change the brightness, contrast and saturation of an image. 713 | 714 | Args: 715 | brightness (float): How much to jitter brightness. brightness_factor 716 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 717 | contrast (float): How much to jitter contrast. contrast_factor 718 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 719 | saturation (float): How much to jitter saturation. saturation_factor 720 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 721 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 722 | [-hue, hue]. Should be >=0 and <= 0.5. 723 | """ 724 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 725 | self.brightness = brightness 726 | self.contrast = contrast 727 | self.saturation = saturation 728 | self.hue = hue 729 | 730 | @staticmethod 731 | def get_params(brightness, contrast, saturation, hue): 732 | """Get a randomized transform to be applied on image. 733 | 734 | Arguments are same as that of __init__. 735 | 736 | Returns: 737 | Transform which randomly adjusts brightness, contrast and 738 | saturation in a random order. 739 | """ 740 | transforms = [] 741 | if brightness > 0: 742 | brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) 743 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 744 | 745 | if contrast > 0: 746 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) 747 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 748 | 749 | if saturation > 0: 750 | saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) 751 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 752 | 753 | if hue > 0: 754 | hue_factor = random.uniform(-hue, hue) 755 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 756 | 757 | random.shuffle(transforms) 758 | transform = Compose(transforms) 759 | 760 | return transform 761 | 762 | def __call__(self, img): 763 | """ 764 | Args: 765 | img (PIL Image): Input image. 766 | 767 | Returns: 768 | PIL Image: Color jittered image. 769 | """ 770 | transform = self.get_params(self.brightness, self.contrast, 771 | self.saturation, self.hue) 772 | return transform(img) 773 | 774 | def __repr__(self): 775 | format_string = self.__class__.__name__ + '(' 776 | format_string += 'brightness={0}'.format(self.brightness) 777 | format_string += ', contrast={0}'.format(self.contrast) 778 | format_string += ', saturation={0}'.format(self.saturation) 779 | format_string += ', hue={0})'.format(self.hue) 780 | return format_string 781 | 782 | 783 | class RandomRotation(object): 784 | """Rotate the image by angle. 785 | 786 | Args: 787 | degrees (sequence or float or int): Range of degrees to select from. 788 | If degrees is a number instead of sequence like (min, max), the range of degrees 789 | will be (-degrees, +degrees). 790 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 791 | An optional resampling filter. 792 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 793 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 794 | expand (bool, optional): Optional expansion flag. 795 | If true, expands the output to make it large enough to hold the entire rotated image. 796 | If false or omitted, make the output image the same size as the input image. 797 | Note that the expand flag assumes rotation around the center and no translation. 798 | center (2-tuple, optional): Optional center of rotation. 799 | Origin is the upper left corner. 800 | Default is the center of the image. 801 | """ 802 | 803 | def __init__(self, degrees, resample=False, expand=False, center=None): 804 | if isinstance(degrees, numbers.Number): 805 | if degrees < 0: 806 | raise ValueError("If degrees is a single number, it must be positive.") 807 | self.degrees = (-degrees, degrees) 808 | else: 809 | if len(degrees) != 2: 810 | raise ValueError("If degrees is a sequence, it must be of len 2.") 811 | self.degrees = degrees 812 | 813 | self.resample = resample 814 | self.expand = expand 815 | self.center = center 816 | 817 | @staticmethod 818 | def get_params(degrees): 819 | """Get parameters for ``rotate`` for a random rotation. 820 | 821 | Returns: 822 | sequence: params to be passed to ``rotate`` for random rotation. 823 | """ 824 | angle = random.uniform(degrees[0], degrees[1]) 825 | 826 | return angle 827 | 828 | def __call__(self, img): 829 | """ 830 | img (PIL Image): Image to be rotated. 831 | 832 | Returns: 833 | PIL Image: Rotated image. 834 | """ 835 | 836 | angle = self.get_params(self.degrees) 837 | 838 | return F.rotate(img, angle, self.resample, self.expand, self.center) 839 | 840 | def __repr__(self): 841 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 842 | format_string += ', resample={0}'.format(self.resample) 843 | format_string += ', expand={0}'.format(self.expand) 844 | if self.center is not None: 845 | format_string += ', center={0}'.format(self.center) 846 | format_string += ')' 847 | return format_string 848 | 849 | 850 | class RandomAffine(object): 851 | """Random affine transformation of the image keeping center invariant 852 | 853 | Args: 854 | degrees (sequence or float or int): Range of degrees to select from. 855 | If degrees is a number instead of sequence like (min, max), the range of degrees 856 | will be (-degrees, +degrees). Set to 0 to desactivate rotations. 857 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 858 | and vertical translations. For example translate=(a, b), then horizontal shift 859 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 860 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 861 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 862 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 863 | shear (sequence or float or int, optional): Range of degrees to select from. 864 | If degrees is a number instead of sequence like (min, max), the range of degrees 865 | will be (-degrees, +degrees). Will not apply shear by default 866 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 867 | An optional resampling filter. 868 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 869 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 870 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 871 | """ 872 | 873 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 874 | if isinstance(degrees, numbers.Number): 875 | if degrees < 0: 876 | raise ValueError("If degrees is a single number, it must be positive.") 877 | self.degrees = (-degrees, degrees) 878 | else: 879 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 880 | "degrees should be a list or tuple and it must be of length 2." 881 | self.degrees = degrees 882 | 883 | if translate is not None: 884 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 885 | "translate should be a list or tuple and it must be of length 2." 886 | for t in translate: 887 | if not (0.0 <= t <= 1.0): 888 | raise ValueError("translation values should be between 0 and 1") 889 | self.translate = translate 890 | 891 | if scale is not None: 892 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 893 | "scale should be a list or tuple and it must be of length 2." 894 | for s in scale: 895 | if s <= 0: 896 | raise ValueError("scale values should be positive") 897 | self.scale = scale 898 | 899 | if shear is not None: 900 | if isinstance(shear, numbers.Number): 901 | if shear < 0: 902 | raise ValueError("If shear is a single number, it must be positive.") 903 | self.shear = (-shear, shear) 904 | else: 905 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 906 | "shear should be a list or tuple and it must be of length 2." 907 | self.shear = shear 908 | else: 909 | self.shear = shear 910 | 911 | self.resample = resample 912 | self.fillcolor = fillcolor 913 | 914 | @staticmethod 915 | def get_params(degrees, translate, scale_ranges, shears, img_size): 916 | """Get parameters for affine transformation 917 | 918 | Returns: 919 | sequence: params to be passed to the affine transformation 920 | """ 921 | angle = random.uniform(degrees[0], degrees[1]) 922 | if translate is not None: 923 | max_dx = translate[0] * img_size[0] 924 | max_dy = translate[1] * img_size[1] 925 | translations = (np.round(random.uniform(-max_dx, max_dx)), 926 | np.round(random.uniform(-max_dy, max_dy))) 927 | else: 928 | translations = (0, 0) 929 | 930 | if scale_ranges is not None: 931 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 932 | else: 933 | scale = 1.0 934 | 935 | if shears is not None: 936 | shear = random.uniform(shears[0], shears[1]) 937 | else: 938 | shear = 0.0 939 | 940 | return angle, translations, scale, shear 941 | 942 | def __call__(self, img): 943 | """ 944 | img (PIL Image): Image to be transformed. 945 | 946 | Returns: 947 | PIL Image: Affine transformed image. 948 | """ 949 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 950 | return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) 951 | 952 | def __repr__(self): 953 | s = '{name}(degrees={degrees}' 954 | if self.translate is not None: 955 | s += ', translate={translate}' 956 | if self.scale is not None: 957 | s += ', scale={scale}' 958 | if self.shear is not None: 959 | s += ', shear={shear}' 960 | if self.resample > 0: 961 | s += ', resample={resample}' 962 | if self.fillcolor != 0: 963 | s += ', fillcolor={fillcolor}' 964 | s += ')' 965 | d = dict(self.__dict__) 966 | d['resample'] = _pil_interpolation_to_str[d['resample']] 967 | return s.format(name=self.__class__.__name__, **d) 968 | 969 | 970 | class Grayscale(object): 971 | """Convert image to grayscale. 972 | 973 | Args: 974 | num_output_channels (int): (1 or 3) number of channels desired for output image 975 | 976 | Returns: 977 | PIL Image: Grayscale version of the input. 978 | - If num_output_channels == 1 : returned image is single channel 979 | - If num_output_channels == 3 : returned image is 3 channel with r == g == b 980 | 981 | """ 982 | 983 | def __init__(self, num_output_channels=1): 984 | self.num_output_channels = num_output_channels 985 | 986 | def __call__(self, img): 987 | """ 988 | Args: 989 | img (PIL Image): Image to be converted to grayscale. 990 | 991 | Returns: 992 | PIL Image: Randomly grayscaled image. 993 | """ 994 | return F.to_grayscale(img, num_output_channels=self.num_output_channels) 995 | 996 | def __repr__(self): 997 | return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) 998 | 999 | 1000 | class RandomGrayscale(object): 1001 | """Randomly convert image to grayscale with a probability of p (default 0.1). 1002 | 1003 | Args: 1004 | p (float): probability that image should be converted to grayscale. 1005 | 1006 | Returns: 1007 | PIL Image: Grayscale version of the input image with probability p and unchanged 1008 | with probability (1-p). 1009 | - If input image is 1 channel: grayscale version is 1 channel 1010 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 1011 | 1012 | """ 1013 | 1014 | def __init__(self, p=0.1): 1015 | self.p = p 1016 | 1017 | def __call__(self, img): 1018 | """ 1019 | Args: 1020 | img (PIL Image): Image to be converted to grayscale. 1021 | 1022 | Returns: 1023 | PIL Image: Randomly grayscaled image. 1024 | """ 1025 | num_output_channels = 1 if img.mode == 'L' else 3 1026 | if random.random() < self.p: 1027 | return F.to_grayscale(img, num_output_channels=num_output_channels) 1028 | return img 1029 | 1030 | def __repr__(self): 1031 | return self.__class__.__name__ + '(p={0})'.format(self.p) 1032 | -------------------------------------------------------------------------------- /utils/Asoftmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | import math 7 | 8 | import pdb 9 | 10 | class AngleLoss(nn.Module): 11 | def __init__(self, gamma=0): 12 | super(AngleLoss, self).__init__() 13 | self.gamma = gamma 14 | self.it = 0 15 | self.LambdaMin = 50.0 16 | self.LambdaMax = 1500.0 17 | self.lamb = 1500.0 18 | 19 | def forward(self, input, target, decay=None): 20 | self.it += 1 21 | cos_theta,phi_theta = input 22 | target = target.view(-1,1) #size=(B,1) 23 | 24 | index = cos_theta.data * 0.0 #size=(B,Classnum) 25 | index.scatter_(1,target.data.view(-1,1),1) 26 | index = index.byte() 27 | index = Variable(index) 28 | 29 | if decay is None: 30 | self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it )) 31 | else: 32 | self.LambdaMax *= decay 33 | self.lamb = max(self.LambdaMin, self.LambdaMax) 34 | output = cos_theta * 1.0 #size=(B,Classnum) 35 | output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb) 36 | output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb) 37 | 38 | logpt = F.log_softmax(output, 1) 39 | logpt = logpt.gather(1,target) 40 | logpt = logpt.view(-1) 41 | pt = Variable(logpt.data.exp()) 42 | 43 | loss = -1 * (1-pt)**self.gamma * logpt 44 | loss = loss.mean() 45 | 46 | return loss 47 | 48 | -------------------------------------------------------------------------------- /utils/__pycache__/Asoftmax_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/Asoftmax_loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoaugment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/autoaugment.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_DCL.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/dataset_DCL.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/eval_model.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/train_model.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JDAI-CV/DCL/895081603dc68aeeda07301dbddf32b364ecacf7/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | Example: 9 | >>> policy = ImageNetPolicy() 10 | >>> transformed = policy(image) 11 | Example as a PyTorch Transform: 12 | >>> transform=transforms.Compose([ 13 | >>> transforms.Resize(256), 14 | >>> ImageNetPolicy(), 15 | >>> transforms.ToTensor()]) 16 | """ 17 | def __init__(self, fillcolor=(128, 128, 128)): 18 | self.policies = [ 19 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 20 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 21 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 22 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 23 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 24 | 25 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 26 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 27 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 28 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 29 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 30 | 31 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 32 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 33 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 34 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 36 | 37 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 38 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 39 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 40 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 41 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 42 | 43 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 44 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 45 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 46 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 47 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 48 | ] 49 | 50 | 51 | def __call__(self, img): 52 | policy_idx = random.randint(0, len(self.policies) - 1) 53 | return self.policies[policy_idx](img) 54 | 55 | def __repr__(self): 56 | return "AutoAugment ImageNet Policy" 57 | 58 | 59 | class CIFAR10Policy(object): 60 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 61 | Example: 62 | >>> policy = CIFAR10Policy() 63 | >>> transformed = policy(image) 64 | Example as a PyTorch Transform: 65 | >>> transform=transforms.Compose([ 66 | >>> transforms.Resize(256), 67 | >>> CIFAR10Policy(), 68 | >>> transforms.ToTensor()]) 69 | """ 70 | def __init__(self, fillcolor=(128, 128, 128)): 71 | self.policies = [ 72 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 73 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 74 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 75 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 76 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 77 | 78 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 79 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 80 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 81 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 82 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 83 | 84 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 85 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 86 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 87 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 88 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 89 | 90 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 91 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 92 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 93 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 94 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 95 | 96 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 97 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 98 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 99 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 100 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 101 | ] 102 | 103 | 104 | def __call__(self, img): 105 | policy_idx = random.randint(0, len(self.policies) - 1) 106 | return self.policies[policy_idx](img) 107 | 108 | def __repr__(self): 109 | return "AutoAugment CIFAR10 Policy" 110 | 111 | 112 | class SVHNPolicy(object): 113 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 114 | Example: 115 | >>> policy = SVHNPolicy() 116 | >>> transformed = policy(image) 117 | Example as a PyTorch Transform: 118 | >>> transform=transforms.Compose([ 119 | >>> transforms.Resize(256), 120 | >>> SVHNPolicy(), 121 | >>> transforms.ToTensor()]) 122 | """ 123 | def __init__(self, fillcolor=(128, 128, 128)): 124 | self.policies = [ 125 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 126 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 127 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 128 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 129 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 130 | 131 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 133 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 135 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 138 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 139 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 141 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 142 | 143 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 144 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 145 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 146 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 147 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 148 | 149 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 150 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 151 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 152 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 153 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 154 | ] 155 | 156 | 157 | def __call__(self, img): 158 | policy_idx = random.randint(0, len(self.policies) - 1) 159 | return self.policies[policy_idx](img) 160 | 161 | def __repr__(self): 162 | return "AutoAugment SVHN Policy" 163 | 164 | 165 | class SubPolicy(object): 166 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 167 | ranges = { 168 | "shearX": np.linspace(0, 0.3, 10), 169 | "shearY": np.linspace(0, 0.3, 10), 170 | "translateX": np.linspace(0, 150 / 331, 10), 171 | "translateY": np.linspace(0, 150 / 331, 10), 172 | "rotate": np.linspace(0, 30, 10), 173 | "color": np.linspace(0.0, 0.9, 10), 174 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 175 | "solarize": np.linspace(256, 0, 10), 176 | "contrast": np.linspace(0.0, 0.9, 10), 177 | "sharpness": np.linspace(0.0, 0.9, 10), 178 | "brightness": np.linspace(0.0, 0.9, 10), 179 | "autocontrast": [0] * 10, 180 | "equalize": [0] * 10, 181 | "invert": [0] * 10 182 | } 183 | 184 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 185 | def rotate_with_fill(img, magnitude): 186 | rot = img.convert("RGBA").rotate(magnitude) 187 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 188 | 189 | func = { 190 | "shearX": lambda img, magnitude: img.transform( 191 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 192 | Image.BICUBIC, fillcolor=fillcolor), 193 | "shearY": lambda img, magnitude: img.transform( 194 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 195 | Image.BICUBIC, fillcolor=fillcolor), 196 | "translateX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 198 | fillcolor=fillcolor), 199 | "translateY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 201 | fillcolor=fillcolor), 202 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 203 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 204 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 205 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 206 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 207 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 208 | 1 + magnitude * random.choice([-1, 1])), 209 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 210 | 1 + magnitude * random.choice([-1, 1])), 211 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 212 | 1 + magnitude * random.choice([-1, 1])), 213 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 214 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 215 | "invert": lambda img, magnitude: ImageOps.invert(img) 216 | } 217 | 218 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 219 | # operation1, ranges[operation1][magnitude_idx1], 220 | # operation2, ranges[operation2][magnitude_idx2]) 221 | self.p1 = p1 222 | self.operation1 = func[operation1] 223 | self.magnitude1 = ranges[operation1][magnitude_idx1] 224 | self.p2 = p2 225 | self.operation2 = func[operation2] 226 | self.magnitude2 = ranges[operation2][magnitude_idx2] 227 | 228 | 229 | def __call__(self, img): 230 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 231 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 232 | return img 233 | -------------------------------------------------------------------------------- /utils/dataset_DCL.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | from __future__ import division 3 | import os 4 | import torch 5 | import torch.utils.data as data 6 | import pandas 7 | import random 8 | import PIL.Image as Image 9 | from PIL import ImageStat 10 | 11 | import pdb 12 | 13 | def random_sample(img_names, labels): 14 | anno_dict = {} 15 | img_list = [] 16 | anno_list = [] 17 | for img, anno in zip(img_names, labels): 18 | if not anno in anno_dict: 19 | anno_dict[anno] = [img] 20 | else: 21 | anno_dict[anno].append(img) 22 | 23 | for anno in anno_dict.keys(): 24 | anno_len = len(anno_dict[anno]) 25 | fetch_keys = random.sample(list(range(anno_len)), anno_len//10) 26 | img_list.extend([anno_dict[anno][x] for x in fetch_keys]) 27 | anno_list.extend([anno for x in fetch_keys]) 28 | return img_list, anno_list 29 | 30 | 31 | 32 | class dataset(data.Dataset): 33 | def __init__(self, Config, anno, swap_size=[7,7], common_aug=None, swap=None, totensor=None, train=False, train_val=False, test=False): 34 | self.root_path = Config.rawdata_root 35 | self.numcls = Config.numcls 36 | self.dataset = Config.dataset 37 | self.use_cls_2 = Config.cls_2 38 | self.use_cls_mul = Config.cls_2xmul 39 | if isinstance(anno, pandas.core.frame.DataFrame): 40 | self.paths = anno['ImageName'].tolist() 41 | self.labels = anno['label'].tolist() 42 | elif isinstance(anno, dict): 43 | self.paths = anno['img_name'] 44 | self.labels = anno['label'] 45 | 46 | if train_val: 47 | self.paths, self.labels = random_sample(self.paths, self.labels) 48 | self.common_aug = common_aug 49 | self.swap = swap 50 | self.totensor = totensor 51 | self.cfg = Config 52 | self.train = train 53 | self.swap_size = swap_size 54 | self.test = test 55 | 56 | def __len__(self): 57 | return len(self.paths) 58 | 59 | def __getitem__(self, item): 60 | img_path = os.path.join(self.root_path, self.paths[item]) 61 | img = self.pil_loader(img_path) 62 | if self.test: 63 | img = self.totensor(img) 64 | label = self.labels[item] 65 | return img, label, self.paths[item] 66 | img_unswap = self.common_aug(img) if not self.common_aug is None else img 67 | 68 | image_unswap_list = self.crop_image(img_unswap, self.swap_size) 69 | 70 | swap_range = self.swap_size[0] * self.swap_size[1] 71 | swap_law1 = [(i-(swap_range//2))/swap_range for i in range(swap_range)] 72 | 73 | if self.train: 74 | img_swap = self.swap(img_unswap) 75 | image_swap_list = self.crop_image(img_swap, self.swap_size) 76 | unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list] 77 | swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list] 78 | swap_law2 = [] 79 | for swap_im in swap_stats: 80 | distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats] 81 | index = distance.index(min(distance)) 82 | swap_law2.append((index-(swap_range//2))/swap_range) 83 | img_swap = self.totensor(img_swap) 84 | label = self.labels[item] 85 | if self.use_cls_mul: 86 | label_swap = label + self.numcls 87 | if self.use_cls_2: 88 | label_swap = -1 89 | img_unswap = self.totensor(img_unswap) 90 | return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2, self.paths[item] 91 | else: 92 | label = self.labels[item] 93 | swap_law2 = [(i-(swap_range//2))/swap_range for i in range(swap_range)] 94 | label_swap = label 95 | img_unswap = self.totensor(img_unswap) 96 | return img_unswap, label, label_swap, swap_law1, swap_law2, self.paths[item] 97 | 98 | def pil_loader(self,imgpath): 99 | with open(imgpath, 'rb') as f: 100 | with Image.open(f) as img: 101 | return img.convert('RGB') 102 | 103 | def crop_image(self, image, cropnum): 104 | width, high = image.size 105 | crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)] 106 | crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)] 107 | im_list = [] 108 | for j in range(len(crop_y) - 1): 109 | for i in range(len(crop_x) - 1): 110 | im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high)))) 111 | return im_list 112 | 113 | 114 | def get_weighted_sampler(self): 115 | img_nums = len(self.labels) 116 | weights = [self.labels.count(x) for x in range(self.numcls)] 117 | return torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=img_nums) 118 | 119 | 120 | def collate_fn4train(batch): 121 | imgs = [] 122 | label = [] 123 | label_swap = [] 124 | law_swap = [] 125 | img_name = [] 126 | for sample in batch: 127 | imgs.append(sample[0]) 128 | imgs.append(sample[1]) 129 | label.append(sample[2]) 130 | label.append(sample[2]) 131 | if sample[3] == -1: 132 | label_swap.append(1) 133 | label_swap.append(0) 134 | else: 135 | label_swap.append(sample[2]) 136 | label_swap.append(sample[3]) 137 | law_swap.append(sample[4]) 138 | law_swap.append(sample[5]) 139 | img_name.append(sample[-1]) 140 | return torch.stack(imgs, 0), label, label_swap, law_swap, img_name 141 | 142 | def collate_fn4val(batch): 143 | imgs = [] 144 | label = [] 145 | label_swap = [] 146 | law_swap = [] 147 | img_name = [] 148 | for sample in batch: 149 | imgs.append(sample[0]) 150 | label.append(sample[1]) 151 | if sample[3] == -1: 152 | label_swap.append(1) 153 | else: 154 | label_swap.append(sample[2]) 155 | law_swap.append(sample[3]) 156 | img_name.append(sample[-1]) 157 | return torch.stack(imgs, 0), label, label_swap, law_swap, img_name 158 | 159 | def collate_fn4backbone(batch): 160 | imgs = [] 161 | label = [] 162 | img_name = [] 163 | for sample in batch: 164 | imgs.append(sample[0]) 165 | if len(sample) == 7: 166 | label.append(sample[2]) 167 | else: 168 | label.append(sample[1]) 169 | img_name.append(sample[-1]) 170 | return torch.stack(imgs, 0), label, img_name 171 | 172 | 173 | def collate_fn4test(batch): 174 | imgs = [] 175 | label = [] 176 | img_name = [] 177 | for sample in batch: 178 | imgs.append(sample[0]) 179 | label.append(sample[1]) 180 | img_name.append(sample[-1]) 181 | return torch.stack(imgs, 0), label, img_name 182 | -------------------------------------------------------------------------------- /utils/eval_model.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from __future__ import print_function, division 3 | import os,time,datetime 4 | import numpy as np 5 | import datetime 6 | from math import ceil 7 | 8 | import torch 9 | from torch import nn 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | 13 | from utils.utils import LossRecord 14 | 15 | import pdb 16 | 17 | def dt(): 18 | return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S") 19 | 20 | def eval_turn(Config, model, data_loader, val_version, epoch_num, log_file): 21 | 22 | model.train(False) 23 | 24 | val_corrects1 = 0 25 | val_corrects2 = 0 26 | val_corrects3 = 0 27 | val_size = data_loader.__len__() 28 | item_count = data_loader.total_item_len 29 | t0 = time.time() 30 | get_l1_loss = nn.L1Loss() 31 | get_ce_loss = nn.CrossEntropyLoss() 32 | 33 | val_batch_size = data_loader.batch_size 34 | val_epoch_step = data_loader.__len__() 35 | num_cls = data_loader.num_cls 36 | 37 | val_loss_recorder = LossRecord(val_batch_size) 38 | val_celoss_recorder = LossRecord(val_batch_size) 39 | print('evaluating %s ...'%val_version, flush=True) 40 | with torch.no_grad(): 41 | for batch_cnt_val, data_val in enumerate(data_loader): 42 | inputs = Variable(data_val[0].cuda()) 43 | labels = Variable(torch.from_numpy(np.array(data_val[1])).long().cuda()) 44 | outputs = model(inputs) 45 | loss = 0 46 | 47 | ce_loss = get_ce_loss(outputs[0], labels).item() 48 | loss += ce_loss 49 | 50 | val_loss_recorder.update(loss) 51 | val_celoss_recorder.update(ce_loss) 52 | 53 | if Config.use_dcl and Config.cls_2xmul: 54 | outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls] 55 | else: 56 | outputs_pred = outputs[0] 57 | top3_val, top3_pos = torch.topk(outputs_pred, 3) 58 | 59 | print('{:s} eval_batch: {:-6d} / {:d} loss: {:8.4f}'.format(val_version, batch_cnt_val, val_epoch_step, loss), flush=True) 60 | 61 | batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item() 62 | val_corrects1 += batch_corrects1 63 | batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item() 64 | val_corrects2 += (batch_corrects2 + batch_corrects1) 65 | batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item() 66 | val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1) 67 | 68 | val_acc1 = val_corrects1 / item_count 69 | val_acc2 = val_corrects2 / item_count 70 | val_acc3 = val_corrects3 / item_count 71 | 72 | log_file.write(val_version + '\t' +str(val_loss_recorder.get_val())+'\t' + str(val_celoss_recorder.get_val()) + '\t' + str(val_acc1) + '\t' + str(val_acc3) + '\n') 73 | 74 | t1 = time.time() 75 | since = t1-t0 76 | print('--'*30, flush=True) 77 | print('% 3d %s %s %s-loss: %.4f ||%s-acc@1: %.4f %s-acc@2: %.4f %s-acc@3: %.4f ||time: %d' % (epoch_num, val_version, dt(), val_version, val_loss_recorder.get_val(init=True), val_version, val_acc1,val_version, val_acc2, val_version, val_acc3, since), flush=True) 78 | print('--' * 30, flush=True) 79 | 80 | return val_acc1, val_acc2, val_acc3 81 | 82 | -------------------------------------------------------------------------------- /utils/test_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | import datetime 6 | 7 | import torch 8 | from torchvision.utils import save_image, make_grid 9 | 10 | import pdb 11 | 12 | def dt(): 13 | return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S") 14 | 15 | def set_text(text, img): 16 | font = cv2.FONT_HERSHEY_SIMPLEX 17 | if isinstance(text, str): 18 | cont = text 19 | cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 20 | if isinstance(text, float): 21 | cont = '%.4f'%text 22 | cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 23 | if isinstance(text, list): 24 | for count in range(len(img)): 25 | cv2.putText(img[count], text[count], (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 26 | return img 27 | 28 | def save_multi_img(img_list, text_list, grid_size=[5,5], sub_size=200, save_dir='./', save_name=None): 29 | if len(img_list) > grid_size[0]*grid_size[1]: 30 | merge_height = math.ceil(len(img_list) / grid_size[0]) * sub_size 31 | else: 32 | merge_height = grid_size[1]*sub_size 33 | merged_img = np.zeros((merge_height, grid_size[0]*sub_size, 3)) 34 | 35 | if isinstance(img_list[0], str): 36 | img_name_list = img_list 37 | img_list = [] 38 | for img_name in img_name_list: 39 | img_list.append(cv2.imread(img_name)) 40 | 41 | img_counter = 0 42 | for img, txt in zip(img_list, text_list): 43 | img = cv2.resize(img, (sub_size, sub_size)) 44 | img = set_text(txt, img) 45 | pos = [img_counter // grid_size[1], img_counter % grid_size[1]] 46 | sub_pos = [pos[0]*sub_size, (pos[0]+1)*sub_size, 47 | pos[1]*sub_size, (pos[1]+1)*sub_size] 48 | merged_img[sub_pos[0]:sub_pos[1], sub_pos[2]:sub_pos[3], :] = img 49 | img_counter += 1 50 | 51 | if save_name is None: 52 | img_save_path = os.path.join(save_dir, dt()+'.png') 53 | else: 54 | img_save_path = os.path.join(save_dir, save_name+'.png') 55 | cv2.imwrite(img_save_path, merged_img) 56 | print('saved img in %s ...'%img_save_path) 57 | 58 | 59 | def cls_base_acc(result_gather): 60 | top1_acc = {} 61 | top3_acc = {} 62 | cls_count = {} 63 | for img_item in result_gather.keys(): 64 | acc_case = result_gather[img_item] 65 | 66 | if acc_case['label'] in cls_count: 67 | cls_count[acc_case['label']] += 1 68 | if acc_case['top1_cat'] == acc_case['label']: 69 | top1_acc[acc_case['label']] += 1 70 | if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]: 71 | top3_acc[acc_case['label']] += 1 72 | else: 73 | cls_count[acc_case['label']] = 1 74 | if acc_case['top1_cat'] == acc_case['label']: 75 | top1_acc[acc_case['label']] = 1 76 | else: 77 | top1_acc[acc_case['label']] = 0 78 | 79 | if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]: 80 | top3_acc[acc_case['label']] = 1 81 | else: 82 | top3_acc[acc_case['label']] = 0 83 | 84 | for label_item in cls_count: 85 | top1_acc[label_item] /= max(1.0*cls_count[label_item], 0.001) 86 | top3_acc[label_item] /= max(1.0*cls_count[label_item], 0.001) 87 | 88 | print('top1_acc:', top1_acc) 89 | print('top3_acc:', top3_acc) 90 | print('cls_count', cls_count) 91 | 92 | return top1_acc, top3_acc, cls_count 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/train_model.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from __future__ import print_function, division 3 | 4 | import os,time,datetime 5 | import numpy as np 6 | from math import ceil 7 | import datetime 8 | 9 | import torch 10 | from torch import nn 11 | from torch.autograd import Variable 12 | #from torchvision.utils import make_grid, save_image 13 | 14 | from utils.utils import LossRecord, clip_gradient 15 | from models.focal_loss import FocalLoss 16 | from utils.eval_model import eval_turn 17 | from utils.Asoftmax_loss import AngleLoss 18 | 19 | import pdb 20 | 21 | def dt(): 22 | return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S") 23 | 24 | 25 | def train(Config, 26 | model, 27 | epoch_num, 28 | start_epoch, 29 | optimizer, 30 | exp_lr_scheduler, 31 | data_loader, 32 | save_dir, 33 | data_size=448, 34 | savepoint=500, 35 | checkpoint=1000 36 | ): 37 | # savepoint: save without evalution 38 | # checkpoint: save with evaluation 39 | 40 | step = 0 41 | eval_train_flag = False 42 | rec_loss = [] 43 | checkpoint_list = [] 44 | 45 | train_batch_size = data_loader['train'].batch_size 46 | train_epoch_step = data_loader['train'].__len__() 47 | train_loss_recorder = LossRecord(train_batch_size) 48 | 49 | if savepoint > train_epoch_step: 50 | savepoint = 1*train_epoch_step 51 | checkpoint = savepoint 52 | 53 | date_suffix = dt() 54 | log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a') 55 | 56 | add_loss = nn.L1Loss() 57 | get_ce_loss = nn.CrossEntropyLoss() 58 | get_focal_loss = FocalLoss() 59 | get_angle_loss = AngleLoss() 60 | 61 | for epoch in range(start_epoch,epoch_num-1): 62 | exp_lr_scheduler.step(epoch) 63 | model.train(True) 64 | 65 | save_grad = [] 66 | for batch_cnt, data in enumerate(data_loader['train']): 67 | step += 1 68 | loss = 0 69 | model.train(True) 70 | if Config.use_backbone: 71 | inputs, labels, img_names = data 72 | inputs = Variable(inputs.cuda()) 73 | labels = Variable(torch.from_numpy(np.array(labels)).cuda()) 74 | 75 | if Config.use_dcl: 76 | inputs, labels, labels_swap, swap_law, img_names = data 77 | 78 | inputs = Variable(inputs.cuda()) 79 | labels = Variable(torch.from_numpy(np.array(labels)).cuda()) 80 | labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda()) 81 | swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda()) 82 | 83 | optimizer.zero_grad() 84 | 85 | if inputs.size(0) < 2*train_batch_size: 86 | outputs = model(inputs, inputs[0:-1:2]) 87 | else: 88 | outputs = model(inputs, None) 89 | 90 | if Config.use_focal_loss: 91 | ce_loss = get_focal_loss(outputs[0], labels) 92 | else: 93 | ce_loss = get_ce_loss(outputs[0], labels) 94 | 95 | if Config.use_Asoftmax: 96 | fetch_batch = labels.size(0) 97 | if batch_cnt % (train_epoch_step // 5) == 0: 98 | angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2], decay=0.9) 99 | else: 100 | angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2]) 101 | loss += angle_loss 102 | 103 | loss += ce_loss 104 | 105 | alpha_ = 1 106 | beta_ = 1 107 | gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1 108 | if Config.use_dcl: 109 | swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_ 110 | loss += swap_loss 111 | law_loss = add_loss(outputs[2], swap_law) * gamma_ 112 | loss += law_loss 113 | 114 | loss.backward() 115 | torch.cuda.synchronize() 116 | 117 | optimizer.step() 118 | torch.cuda.synchronize() 119 | 120 | if Config.use_dcl: 121 | print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item(), swap_loss.detach().item(), law_loss.detach().item()), flush=True) 122 | if Config.use_backbone: 123 | print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item()), flush=True) 124 | rec_loss.append(loss.detach().item()) 125 | 126 | train_loss_recorder.update(loss.detach().item()) 127 | 128 | # evaluation & save 129 | if step % checkpoint == 0: 130 | rec_loss = [] 131 | print(32*'-', flush=True) 132 | print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True) 133 | print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True) 134 | if eval_train_flag: 135 | trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file) 136 | if abs(trainval_acc1 - trainval_acc3) < 0.01: 137 | eval_train_flag = False 138 | 139 | val_acc1, val_acc2, val_acc3 = eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file) 140 | 141 | save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3)) 142 | torch.cuda.synchronize() 143 | torch.save(model.state_dict(), save_path) 144 | print('saved model to %s' % (save_path), flush=True) 145 | torch.cuda.empty_cache() 146 | 147 | # save only 148 | elif step % savepoint == 0: 149 | train_loss_recorder.update(rec_loss) 150 | rec_loss = [] 151 | save_path = os.path.join(save_dir, 'savepoint_weights-%d-%s.pth'%(step, dt())) 152 | 153 | checkpoint_list.append(save_path) 154 | if len(checkpoint_list) == 6: 155 | os.remove(checkpoint_list[0]) 156 | del checkpoint_list[0] 157 | torch.save(model.state_dict(), save_path) 158 | torch.cuda.empty_cache() 159 | 160 | 161 | log_file.close() 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import pdb 6 | 7 | 8 | class LossRecord(object): 9 | def __init__(self, batch_size): 10 | self.rec_loss = 0 11 | self.count = 0 12 | self.batch_size = batch_size 13 | 14 | def update(self, loss): 15 | if isinstance(loss, list): 16 | avg_loss = sum(loss) 17 | avg_loss /= (len(loss)*self.batch_size) 18 | self.rec_loss += avg_loss 19 | self.count += 1 20 | if isinstance(loss, float): 21 | self.rec_loss += loss/self.batch_size 22 | self.count += 1 23 | 24 | def get_val(self, init=False): 25 | pop_loss = self.rec_loss / self.count 26 | if init: 27 | self.rec_loss = 0 28 | self.count = 0 29 | return pop_loss 30 | 31 | 32 | def weights_normal_init(model, dev=0.01): 33 | if isinstance(model, list): 34 | for m in model: 35 | weights_normal_init(m, dev) 36 | else: 37 | for m in model.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | m.weight.data.normal_(0.0, dev) 40 | elif isinstance(m, nn.Linear): 41 | m.weight.data.normal_(0.0, dev) 42 | 43 | 44 | def clip_gradient(model, clip_norm): 45 | """Computes a gradient clipping coefficient based on gradient norm.""" 46 | totalnorm = 0 47 | for p in model.parameters(): 48 | if p.requires_grad: 49 | modulenorm = p.grad.data.norm() 50 | totalnorm += modulenorm ** 2 51 | totalnorm = torch.sqrt(totalnorm).item() 52 | norm = (clip_norm / max(totalnorm, clip_norm)) 53 | for p in model.parameters(): 54 | if p.requires_grad: 55 | p.grad.mul_(norm) 56 | 57 | 58 | def Linear(in_features, out_features, bias=True): 59 | """Weight-normalized Linear layer (input: N x T x C)""" 60 | m = nn.Linear(in_features, out_features, bias=bias) 61 | m.weight.data.uniform_(-0.1, 0.1) 62 | if bias: 63 | m.bias.data.uniform_(-0.1, 0.1) 64 | return m 65 | 66 | 67 | class convolution(nn.Module): 68 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 69 | super(convolution, self).__init__() 70 | 71 | pad = (k - 1) // 2 72 | self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn) 73 | self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential() 74 | self.relu = nn.ReLU(inplace=True) 75 | 76 | def forward(self, x): 77 | conv = self.conv(x) 78 | bn = self.bn(conv) 79 | relu = self.relu(bn) 80 | return relu 81 | 82 | class fully_connected(nn.Module): 83 | def __init__(self, inp_dim, out_dim, with_bn=True): 84 | super(fully_connected, self).__init__() 85 | self.with_bn = with_bn 86 | 87 | self.linear = nn.Linear(inp_dim, out_dim) 88 | if self.with_bn: 89 | self.bn = nn.BatchNorm1d(out_dim) 90 | self.relu = nn.ReLU(inplace=True) 91 | 92 | def forward(self, x): 93 | linear = self.linear(x) 94 | bn = self.bn(linear) if self.with_bn else linear 95 | relu = self.relu(bn) 96 | return relu 97 | 98 | class residual(nn.Module): 99 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 100 | super(residual, self).__init__() 101 | 102 | self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False) 103 | self.bn1 = nn.BatchNorm2d(out_dim) 104 | self.relu1 = nn.ReLU(inplace=True) 105 | 106 | self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False) 107 | self.bn2 = nn.BatchNorm2d(out_dim) 108 | 109 | self.skip = nn.Sequential( 110 | nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False), 111 | nn.BatchNorm2d(out_dim) 112 | ) if stride != 1 or inp_dim != out_dim else nn.Sequential() 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | def forward(self, x): 116 | conv1 = self.conv1(x) 117 | bn1 = self.bn1(conv1) 118 | relu1 = self.relu1(bn1) 119 | 120 | conv2 = self.conv2(relu1) 121 | bn2 = self.bn2(conv2) 122 | 123 | skip = self.skip(x) 124 | return self.relu(bn2 + skip) 125 | --------------------------------------------------------------------------------