├── .gitignore ├── README.md ├── __init__.py ├── configs └── config.json ├── data_example ├── jpg │ ├── image_00001.jpg │ ├── image_00002.jpg │ ├── image_00003.jpg │ ├── image_00004.jpg │ ├── image_00005.jpg │ ├── image_00006.jpg │ ├── image_00007.jpg │ ├── image_00008.jpg │ ├── image_00009.jpg │ └── image_00010.jpg ├── testset.txt └── trainset.txt ├── data_loader ├── data_augmentation.py ├── data_processor.py └── dataset.py ├── inference.py ├── logs └── .gitignore ├── nets ├── alexnet_module.py ├── densenet_module.py ├── dpn_module.py ├── inception_resnet_v2_module.py ├── inception_v3_module.py ├── inception_v4_module.py ├── mnasnet_module.py ├── mobilenet_v2_module.py ├── mobilenet_v3_module.py ├── nasnet_a_large_module.py ├── nasnet_a_mobile_module.py ├── net_interface.py ├── oct_resnet_module.py ├── pnasnet_5_large_module.py ├── polynet_module.py ├── resnet_module.py ├── resnext_module.py ├── senet_module.py ├── shufflenet_v2_module.py ├── squeezenet_module.py ├── vgg_module.py └── xception_module.py ├── requirements.txt ├── train.py ├── trainers ├── base_model.py ├── base_trainer.py ├── example_model.py └── example_trainer.py └── utils ├── config.py ├── logger.py ├── logger_summarizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Added by ~ cadene ~ 2 | .DS_Store 3 | ._.DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # image-classification-pytorch 2 | This repo is designed for those who want to start their projects of image classification. 3 | It provides fast experiment setup and attempts to maximize the number of projects killed within the given time. 4 | It includes a few Convolutional Neural Network modules.You can build your own dnn easily. 5 | 6 | ## Requirements 7 | Python3 support only. Tested on CUDA9.0, cudnn7. 8 | 9 | * albumentations==0.1.1 10 | * easydict==1.8 11 | * imgaug==0.2.6 12 | * opencv-python==3.4.3.18 13 | * protobuf==3.6.1 14 | * scikit-image==0.14.0 15 | * tensorboardX==1.4 16 | * torch==0.4.1 17 | * torchvision==0.2.1 18 | 19 | ## model 20 | | net | inputsize | 21 | |-------------------------|-----------| 22 | | vggnet | 224 | 23 | | alexnet | 224 | 24 | | resnet | 224 | 25 | | inceptionV3 | 299 | 26 | | inceptionV4 | 299 | 27 | | squeezenet | 224 | 28 | | densenet | 224 | 29 | | dpnnet | 224 | 30 | | inception-resnet-v2 | 299 | 31 | | mobilenetV2 | 224 | 32 | | nasnet-a-large | 331 | 33 | | nasnet-mobile | 224 | 34 | | polynet | 331 | 35 | | resnext | 224 | 36 | | senet | 224 | 37 | | squeezenet | 224 | 38 | | pnasnet | 331 | 39 | | shufflenetV2 | 224 | 40 | | mnasnet | 224 | 41 | | mobilenetV3 | 224 | 42 | | oct-resnet | 224/256 | 43 | | ... | ... | 44 | 45 | ### pre-trained model 46 | you can download pretrain model with url in ($net-module.py) 47 | 48 | #### From [torchvision](https://github.com/pytorch/vision/) package: 49 | 50 | - ResNet ([resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth), [resnet34](https://download.pytorch.org/models/resnet34-333f7ec4.pth), [resnet50](https://download.pytorch.org/models/resnet50-19c8e357.pth), [resnet101](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth), [resnet152](https://download.pytorch.org/models/resnet152-b121ed2d.pth)) 51 | - DenseNet ([densenet121](https://download.pytorch.org/models/densenet121-a639ec97.pth'), [densenet169](https://download.pytorch.org/models/densenet169-b2777c0a.pth), [densenet201](https://download.pytorch.org/models/densenet201-c1103571.pth), [densenet161](https://download.pytorch.org/models/densenet161-8d451a50.pth)) 52 | - Inception v3 ([inception_v3](https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth)) 53 | - VGG ([vgg11](https://download.pytorch.org/models/vgg11-bbd30ac9.pth), [vgg11_bn](https://download.pytorch.org/models/vgg11_bn-6002323d.pth), [vgg13](https://download.pytorch.org/models/vgg13-c768596a.pth), [vgg13_bn](https://download.pytorch.org/models/vgg13_bn-abd245e5.pth), [vgg16](https://download.pytorch.org/models/vgg16-397923af.pth), [vgg16_bn](https://download.pytorch.org/models/vgg16_bn-6c64b313.pth), [vgg19](https://download.pytorch.org/models/vgg19-dcbb9e9d.pth), [vgg19_bn](https://download.pytorch.org/models/vgg19_bn-c79401a0.pth)) 54 | - SqueezeNet ([squeezenet1_0](https://download.pytorch.org/models/squeezenet1_0-a815701f.pth), [squeezenet1_1](https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth)) 55 | - AlexNet ([alexnet](https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth)) 56 | 57 | #### From [Pretrained models for PyTorch](https://github.com/Cadene/pretrained-models.pytorch) package: 58 | - ResNeXt ([resnext101_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/resnext101_32x4d-29e315fa.pth), [resnext101_64x4d](http://data.lip6.fr/cadene/pretrainedmodels/resnext101_64x4d-e77a0586.pth)) 59 | - NASNet-A Large (`nasnet_a_large`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth), [imagenet+background](http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth)) 60 | - NASNet-A Mobile (`nasnet_a_mobile`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth)) 61 | - Inception-ResNet v2 (`inception_resnet_v2`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth), [imagenet+background](http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth)) 62 | - Dual Path Networks ([dpn68](http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth), [dpn68b](http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth), `dpn92`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth), [imagenet+5k](http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth), [dpn98](http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth), [dpn131](http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth), [dpn107](http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth)) 63 | - Inception v4 (`inception_v4`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth), [imagenet+background](http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth)) 64 | - Xception ([xception](http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth)) 65 | - Squeeze-and-Excitation Networks ([senet154](http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth), [se_resnet50](http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth), [se_resnet101](http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth), [se_resnet152](http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth), [se_resnext50_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth), [se_resnext101_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth)) 66 | - PNASNet-5-Large (`pnasnet_5_large`: [imagenet](http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth), [imagenet+background](http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth)) 67 | - PolyNet ([polynet](http://data.lip6.fr/cadene/pretrainedmodels/polynet-f71d82a5.pth)) 68 | 69 | #### From [mobilenetV2](https://github.com/ericsun99/MobileNet-V2-Pytorch) package: 70 | - Mobilenet V2 ([mobilenet_v2](https://github.com/ericsun99/MobileNet-V2-Pytorch)) 71 | 72 | #### From [shufflenetV2](https://github.com/ericsun99/Shufflenet-v2-Pytorch) package: 73 | - Shufflenet V2 ([shufflenet_v2](https://github.com/ericsun99/Shufflenet-v2-Pytorch)) 74 | 75 | #### From [MnasNet](https://github.com/billhhh/MnasNet-pytorch-pretrained) package: 76 | - Mnasnet ([MnasNet](https://github.com/billhhh/MnasNet-pytorch-pretrained)) 77 | 78 | #### From [mobilenetV3](https://github.com/kuan-wang/pytorch-mobilenet-v3) package: 79 | - Mobilenet V3 ([mobilenet_v3_large](https://github.com/kuan-wang/pytorch-mobilenet-v3), [mobilenet_v3_small](https://github.com/kuan-wang/pytorch-mobilenet-v3)) 80 | 81 | #### From [OctaveResnet](https://github.com/d-li14/octconv.pytorch) package: 82 | - Octave Resnet ([oct_resnet26](https://github.com/d-li14/octconv.pytorch), [oct_resnet50](https://github.com/d-li14/octconv.pytorch), [oct_resnet101](https://github.com/d-li14/octconv.pytorch), [oct_resnet152](https://github.com/d-li14/octconv.pytorch), [oct_resnet200](https://github.com/d-li14/octconv.pytorch)) 83 | 84 | ## usage 85 | 86 | ### configuration 87 | | configure | description | 88 | |---------------------------------|---------------------------------------------------------------------------| 89 | | model_module_name | eg: vgg_module | 90 | | model_net_name | net function name in module, eg:vgg16 | 91 | | gpu_id | eg: single GPU: "0", multi-GPUs:"0,1,3,4,7" | 92 | | async_loading | make an asynchronous copy to the GPU | 93 | | is_tensorboard | if use tensorboard for visualization | 94 | | evaluate_before_train | evaluate accuracy before training | 95 | | shuffle | shuffle your training data | 96 | | data_aug | augment your training data | 97 | | img_height | input height | 98 | | img_width | input width | 99 | | num_channels | input channel | 100 | | num_classes | output number of classes | 101 | | batch_size | train batch size | 102 | | dataloader_workers | number of workers when loading data | 103 | | learning_rate | learning rate | 104 | | learning_rate_decay | learning rate decat rate | 105 | | learning_rate_decay_epoch | learning rate decay per n-epoch | 106 | | train_mode | eg: "fromscratch","finetune","update" | 107 | | file_label_separator | separator between data-name and label. eg:"----" | 108 | | pretrained_path | pretrain model path | 109 | | pretrained_file | pretrain model name. eg:"alexnet-owt-4df8aa71.pth" | 110 | | pretrained_model_num_classes | output number of classes when pretrain model trained. eg:1000 in imagenet | 111 | | save_path | model path when saving | 112 | | save_name | model name when saving | 113 | | train_data_root_dir | training data root dir | 114 | | val_data_root_dir | testing data root dir | 115 | | train_data_file | a txt filename which has training data and label list | 116 | | val_data_file | a txt filename which has testing data and label list | 117 | 118 | ### Training 119 | 1.make your training &. testing data and label list with txt file: 120 | 121 | txt file with single label index eg: 122 | 123 | apple.jpg----0 124 | k.jpg----3 125 | 30.jpg----0 126 | data/2.jpg----1 127 | abc.jpg----1 128 | 2.configuration 129 | 130 | 3.train 131 | 132 | python3 train.py 133 | 134 | ### Inference 135 | eg: trained by inception_resnet_v2, vgg/data/flowers/102: 136 | 137 | python3 inference.py --image test.jpg --module inception_resnet_v2_module --net inception_resnet_v2 --model model.pth --size 299 --cls 102 138 | 139 | ### tensorboardX 140 | 141 | tensorboard --logdir='./logs/' runs 142 | 143 | logdir is log dir in your project dir 144 | 145 | ## References 146 | 1.[https://github.com/pytorch](https://github.com/pytorch) 147 | 2.[https://github.com/victoresque/pytorch-template](https://github.com/victoresque/pytorch-template) 148 | 3.[https://pytorch.org](https://pytorch.org) 149 | 5.[https://github.com/yunjey/pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial) 150 | 4.[https://www.tensorflow.org](https://www.tensorflow.org) 151 | 5.[https://github.com/Cadene/pretrained-models.pytorch/tree/master/pretrainedmodels/models](https://github.com/Cadene/pretrained-models.pytorch/tree/master/pretrainedmodels/models) 152 | 6.[https://github.com/ericsun99/MobileNet-V2-Pytorch](https://github.com/ericsun99/MobileNet-V2-Pytorch) 153 | 7.[http://www.robots.ox.ac.uk/~vgg/data/flowers/102](http://www.robots.ox.ac.uk/~vgg/data/flowers/102) 154 | 8.[https://github.com/ericsun99/Shufflenet-v2-Pytorch](https://github.com/ericsun99/Shufflenet-v2-Pytorch) 155 | 9.[https://github.com/billhhh/MnasNet-pytorch-pretrained](https://github.com/billhhh/MnasNet-pytorch-pretrained) 156 | 10.[https://github.com/d-li14/octconv.pytorch](https://github.com/d-li14/octconv.pytorch) 157 | 11.[https://github.com/kuan-wang/pytorch-mobilenet-v3](https://github.com/kuan-wang/pytorch-mobilenet-v3) 158 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_module_name": "inception_resnet_v2_module", 3 | "model_net_name": "inception_resnet_v2", 4 | "gpu_id": "0,1", 5 | "async_loading": true, 6 | "is_tensorboard": true, 7 | "evaluate_before_train": true, 8 | "shuffle": true, 9 | "data_aug": true, 10 | 11 | "num_epochs": 30, 12 | "img_height": 299, 13 | "img_width": 299, 14 | "num_channels": 3, 15 | "num_classes": 102, 16 | "batch_size": 64, 17 | "dataloader_workers": 1, 18 | "learning_rate": 1e-4, 19 | "learning_rate_decay": 0.9, 20 | "learning_rate_decay_epoch": 100, 21 | 22 | "train_mode": "finetune", 23 | "file_label_separator": "----", 24 | "pretrained_path": "/hdd/pretrain_model/pytorch", 25 | "pretrained_file": "inceptionresnetv2-520b38e4.pth", 26 | "save_path": "/hdd/datasets/flower_102/save_model", 27 | "save_name": "model.pth", 28 | 29 | "train_data_root_dir": "/hdd/datasets/flower_102", 30 | "val_data_root_dir": "/hdd/datasets/flower_102", 31 | "train_data_file": "/hdd/datasets/flower_102/trainset.txt", 32 | "val_data_file": "/hdd/datasets/flower_102/testset.txt" 33 | 34 | 35 | } -------------------------------------------------------------------------------- /data_example/jpg/image_00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00001.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00002.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00003.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00004.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00005.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00006.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00007.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00008.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00009.jpg -------------------------------------------------------------------------------- /data_example/jpg/image_00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/data_example/jpg/image_00010.jpg -------------------------------------------------------------------------------- /data_loader/data_augmentation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | 5 | ########################################################## 6 | # name: DataAugmenters 7 | # breif: 8 | # 9 | # usage: 10 | ########################################################## 11 | class DataAugmenters: 12 | def __init__(self, config): 13 | self.config = config 14 | self.augmentation = self._aug_albumentations() 15 | 16 | def _example(self,image,**kwargs): 17 | """ 18 | example 19 | :param image: 20 | :param kwargs: 21 | :return: 22 | """ 23 | return image 24 | 25 | 26 | def run(self,image,**kwargs): 27 | """ 28 | augment your image data 29 | :param image: 30 | :param kwargs: 31 | :return: 32 | """ 33 | data = {'image': image} 34 | augmented = self.augmentation(**data) 35 | return augmented['image'] 36 | 37 | 38 | def _aug_albumentations(self): 39 | p = 0.9 40 | from albumentations import ( 41 | HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, 42 | Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, 43 | IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, 44 | IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose, JpegCompression, 45 | DualTransform) 46 | return Compose([ 47 | RandomRotate90(), 48 | Flip(), 49 | Transpose(), 50 | OneOf([ 51 | IAAAdditiveGaussianNoise(), 52 | GaussNoise(), 53 | ], p=0.2), 54 | OneOf([ 55 | MotionBlur(p=0.2), 56 | MedianBlur(blur_limit=3, p=0.1), 57 | Blur(blur_limit=3, p=0.1), 58 | ], p=0.2), 59 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2), 60 | OneOf([ 61 | OpticalDistortion(p=0.3), 62 | GridDistortion(p=0.1), 63 | IAAPiecewiseAffine(p=0.3), 64 | ], p=0.2), 65 | OneOf([ 66 | CLAHE(clip_limit=2), 67 | IAASharpen(), 68 | IAAEmboss(), 69 | RandomContrast(), 70 | RandomBrightness(), 71 | ], p=0.3), 72 | HueSaturationValue(p=0.3), 73 | ], p=p) 74 | 75 | 76 | if __name__ == '__main__': 77 | print('done!') 78 | -------------------------------------------------------------------------------- /data_loader/data_processor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import random 4 | import numpy as np 5 | import cv2 6 | import scipy.misc as misc 7 | from data_loader.data_augmentation import DataAugmenters 8 | 9 | 10 | ########################################################## 11 | # name: DataProcessor 12 | # breif: 13 | # 14 | # usage: 15 | ########################################################## 16 | class DataProcessor: 17 | def __init__(self, config): 18 | self.config = config 19 | self.DataAugmenters = DataAugmenters(self.config) 20 | 21 | def image_loader(self, filename, **kwargs): 22 | """ 23 | load your image data 24 | :param filename: 25 | :return: 26 | """ 27 | image = cv2.imread(filename) 28 | if image is None: 29 | raise ValueError('image data is none when cv2.imread!') 30 | return image 31 | 32 | 33 | def image_resize(self, image, **kwargs): 34 | """ 35 | resize your image data 36 | :param image: 37 | :param kwargs: 38 | :return: 39 | """ 40 | _size = (self.config['img_width'], self.config['img_height']) 41 | _resize_image = cv2.resize(image, _size) 42 | return _resize_image[:,:,::-1] # bgr2rgb 43 | 44 | def input_norm(self, image, **kwargs): 45 | """ 46 | normalize your image data 47 | :param image: 48 | :return: 49 | """ 50 | return ((image - 127) * 0.0078125).astype(np.float32) # 1/128 51 | 52 | 53 | def data_aug(self, image, **kwargs): 54 | """ 55 | augment your image data with DataAugmenters 56 | :param image: 57 | :return: 58 | """ 59 | return self.DataAugmenters.run(image, **kwargs) 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | print('done!') 65 | -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import copy 6 | import numpy as np 7 | import torch 8 | from torch.autograd import Variable 9 | from torchvision import transforms 10 | from torch.utils.data import Dataset, DataLoader 11 | from data_loader.data_processor import DataProcessor 12 | 13 | 14 | class PyTorchDataset(Dataset): 15 | def __init__(self, txt, config, transform=None, loader = None, 16 | target_transform=None, is_train_set=True): 17 | self.config = config 18 | imgs = [] 19 | with open(txt,'r') as f: 20 | for line in f: 21 | line = line.strip('\n\r').strip('\n').strip('\r') 22 | words = line.split(self.config['file_label_separator']) 23 | # single label here so we use int(words[1]) 24 | imgs.append((words[0], int(words[1]))) 25 | 26 | self.DataProcessor = DataProcessor(self.config) 27 | self.imgs = imgs 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | self.is_train_set = is_train_set 31 | 32 | 33 | def __getitem__(self, index): 34 | fn, label = self.imgs[index] 35 | _root_dir = self.config['train_data_root_dir'] if self.is_train_set else self.config['val_data_root_dir'] 36 | image = self.self_defined_loader(os.path.join(_root_dir, fn)) 37 | if self.transform is not None: 38 | image = self.transform(image) 39 | 40 | return image, label 41 | 42 | 43 | def __len__(self): 44 | return len(self.imgs) 45 | 46 | 47 | def self_defined_loader(self, filename): 48 | image = self.DataProcessor.image_loader(filename) 49 | image = self.DataProcessor.image_resize(image) 50 | if self.is_train_set and self.config['data_aug']: 51 | image = self.DataProcessor.data_aug(image) 52 | image = self.DataProcessor.input_norm(image) 53 | return image 54 | 55 | 56 | def get_data_loader(config): 57 | """ 58 | 59 | :param config: 60 | :return: 61 | """ 62 | train_data_file = config['train_data_file'] 63 | test_data_file = config['val_data_file'] 64 | batch_size = config['batch_size'] 65 | num_workers =config['dataloader_workers'] 66 | shuffle = config['shuffle'] 67 | 68 | if not os.path.isfile(train_data_file): 69 | raise ValueError('train_data_file is not existed') 70 | if not os.path.isfile(test_data_file): 71 | raise ValueError('val_data_file is not existed') 72 | 73 | train_data = PyTorchDataset(txt=train_data_file,config=config, 74 | transform=transforms.ToTensor(), is_train_set=True) 75 | test_data = PyTorchDataset(txt=test_data_file,config=config, 76 | transform=transforms.ToTensor(), is_train_set=False) 77 | train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=shuffle, 78 | num_workers=num_workers) 79 | test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, 80 | num_workers=num_workers) 81 | 82 | return train_loader, test_loader 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | from importlib import import_module 12 | 13 | class TagPytorchInference(object): 14 | 15 | def __init__(self, **kwargs): 16 | _input_size = kwargs.get('input_size',299) 17 | self.input_size = (_input_size, _input_size) 18 | self.gpu_index = kwargs.get('gpu_index', '0') 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = self.gpu_index 21 | self.net = self._create_model(**kwargs) 22 | self._load(**kwargs) 23 | self.net.eval() 24 | self.transforms = transforms.ToTensor() 25 | if torch.cuda.is_available(): 26 | self.net.cuda() 27 | 28 | def close(self): 29 | torch.cuda.empty_cache() 30 | 31 | 32 | def _create_model(self, **kwargs): 33 | module_name = kwargs.get('module_name','vgg_module') 34 | net_name = kwargs.get('net_name', 'vgg16') 35 | m = import_module('nets.' + module_name) 36 | model = getattr(m, net_name) 37 | net = model(**kwargs) 38 | return net 39 | 40 | 41 | def _load(self, **kwargs): 42 | model_name = kwargs.get('model_name', 'model.pth') 43 | model_filename = model_name 44 | state_dict = torch.load(model_filename, map_location=None) 45 | self.net.load_state_dict(state_dict) 46 | 47 | 48 | def run(self, image_data, **kwargs): 49 | _image_data = self.image_preproces(image_data) 50 | input = self.transforms(_image_data) 51 | _size = input.size() 52 | input = input.resize_(1, _size[0], _size[1], _size[2]) 53 | if torch.cuda.is_available(): 54 | input = input.cuda() 55 | logit = self.net(Variable(input)) 56 | # softmax 57 | infer = F.softmax(logit, 1) 58 | return infer.data.cpu().numpy().tolist() 59 | 60 | 61 | def image_preproces(self, image_data): 62 | _image = cv2.resize(image_data, self.input_size) 63 | _image = _image[:,:,::-1] # bgr2rgb 64 | return _image.copy() 65 | 66 | if __name__ == "__main__": 67 | # # python3 inference.py --image test.jpg --module inception_resnet_v2_module --net inception_resnet_v2 --model model.pth 68 | import argparse 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('-image', "--image", type=str, help='Assign the image path.', default=None) 71 | parser.add_argument('-module', "--module", type=str, help='Assign the module name.', default=None) 72 | parser.add_argument('-net', "--net", type=str, help='Assign the net name.', default=None) 73 | parser.add_argument('-model', "--model", type=str, help='Assign the net name.', default=None) 74 | parser.add_argument('-cls', "--cls", type=int, help='Assign the classes number.', default=None) 75 | parser.add_argument('-size', "--size", type=int, help='Assign the input size.', default=None) 76 | args = parser.parse_args() 77 | if args.image is None or args.module is None or args.net is None or args.model is None\ 78 | or args.size is None or args.cls is None: 79 | raise TypeError('input error') 80 | if not os.path.exists(args.model): 81 | raise TypeError('cannot find file of model') 82 | if not os.path.exists(args.image): 83 | raise TypeError('cannot find file of image') 84 | print('test:') 85 | filename = args.image 86 | module_name = args.module 87 | net_name = args.net 88 | model_name = args.model 89 | input_size = args.size 90 | num_classes = args.cls 91 | image = cv2.imread(filename) 92 | if image is None: 93 | raise TypeError('image data is none') 94 | tagInfer = TagPytorchInference(module_name=module_name,net_name=net_name, 95 | num_classes=num_classes, model_name=model_name, 96 | input_size=input_size) 97 | result = tagInfer.run(image) 98 | print(result) 99 | print('done!') -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frotms/image_classification_pytorch/379b0596604efed8de2425584e7acc65996da726/logs/.gitignore -------------------------------------------------------------------------------- /nets/alexnet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | ''' 5 | 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | __all__ = ['AlexNet', 'alexnet'] 11 | 12 | 13 | model_urls = { 14 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 15 | } 16 | 17 | 18 | class AlexNet(nn.Module): 19 | 20 | def __init__(self, num_classes=1000): 21 | super(AlexNet, self).__init__() 22 | self.features = nn.Sequential( 23 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=3, stride=2), 29 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=3, stride=2), 36 | ) 37 | self.classifier = nn.Sequential( 38 | nn.Dropout(), 39 | nn.Linear(256 * 6 * 6, 4096), 40 | nn.ReLU(inplace=True), 41 | nn.Dropout(), 42 | nn.Linear(4096, 4096), 43 | nn.ReLU(inplace=True), 44 | nn.Linear(4096, num_classes), 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.features(x) 49 | x = x.view(x.size(0), 256 * 6 * 6) 50 | x = self.classifier(x) 51 | return x 52 | 53 | 54 | def _alexnet(pretrained=False, **kwargs): 55 | r"""AlexNet model architecture from the 56 | `"One weird trick..." `_ paper. 57 | Args: 58 | pretrained (bool): If True, returns a model pre-trained on ImageNet 59 | """ 60 | model = AlexNet(**kwargs) 61 | if pretrained: 62 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 63 | return model 64 | 65 | 66 | def alexnet(**kwargs): 67 | r"""AlexNet model architecture from the 68 | `"One weird trick..." `_ paper. 69 | Args: 70 | pretrained (bool): If True, returns a model pre-trained on ImageNet 71 | """ 72 | num_classes = kwargs.get('num_classes', 1000) 73 | model = AlexNet(num_classes=num_classes) 74 | return model 75 | 76 | -------------------------------------------------------------------------------- /nets/densenet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py 5 | ''' 6 | 7 | import re 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.model_zoo as model_zoo 12 | from collections import OrderedDict 13 | 14 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 15 | 16 | 17 | model_urls = { 18 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 19 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 20 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 21 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 22 | } 23 | 24 | 25 | def densenet121(pretrained=False, **kwargs): 26 | r"""Densenet-121 model from 27 | `"Densely Connected Convolutional Networks" `_ 28 | Args: 29 | pretrained (bool): If True, returns a model pre-trained on ImageNet 30 | """ 31 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 32 | **kwargs) 33 | if pretrained: 34 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 35 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 36 | # They are also in the checkpoints in model_urls. This pattern is used 37 | # to find such keys. 38 | pattern = re.compile( 39 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 40 | state_dict = model_zoo.load_url(model_urls['densenet121']) 41 | for key in list(state_dict.keys()): 42 | res = pattern.match(key) 43 | if res: 44 | new_key = res.group(1) + res.group(2) 45 | state_dict[new_key] = state_dict[key] 46 | del state_dict[key] 47 | model.load_state_dict(state_dict) 48 | return model 49 | 50 | 51 | def densenet169(pretrained=False, **kwargs): 52 | r"""Densenet-169 model from 53 | `"Densely Connected Convolutional Networks" `_ 54 | Args: 55 | pretrained (bool): If True, returns a model pre-trained on ImageNet 56 | """ 57 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 58 | **kwargs) 59 | if pretrained: 60 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 61 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 62 | # They are also in the checkpoints in model_urls. This pattern is used 63 | # to find such keys. 64 | pattern = re.compile( 65 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 66 | state_dict = model_zoo.load_url(model_urls['densenet169']) 67 | for key in list(state_dict.keys()): 68 | res = pattern.match(key) 69 | if res: 70 | new_key = res.group(1) + res.group(2) 71 | state_dict[new_key] = state_dict[key] 72 | del state_dict[key] 73 | model.load_state_dict(state_dict) 74 | return model 75 | 76 | 77 | def densenet201(pretrained=False, **kwargs): 78 | r"""Densenet-201 model from 79 | `"Densely Connected Convolutional Networks" `_ 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 84 | **kwargs) 85 | if pretrained: 86 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 87 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 88 | # They are also in the checkpoints in model_urls. This pattern is used 89 | # to find such keys. 90 | pattern = re.compile( 91 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 92 | state_dict = model_zoo.load_url(model_urls['densenet201']) 93 | for key in list(state_dict.keys()): 94 | res = pattern.match(key) 95 | if res: 96 | new_key = res.group(1) + res.group(2) 97 | state_dict[new_key] = state_dict[key] 98 | del state_dict[key] 99 | model.load_state_dict(state_dict) 100 | return model 101 | 102 | 103 | def densenet161(pretrained=False, **kwargs): 104 | r"""Densenet-161 model from 105 | `"Densely Connected Convolutional Networks" `_ 106 | Args: 107 | pretrained (bool): If True, returns a model pre-trained on ImageNet 108 | """ 109 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 110 | **kwargs) 111 | if pretrained: 112 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 113 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 114 | # They are also in the checkpoints in model_urls. This pattern is used 115 | # to find such keys. 116 | pattern = re.compile( 117 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 118 | state_dict = model_zoo.load_url(model_urls['densenet161']) 119 | for key in list(state_dict.keys()): 120 | res = pattern.match(key) 121 | if res: 122 | new_key = res.group(1) + res.group(2) 123 | state_dict[new_key] = state_dict[key] 124 | del state_dict[key] 125 | model.load_state_dict(state_dict) 126 | return model 127 | 128 | 129 | class _DenseLayer(nn.Sequential): 130 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 131 | super(_DenseLayer, self).__init__() 132 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 133 | self.add_module('relu1', nn.ReLU(inplace=True)), 134 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 135 | growth_rate, kernel_size=1, stride=1, bias=False)), 136 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 137 | self.add_module('relu2', nn.ReLU(inplace=True)), 138 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 139 | kernel_size=3, stride=1, padding=1, bias=False)), 140 | self.drop_rate = drop_rate 141 | 142 | def forward(self, x): 143 | new_features = super(_DenseLayer, self).forward(x) 144 | if self.drop_rate > 0: 145 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 146 | return torch.cat([x, new_features], 1) 147 | 148 | 149 | class _DenseBlock(nn.Sequential): 150 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 151 | super(_DenseBlock, self).__init__() 152 | for i in range(num_layers): 153 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 154 | self.add_module('denselayer%d' % (i + 1), layer) 155 | 156 | 157 | class _Transition(nn.Sequential): 158 | def __init__(self, num_input_features, num_output_features): 159 | super(_Transition, self).__init__() 160 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 161 | self.add_module('relu', nn.ReLU(inplace=True)) 162 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 163 | kernel_size=1, stride=1, bias=False)) 164 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 165 | 166 | 167 | class DenseNet(nn.Module): 168 | r"""Densenet-BC model class, based on 169 | `"Densely Connected Convolutional Networks" `_ 170 | Args: 171 | growth_rate (int) - how many filters to add each layer (`k` in paper) 172 | block_config (list of 4 ints) - how many layers in each pooling block 173 | num_init_features (int) - the number of filters to learn in the first convolution layer 174 | bn_size (int) - multiplicative factor for number of bottle neck layers 175 | (i.e. bn_size * k features in the bottleneck layer) 176 | drop_rate (float) - dropout rate after each dense layer 177 | num_classes (int) - number of classification classes 178 | """ 179 | 180 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 181 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 182 | 183 | super(DenseNet, self).__init__() 184 | 185 | # First convolution 186 | self.features = nn.Sequential(OrderedDict([ 187 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 188 | ('norm0', nn.BatchNorm2d(num_init_features)), 189 | ('relu0', nn.ReLU(inplace=True)), 190 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 191 | ])) 192 | 193 | # Each denseblock 194 | num_features = num_init_features 195 | for i, num_layers in enumerate(block_config): 196 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 197 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 198 | self.features.add_module('denseblock%d' % (i + 1), block) 199 | num_features = num_features + num_layers * growth_rate 200 | if i != len(block_config) - 1: 201 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 202 | self.features.add_module('transition%d' % (i + 1), trans) 203 | num_features = num_features // 2 204 | 205 | # Final batch norm 206 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 207 | 208 | # Linear layer 209 | self.classifier = nn.Linear(num_features, num_classes) 210 | 211 | # Official init from torch repo. 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | nn.init.kaiming_normal_(m.weight) 215 | elif isinstance(m, nn.BatchNorm2d): 216 | nn.init.constant_(m.weight, 1) 217 | nn.init.constant_(m.bias, 0) 218 | elif isinstance(m, nn.Linear): 219 | nn.init.constant_(m.bias, 0) 220 | 221 | def forward(self, x): 222 | features = self.features(x) 223 | out = F.relu(features, inplace=True) 224 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 225 | out = self.classifier(out) 226 | return out 227 | 228 | 229 | def densenet121(**kwargs): 230 | r"""Densenet-121 model from 231 | `"Densely Connected Convolutional Networks" `_ 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | num_classes = kwargs.get('num_classes', 1000) 236 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),num_classes=num_classes) 237 | return model 238 | 239 | 240 | def densenet169(**kwargs): 241 | r"""Densenet-169 model from 242 | `"Densely Connected Convolutional Networks" `_ 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | num_classes = kwargs.get('num_classes', 1000) 247 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),num_classes=num_classes) 248 | return model 249 | 250 | 251 | def densenet201(**kwargs): 252 | r"""Densenet-201 model from 253 | `"Densely Connected Convolutional Networks" `_ 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | num_classes = kwargs.get('num_classes', 1000) 258 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),num_classes=num_classes) 259 | return model 260 | 261 | 262 | def densenet161(**kwargs): 263 | r"""Densenet-161 model from 264 | `"Densely Connected Convolutional Networks" `_ 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | num_classes = kwargs.get('num_classes', 1000) 269 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),num_classes=num_classes) 270 | return model 271 | -------------------------------------------------------------------------------- /nets/inception_resnet_v2_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | import os 6 | import sys 7 | 8 | __all__ = ['InceptionResNetV2', 'inceptionresnetv2'] 9 | 10 | pretrained_settings = { 11 | 'inceptionresnetv2': { 12 | 'imagenet': { 13 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', 14 | 'input_space': 'RGB', 15 | 'input_size': [3, 299, 299], 16 | 'input_range': [0, 1], 17 | 'mean': [0.5, 0.5, 0.5], 18 | 'std': [0.5, 0.5, 0.5], 19 | 'num_classes': 1000 20 | }, 21 | 'imagenet+background': { 22 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', 23 | 'input_space': 'RGB', 24 | 'input_size': [3, 299, 299], 25 | 'input_range': [0, 1], 26 | 'mean': [0.5, 0.5, 0.5], 27 | 'std': [0.5, 0.5, 0.5], 28 | 'num_classes': 1001 29 | } 30 | } 31 | } 32 | 33 | 34 | class BasicConv2d(nn.Module): 35 | 36 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 37 | super(BasicConv2d, self).__init__() 38 | self.conv = nn.Conv2d(in_planes, out_planes, 39 | kernel_size=kernel_size, stride=stride, 40 | padding=padding, bias=False) # verify bias false 41 | self.bn = nn.BatchNorm2d(out_planes, 42 | eps=0.001, # value found in tensorflow 43 | momentum=0.1, # default pytorch value 44 | affine=True) 45 | self.relu = nn.ReLU(inplace=False) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | x = self.bn(x) 50 | x = self.relu(x) 51 | return x 52 | 53 | 54 | class Mixed_5b(nn.Module): 55 | 56 | def __init__(self): 57 | super(Mixed_5b, self).__init__() 58 | 59 | self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) 60 | 61 | self.branch1 = nn.Sequential( 62 | BasicConv2d(192, 48, kernel_size=1, stride=1), 63 | BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) 64 | ) 65 | 66 | self.branch2 = nn.Sequential( 67 | BasicConv2d(192, 64, kernel_size=1, stride=1), 68 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 69 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 70 | ) 71 | 72 | self.branch3 = nn.Sequential( 73 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 74 | BasicConv2d(192, 64, kernel_size=1, stride=1) 75 | ) 76 | 77 | def forward(self, x): 78 | x0 = self.branch0(x) 79 | x1 = self.branch1(x) 80 | x2 = self.branch2(x) 81 | x3 = self.branch3(x) 82 | out = torch.cat((x0, x1, x2, x3), 1) 83 | return out 84 | 85 | 86 | class Block35(nn.Module): 87 | 88 | def __init__(self, scale=1.0): 89 | super(Block35, self).__init__() 90 | 91 | self.scale = scale 92 | 93 | self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) 94 | 95 | self.branch1 = nn.Sequential( 96 | BasicConv2d(320, 32, kernel_size=1, stride=1), 97 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 98 | ) 99 | 100 | self.branch2 = nn.Sequential( 101 | BasicConv2d(320, 32, kernel_size=1, stride=1), 102 | BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), 103 | BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) 104 | ) 105 | 106 | self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) 107 | self.relu = nn.ReLU(inplace=False) 108 | 109 | def forward(self, x): 110 | x0 = self.branch0(x) 111 | x1 = self.branch1(x) 112 | x2 = self.branch2(x) 113 | out = torch.cat((x0, x1, x2), 1) 114 | out = self.conv2d(out) 115 | out = out * self.scale + x 116 | out = self.relu(out) 117 | return out 118 | 119 | 120 | class Mixed_6a(nn.Module): 121 | 122 | def __init__(self): 123 | super(Mixed_6a, self).__init__() 124 | 125 | self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) 126 | 127 | self.branch1 = nn.Sequential( 128 | BasicConv2d(320, 256, kernel_size=1, stride=1), 129 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 130 | BasicConv2d(256, 384, kernel_size=3, stride=2) 131 | ) 132 | 133 | self.branch2 = nn.MaxPool2d(3, stride=2) 134 | 135 | def forward(self, x): 136 | x0 = self.branch0(x) 137 | x1 = self.branch1(x) 138 | x2 = self.branch2(x) 139 | out = torch.cat((x0, x1, x2), 1) 140 | return out 141 | 142 | 143 | class Block17(nn.Module): 144 | 145 | def __init__(self, scale=1.0): 146 | super(Block17, self).__init__() 147 | 148 | self.scale = scale 149 | 150 | self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) 151 | 152 | self.branch1 = nn.Sequential( 153 | BasicConv2d(1088, 128, kernel_size=1, stride=1), 154 | BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)), 155 | BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)) 156 | ) 157 | 158 | self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) 159 | self.relu = nn.ReLU(inplace=False) 160 | 161 | def forward(self, x): 162 | x0 = self.branch0(x) 163 | x1 = self.branch1(x) 164 | out = torch.cat((x0, x1), 1) 165 | out = self.conv2d(out) 166 | out = out * self.scale + x 167 | out = self.relu(out) 168 | return out 169 | 170 | 171 | class Mixed_7a(nn.Module): 172 | 173 | def __init__(self): 174 | super(Mixed_7a, self).__init__() 175 | 176 | self.branch0 = nn.Sequential( 177 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 178 | BasicConv2d(256, 384, kernel_size=3, stride=2) 179 | ) 180 | 181 | self.branch1 = nn.Sequential( 182 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 183 | BasicConv2d(256, 288, kernel_size=3, stride=2) 184 | ) 185 | 186 | self.branch2 = nn.Sequential( 187 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 188 | BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), 189 | BasicConv2d(288, 320, kernel_size=3, stride=2) 190 | ) 191 | 192 | self.branch3 = nn.MaxPool2d(3, stride=2) 193 | 194 | def forward(self, x): 195 | x0 = self.branch0(x) 196 | x1 = self.branch1(x) 197 | x2 = self.branch2(x) 198 | x3 = self.branch3(x) 199 | out = torch.cat((x0, x1, x2, x3), 1) 200 | return out 201 | 202 | 203 | class Block8(nn.Module): 204 | 205 | def __init__(self, scale=1.0, noReLU=False): 206 | super(Block8, self).__init__() 207 | 208 | self.scale = scale 209 | self.noReLU = noReLU 210 | 211 | self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) 212 | 213 | self.branch1 = nn.Sequential( 214 | BasicConv2d(2080, 192, kernel_size=1, stride=1), 215 | BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), 216 | BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 217 | ) 218 | 219 | self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) 220 | if not self.noReLU: 221 | self.relu = nn.ReLU(inplace=False) 222 | 223 | def forward(self, x): 224 | x0 = self.branch0(x) 225 | x1 = self.branch1(x) 226 | out = torch.cat((x0, x1), 1) 227 | out = self.conv2d(out) 228 | out = out * self.scale + x 229 | if not self.noReLU: 230 | out = self.relu(out) 231 | return out 232 | 233 | 234 | class InceptionResNetV2(nn.Module): 235 | 236 | def __init__(self, num_classes=1001): 237 | super(InceptionResNetV2, self).__init__() 238 | # Special attributs 239 | self.input_space = None 240 | self.input_size = (299, 299, 3) 241 | self.mean = None 242 | self.std = None 243 | # Modules 244 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 245 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 246 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 247 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 248 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 249 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 250 | self.maxpool_5a = nn.MaxPool2d(3, stride=2) 251 | self.mixed_5b = Mixed_5b() 252 | self.repeat = nn.Sequential( 253 | Block35(scale=0.17), 254 | Block35(scale=0.17), 255 | Block35(scale=0.17), 256 | Block35(scale=0.17), 257 | Block35(scale=0.17), 258 | Block35(scale=0.17), 259 | Block35(scale=0.17), 260 | Block35(scale=0.17), 261 | Block35(scale=0.17), 262 | Block35(scale=0.17) 263 | ) 264 | self.mixed_6a = Mixed_6a() 265 | self.repeat_1 = nn.Sequential( 266 | Block17(scale=0.10), 267 | Block17(scale=0.10), 268 | Block17(scale=0.10), 269 | Block17(scale=0.10), 270 | Block17(scale=0.10), 271 | Block17(scale=0.10), 272 | Block17(scale=0.10), 273 | Block17(scale=0.10), 274 | Block17(scale=0.10), 275 | Block17(scale=0.10), 276 | Block17(scale=0.10), 277 | Block17(scale=0.10), 278 | Block17(scale=0.10), 279 | Block17(scale=0.10), 280 | Block17(scale=0.10), 281 | Block17(scale=0.10), 282 | Block17(scale=0.10), 283 | Block17(scale=0.10), 284 | Block17(scale=0.10), 285 | Block17(scale=0.10) 286 | ) 287 | self.mixed_7a = Mixed_7a() 288 | self.repeat_2 = nn.Sequential( 289 | Block8(scale=0.20), 290 | Block8(scale=0.20), 291 | Block8(scale=0.20), 292 | Block8(scale=0.20), 293 | Block8(scale=0.20), 294 | Block8(scale=0.20), 295 | Block8(scale=0.20), 296 | Block8(scale=0.20), 297 | Block8(scale=0.20) 298 | ) 299 | self.block8 = Block8(noReLU=True) 300 | self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) 301 | self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) 302 | self.last_linear = nn.Linear(1536, num_classes) 303 | 304 | def features(self, input): 305 | x = self.conv2d_1a(input) 306 | x = self.conv2d_2a(x) 307 | x = self.conv2d_2b(x) 308 | x = self.maxpool_3a(x) 309 | x = self.conv2d_3b(x) 310 | x = self.conv2d_4a(x) 311 | x = self.maxpool_5a(x) 312 | x = self.mixed_5b(x) 313 | x = self.repeat(x) 314 | x = self.mixed_6a(x) 315 | x = self.repeat_1(x) 316 | x = self.mixed_7a(x) 317 | x = self.repeat_2(x) 318 | x = self.block8(x) 319 | x = self.conv2d_7b(x) 320 | return x 321 | 322 | def logits(self, features): 323 | x = self.avgpool_1a(features) 324 | x = x.view(x.size(0), -1) 325 | x = self.last_linear(x) 326 | return x 327 | 328 | def forward(self, input): 329 | x = self.features(input) 330 | x = self.logits(x) 331 | return x 332 | 333 | def inceptionresnetv2(num_classes=1000, pretrained='imagenet'): 334 | r"""InceptionResNetV2 model architecture from the 335 | `"InceptionV4, Inception-ResNet..." `_ paper. 336 | """ 337 | if pretrained: 338 | settings = pretrained_settings['inceptionresnetv2'][pretrained] 339 | assert num_classes == settings['num_classes'], \ 340 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 341 | 342 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 343 | model = InceptionResNetV2(num_classes=1001) 344 | model.load_state_dict(model_zoo.load_url(settings['url'])) 345 | 346 | if pretrained == 'imagenet': 347 | new_last_linear = nn.Linear(1536, 1000) 348 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 349 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 350 | model.last_linear = new_last_linear 351 | 352 | model.input_space = settings['input_space'] 353 | model.input_size = settings['input_size'] 354 | model.input_range = settings['input_range'] 355 | 356 | model.mean = settings['mean'] 357 | model.std = settings['std'] 358 | else: 359 | model = InceptionResNetV2(num_classes=num_classes) 360 | return model 361 | 362 | 363 | def inception_resnet_v2(**kwargs): 364 | num_classes = kwargs.get('num_classes', 1001) 365 | model = InceptionResNetV2(num_classes=num_classes) 366 | return model 367 | 368 | 369 | ''' 370 | TEST 371 | Run this code with: 372 | ``` 373 | cd $HOME/pretrained-models.pytorch 374 | python -m pretrainedmodels.inceptionresnetv2 375 | ``` 376 | ''' 377 | if __name__ == '__main__': 378 | 379 | assert inceptionresnetv2(num_classes=10, pretrained=None) 380 | print('success') 381 | assert inceptionresnetv2(num_classes=1000, pretrained='imagenet') 382 | print('success') 383 | assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background') 384 | print('success') 385 | 386 | # fail 387 | assert inceptionresnetv2(num_classes=1001, pretrained='imagenet') -------------------------------------------------------------------------------- /nets/inception_v3_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | 12 | __all__ = ['Inception3', 'inception_v3'] 13 | 14 | 15 | model_urls = { 16 | # Inception v3 ported from TensorFlow 17 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 18 | } 19 | 20 | 21 | def _inception_v3(pretrained=False, **kwargs): 22 | r"""Inception v3 model architecture from 23 | `"Rethinking the Inception Architecture for Computer Vision" `_. 24 | Args: 25 | pretrained (bool): If True, returns a model pre-trained on ImageNet 26 | """ 27 | if pretrained: 28 | if 'transform_input' not in kwargs: 29 | kwargs['transform_input'] = True 30 | model = Inception3(**kwargs) 31 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 32 | return model 33 | 34 | return Inception3(**kwargs) 35 | 36 | 37 | def inception_v3(**kwargs): 38 | num_classes = kwargs.get('num_classes', 1000) 39 | return Inception3(num_classes=num_classes) 40 | 41 | 42 | class Inception3(nn.Module): 43 | 44 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 45 | super(Inception3, self).__init__() 46 | self.aux_logits = aux_logits 47 | self.transform_input = transform_input 48 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 49 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 50 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 51 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 52 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 53 | self.Mixed_5b = InceptionA(192, pool_features=32) 54 | self.Mixed_5c = InceptionA(256, pool_features=64) 55 | self.Mixed_5d = InceptionA(288, pool_features=64) 56 | self.Mixed_6a = InceptionB(288) 57 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 58 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 59 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 60 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 61 | if aux_logits: 62 | self.AuxLogits = InceptionAux(768, num_classes) 63 | self.Mixed_7a = InceptionD(768) 64 | self.Mixed_7b = InceptionE(1280) 65 | self.Mixed_7c = InceptionE(2048) 66 | self.fc = nn.Linear(2048, num_classes) 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 70 | import scipy.stats as stats 71 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 72 | X = stats.truncnorm(-2, 2, scale=stddev) 73 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 74 | values = values.view(m.weight.data.size()) 75 | m.weight.data.copy_(values) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | 80 | def forward(self, x): 81 | if self.transform_input: 82 | x = x.clone() 83 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 84 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 85 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 86 | # 299 x 299 x 3 87 | x = self.Conv2d_1a_3x3(x) 88 | # 149 x 149 x 32 89 | x = self.Conv2d_2a_3x3(x) 90 | # 147 x 147 x 32 91 | x = self.Conv2d_2b_3x3(x) 92 | # 147 x 147 x 64 93 | x = F.max_pool2d(x, kernel_size=3, stride=2) 94 | # 73 x 73 x 64 95 | x = self.Conv2d_3b_1x1(x) 96 | # 73 x 73 x 80 97 | x = self.Conv2d_4a_3x3(x) 98 | # 71 x 71 x 192 99 | x = F.max_pool2d(x, kernel_size=3, stride=2) 100 | # 35 x 35 x 192 101 | x = self.Mixed_5b(x) 102 | # 35 x 35 x 256 103 | x = self.Mixed_5c(x) 104 | # 35 x 35 x 288 105 | x = self.Mixed_5d(x) 106 | # 35 x 35 x 288 107 | x = self.Mixed_6a(x) 108 | # 17 x 17 x 768 109 | x = self.Mixed_6b(x) 110 | # 17 x 17 x 768 111 | x = self.Mixed_6c(x) 112 | # 17 x 17 x 768 113 | x = self.Mixed_6d(x) 114 | # 17 x 17 x 768 115 | x = self.Mixed_6e(x) 116 | # 17 x 17 x 768 117 | if self.training and self.aux_logits: 118 | aux = self.AuxLogits(x) 119 | # 17 x 17 x 768 120 | x = self.Mixed_7a(x) 121 | # 8 x 8 x 1280 122 | x = self.Mixed_7b(x) 123 | # 8 x 8 x 2048 124 | x = self.Mixed_7c(x) 125 | # 8 x 8 x 2048 126 | x = F.avg_pool2d(x, kernel_size=8) 127 | # 1 x 1 x 2048 128 | x = F.dropout(x, training=self.training) 129 | # 1 x 1 x 2048 130 | x = x.view(x.size(0), -1) 131 | # 2048 132 | x = self.fc(x) 133 | # 1000 (num_classes) 134 | if self.training and self.aux_logits: 135 | return x#, aux 136 | return x 137 | 138 | 139 | class InceptionA(nn.Module): 140 | 141 | def __init__(self, in_channels, pool_features): 142 | super(InceptionA, self).__init__() 143 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 144 | 145 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 146 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 147 | 148 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 149 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 150 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 151 | 152 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 153 | 154 | def forward(self, x): 155 | branch1x1 = self.branch1x1(x) 156 | 157 | branch5x5 = self.branch5x5_1(x) 158 | branch5x5 = self.branch5x5_2(branch5x5) 159 | 160 | branch3x3dbl = self.branch3x3dbl_1(x) 161 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 162 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 163 | 164 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 165 | branch_pool = self.branch_pool(branch_pool) 166 | 167 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 168 | return torch.cat(outputs, 1) 169 | 170 | 171 | class InceptionB(nn.Module): 172 | 173 | def __init__(self, in_channels): 174 | super(InceptionB, self).__init__() 175 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 176 | 177 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 178 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 179 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 180 | 181 | def forward(self, x): 182 | branch3x3 = self.branch3x3(x) 183 | 184 | branch3x3dbl = self.branch3x3dbl_1(x) 185 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 186 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 187 | 188 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 189 | 190 | outputs = [branch3x3, branch3x3dbl, branch_pool] 191 | return torch.cat(outputs, 1) 192 | 193 | 194 | class InceptionC(nn.Module): 195 | 196 | def __init__(self, in_channels, channels_7x7): 197 | super(InceptionC, self).__init__() 198 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 199 | 200 | c7 = channels_7x7 201 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 202 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 203 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 204 | 205 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 206 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 207 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 208 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 209 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 210 | 211 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 212 | 213 | def forward(self, x): 214 | branch1x1 = self.branch1x1(x) 215 | 216 | branch7x7 = self.branch7x7_1(x) 217 | branch7x7 = self.branch7x7_2(branch7x7) 218 | branch7x7 = self.branch7x7_3(branch7x7) 219 | 220 | branch7x7dbl = self.branch7x7dbl_1(x) 221 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 222 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 223 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 224 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 225 | 226 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 227 | branch_pool = self.branch_pool(branch_pool) 228 | 229 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 230 | return torch.cat(outputs, 1) 231 | 232 | 233 | class InceptionD(nn.Module): 234 | 235 | def __init__(self, in_channels): 236 | super(InceptionD, self).__init__() 237 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 238 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 239 | 240 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 241 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 242 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 243 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 244 | 245 | def forward(self, x): 246 | branch3x3 = self.branch3x3_1(x) 247 | branch3x3 = self.branch3x3_2(branch3x3) 248 | 249 | branch7x7x3 = self.branch7x7x3_1(x) 250 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 251 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 252 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 253 | 254 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 255 | outputs = [branch3x3, branch7x7x3, branch_pool] 256 | return torch.cat(outputs, 1) 257 | 258 | 259 | class InceptionE(nn.Module): 260 | 261 | def __init__(self, in_channels): 262 | super(InceptionE, self).__init__() 263 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 264 | 265 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 266 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 267 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 268 | 269 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 270 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 271 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 272 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 273 | 274 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 275 | 276 | def forward(self, x): 277 | branch1x1 = self.branch1x1(x) 278 | 279 | branch3x3 = self.branch3x3_1(x) 280 | branch3x3 = [ 281 | self.branch3x3_2a(branch3x3), 282 | self.branch3x3_2b(branch3x3), 283 | ] 284 | branch3x3 = torch.cat(branch3x3, 1) 285 | 286 | branch3x3dbl = self.branch3x3dbl_1(x) 287 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 288 | branch3x3dbl = [ 289 | self.branch3x3dbl_3a(branch3x3dbl), 290 | self.branch3x3dbl_3b(branch3x3dbl), 291 | ] 292 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 293 | 294 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 295 | branch_pool = self.branch_pool(branch_pool) 296 | 297 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 298 | return torch.cat(outputs, 1) 299 | 300 | 301 | class InceptionAux(nn.Module): 302 | 303 | def __init__(self, in_channels, num_classes): 304 | super(InceptionAux, self).__init__() 305 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 306 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 307 | self.conv1.stddev = 0.01 308 | self.fc = nn.Linear(768, num_classes) 309 | self.fc.stddev = 0.001 310 | 311 | def forward(self, x): 312 | # 17 x 17 x 768 313 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 314 | # 5 x 5 x 768 315 | x = self.conv0(x) 316 | # 5 x 5 x 128 317 | x = self.conv1(x) 318 | # 1 x 1 x 768 319 | x = x.view(x.size(0), -1) 320 | # 768 321 | x = self.fc(x) 322 | # 1000 323 | return x 324 | 325 | 326 | class BasicConv2d(nn.Module): 327 | 328 | def __init__(self, in_channels, out_channels, **kwargs): 329 | super(BasicConv2d, self).__init__() 330 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 331 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 332 | 333 | def forward(self, x): 334 | x = self.conv(x) 335 | x = self.bn(x) 336 | return F.relu(x, inplace=True) 337 | 338 | 339 | -------------------------------------------------------------------------------- /nets/inception_v4_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | import os 6 | import sys 7 | 8 | __all__ = ['InceptionV4', 'inceptionv4'] 9 | 10 | pretrained_settings = { 11 | 'inceptionv4': { 12 | 'imagenet': { 13 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', 14 | 'input_space': 'RGB', 15 | 'input_size': [3, 299, 299], 16 | 'input_range': [0, 1], 17 | 'mean': [0.5, 0.5, 0.5], 18 | 'std': [0.5, 0.5, 0.5], 19 | 'num_classes': 1000 20 | }, 21 | 'imagenet+background': { 22 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', 23 | 'input_space': 'RGB', 24 | 'input_size': [3, 299, 299], 25 | 'input_range': [0, 1], 26 | 'mean': [0.5, 0.5, 0.5], 27 | 'std': [0.5, 0.5, 0.5], 28 | 'num_classes': 1001 29 | } 30 | } 31 | } 32 | 33 | 34 | class BasicConv2d(nn.Module): 35 | 36 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 37 | super(BasicConv2d, self).__init__() 38 | self.conv = nn.Conv2d(in_planes, out_planes, 39 | kernel_size=kernel_size, stride=stride, 40 | padding=padding, bias=False) # verify bias false 41 | self.bn = nn.BatchNorm2d(out_planes, 42 | eps=0.001, # value found in tensorflow 43 | momentum=0.1, # default pytorch value 44 | affine=True) 45 | self.relu = nn.ReLU(inplace=True) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | x = self.bn(x) 50 | x = self.relu(x) 51 | return x 52 | 53 | 54 | class Mixed_3a(nn.Module): 55 | 56 | def __init__(self): 57 | super(Mixed_3a, self).__init__() 58 | self.maxpool = nn.MaxPool2d(3, stride=2) 59 | self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) 60 | 61 | def forward(self, x): 62 | x0 = self.maxpool(x) 63 | x1 = self.conv(x) 64 | out = torch.cat((x0, x1), 1) 65 | return out 66 | 67 | 68 | class Mixed_4a(nn.Module): 69 | 70 | def __init__(self): 71 | super(Mixed_4a, self).__init__() 72 | 73 | self.branch0 = nn.Sequential( 74 | BasicConv2d(160, 64, kernel_size=1, stride=1), 75 | BasicConv2d(64, 96, kernel_size=3, stride=1) 76 | ) 77 | 78 | self.branch1 = nn.Sequential( 79 | BasicConv2d(160, 64, kernel_size=1, stride=1), 80 | BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)), 81 | BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)), 82 | BasicConv2d(64, 96, kernel_size=(3,3), stride=1) 83 | ) 84 | 85 | def forward(self, x): 86 | x0 = self.branch0(x) 87 | x1 = self.branch1(x) 88 | out = torch.cat((x0, x1), 1) 89 | return out 90 | 91 | 92 | class Mixed_5a(nn.Module): 93 | 94 | def __init__(self): 95 | super(Mixed_5a, self).__init__() 96 | self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) 97 | self.maxpool = nn.MaxPool2d(3, stride=2) 98 | 99 | def forward(self, x): 100 | x0 = self.conv(x) 101 | x1 = self.maxpool(x) 102 | out = torch.cat((x0, x1), 1) 103 | return out 104 | 105 | 106 | class Inception_A(nn.Module): 107 | 108 | def __init__(self): 109 | super(Inception_A, self).__init__() 110 | self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) 111 | 112 | self.branch1 = nn.Sequential( 113 | BasicConv2d(384, 64, kernel_size=1, stride=1), 114 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) 115 | ) 116 | 117 | self.branch2 = nn.Sequential( 118 | BasicConv2d(384, 64, kernel_size=1, stride=1), 119 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 120 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 121 | ) 122 | 123 | self.branch3 = nn.Sequential( 124 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 125 | BasicConv2d(384, 96, kernel_size=1, stride=1) 126 | ) 127 | 128 | def forward(self, x): 129 | x0 = self.branch0(x) 130 | x1 = self.branch1(x) 131 | x2 = self.branch2(x) 132 | x3 = self.branch3(x) 133 | out = torch.cat((x0, x1, x2, x3), 1) 134 | return out 135 | 136 | 137 | class Reduction_A(nn.Module): 138 | 139 | def __init__(self): 140 | super(Reduction_A, self).__init__() 141 | self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) 142 | 143 | self.branch1 = nn.Sequential( 144 | BasicConv2d(384, 192, kernel_size=1, stride=1), 145 | BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), 146 | BasicConv2d(224, 256, kernel_size=3, stride=2) 147 | ) 148 | 149 | self.branch2 = nn.MaxPool2d(3, stride=2) 150 | 151 | def forward(self, x): 152 | x0 = self.branch0(x) 153 | x1 = self.branch1(x) 154 | x2 = self.branch2(x) 155 | out = torch.cat((x0, x1, x2), 1) 156 | return out 157 | 158 | 159 | class Inception_B(nn.Module): 160 | 161 | def __init__(self): 162 | super(Inception_B, self).__init__() 163 | self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) 164 | 165 | self.branch1 = nn.Sequential( 166 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 167 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 168 | BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0)) 169 | ) 170 | 171 | self.branch2 = nn.Sequential( 172 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 173 | BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)), 174 | BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)), 175 | BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)), 176 | BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3)) 177 | ) 178 | 179 | self.branch3 = nn.Sequential( 180 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 181 | BasicConv2d(1024, 128, kernel_size=1, stride=1) 182 | ) 183 | 184 | def forward(self, x): 185 | x0 = self.branch0(x) 186 | x1 = self.branch1(x) 187 | x2 = self.branch2(x) 188 | x3 = self.branch3(x) 189 | out = torch.cat((x0, x1, x2, x3), 1) 190 | return out 191 | 192 | 193 | class Reduction_B(nn.Module): 194 | 195 | def __init__(self): 196 | super(Reduction_B, self).__init__() 197 | 198 | self.branch0 = nn.Sequential( 199 | BasicConv2d(1024, 192, kernel_size=1, stride=1), 200 | BasicConv2d(192, 192, kernel_size=3, stride=2) 201 | ) 202 | 203 | self.branch1 = nn.Sequential( 204 | BasicConv2d(1024, 256, kernel_size=1, stride=1), 205 | BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)), 206 | BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)), 207 | BasicConv2d(320, 320, kernel_size=3, stride=2) 208 | ) 209 | 210 | self.branch2 = nn.MaxPool2d(3, stride=2) 211 | 212 | def forward(self, x): 213 | x0 = self.branch0(x) 214 | x1 = self.branch1(x) 215 | x2 = self.branch2(x) 216 | out = torch.cat((x0, x1, x2), 1) 217 | return out 218 | 219 | 220 | class Inception_C(nn.Module): 221 | 222 | def __init__(self): 223 | super(Inception_C, self).__init__() 224 | 225 | self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) 226 | 227 | self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 228 | self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 229 | self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 230 | 231 | self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) 232 | self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0)) 233 | self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1)) 234 | self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1)) 235 | self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 236 | 237 | self.branch3 = nn.Sequential( 238 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 239 | BasicConv2d(1536, 256, kernel_size=1, stride=1) 240 | ) 241 | 242 | def forward(self, x): 243 | x0 = self.branch0(x) 244 | 245 | x1_0 = self.branch1_0(x) 246 | x1_1a = self.branch1_1a(x1_0) 247 | x1_1b = self.branch1_1b(x1_0) 248 | x1 = torch.cat((x1_1a, x1_1b), 1) 249 | 250 | x2_0 = self.branch2_0(x) 251 | x2_1 = self.branch2_1(x2_0) 252 | x2_2 = self.branch2_2(x2_1) 253 | x2_3a = self.branch2_3a(x2_2) 254 | x2_3b = self.branch2_3b(x2_2) 255 | x2 = torch.cat((x2_3a, x2_3b), 1) 256 | 257 | x3 = self.branch3(x) 258 | 259 | out = torch.cat((x0, x1, x2, x3), 1) 260 | return out 261 | 262 | 263 | class InceptionV4(nn.Module): 264 | 265 | def __init__(self, num_classes=1001): 266 | super(InceptionV4, self).__init__() 267 | # Special attributs 268 | self.input_space = None 269 | self.input_size = (299, 299, 3) 270 | self.mean = None 271 | self.std = None 272 | # Modules 273 | self.features = nn.Sequential( 274 | BasicConv2d(3, 32, kernel_size=3, stride=2), 275 | BasicConv2d(32, 32, kernel_size=3, stride=1), 276 | BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), 277 | Mixed_3a(), 278 | Mixed_4a(), 279 | Mixed_5a(), 280 | Inception_A(), 281 | Inception_A(), 282 | Inception_A(), 283 | Inception_A(), 284 | Reduction_A(), # Mixed_6a 285 | Inception_B(), 286 | Inception_B(), 287 | Inception_B(), 288 | Inception_B(), 289 | Inception_B(), 290 | Inception_B(), 291 | Inception_B(), 292 | Reduction_B(), # Mixed_7a 293 | Inception_C(), 294 | Inception_C(), 295 | Inception_C() 296 | ) 297 | self.avg_pool = nn.AvgPool2d(8, count_include_pad=False) 298 | self.last_linear = nn.Linear(1536, num_classes) 299 | 300 | def logits(self, features): 301 | x = self.avg_pool(features) 302 | x = x.view(x.size(0), -1) 303 | x = self.last_linear(x) 304 | return x 305 | 306 | def forward(self, input): 307 | x = self.features(input) 308 | x = self.logits(x) 309 | return x 310 | 311 | 312 | def _inceptionv4(num_classes=1000, pretrained='imagenet'): 313 | if pretrained: 314 | settings = pretrained_settings['inceptionv4'][pretrained] 315 | assert num_classes == settings['num_classes'], \ 316 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 317 | 318 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 319 | model = InceptionV4(num_classes=1001) 320 | model.load_state_dict(model_zoo.load_url(settings['url'])) 321 | 322 | if pretrained == 'imagenet': 323 | new_last_linear = nn.Linear(1536, 1000) 324 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 325 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 326 | model.last_linear = new_last_linear 327 | 328 | model.input_space = settings['input_space'] 329 | model.input_size = settings['input_size'] 330 | model.input_range = settings['input_range'] 331 | model.mean = settings['mean'] 332 | model.std = settings['std'] 333 | else: 334 | model = InceptionV4(num_classes=num_classes) 335 | return model 336 | 337 | 338 | def inception_v4(**kwargs): 339 | num_classes = kwargs.get('num_classes', 1000) 340 | model = InceptionV4(num_classes=num_classes) 341 | return model 342 | 343 | 344 | ''' 345 | TEST 346 | Run this code with: 347 | ``` 348 | cd $HOME/pretrained-models.pytorch 349 | python -m pretrainedmodels.inceptionv4 350 | ``` 351 | ''' 352 | if __name__ == '__main__': 353 | 354 | assert inceptionv4(num_classes=10, pretrained=None) 355 | print('success') 356 | assert inceptionv4(num_classes=1000, pretrained='imagenet') 357 | print('success') 358 | assert inceptionv4(num_classes=1001, pretrained='imagenet+background') 359 | print('success') 360 | 361 | # fail 362 | assert inceptionv4(num_classes=1001, pretrained='imagenet') -------------------------------------------------------------------------------- /nets/mnasnet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/billhhh/MnasNet-pytorch-pretrained 4 | 5 | ''' 6 | from torch.autograd import Variable 7 | import torch.nn as nn 8 | import torch 9 | import math 10 | 11 | 12 | def Conv_3x3(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | nn.ReLU6(inplace=True) 17 | ) 18 | 19 | 20 | def Conv_1x1(inp, oup): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU6(inplace=True) 25 | ) 26 | 27 | def SepConv_3x3(inp, oup): #input=32, output=16 28 | return nn.Sequential( 29 | # dw 30 | nn.Conv2d(inp, inp , 3, 1, 1, groups=inp, bias=False), 31 | nn.BatchNorm2d(inp), 32 | nn.ReLU6(inplace=True), 33 | # pw-linear 34 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 35 | nn.BatchNorm2d(oup), 36 | ) 37 | 38 | 39 | class InvertedResidual(nn.Module): 40 | def __init__(self, inp, oup, stride, expand_ratio, kernel): 41 | super(InvertedResidual, self).__init__() 42 | self.stride = stride 43 | assert stride in [1, 2] 44 | 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(inp * expand_ratio), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, kernel, stride, kernel // 2, groups=inp * expand_ratio, bias=False), 54 | nn.BatchNorm2d(inp * expand_ratio), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MnasNet(nn.Module): 69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 70 | super(MnasNet, self).__init__() 71 | 72 | # setting of inverted residual blocks 73 | self.interverted_residual_setting = [ 74 | # t, c, n, s, k 75 | [3, 24, 3, 2, 3], # -> 56x56 76 | [3, 40, 3, 2, 5], # -> 28x28 77 | [6, 80, 3, 2, 5], # -> 14x14 78 | [6, 96, 2, 1, 3], # -> 14x14 79 | [6, 192, 4, 2, 5], # -> 7x7 80 | [6, 320, 1, 1, 3], # -> 7x7 81 | ] 82 | 83 | assert input_size % 32 == 0 84 | input_channel = int(32 * width_mult) 85 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 86 | 87 | # building first two layer 88 | self.features = [Conv_3x3(3, input_channel, 2), SepConv_3x3(input_channel, 16)] 89 | input_channel = 16 90 | 91 | # building inverted residual blocks (MBConv) 92 | for t, c, n, s, k in self.interverted_residual_setting: 93 | output_channel = int(c * width_mult) 94 | for i in range(n): 95 | if i == 0: 96 | self.features.append(InvertedResidual(input_channel, output_channel, s, t, k)) 97 | else: 98 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t, k)) 99 | input_channel = output_channel 100 | 101 | # building last several layers 102 | self.features.append(Conv_1x1(input_channel, self.last_channel)) 103 | self.features.append(nn.AdaptiveAvgPool2d(1)) 104 | 105 | # make it nn.Sequential 106 | self.features = nn.Sequential(*self.features) 107 | 108 | # building classifier 109 | self.classifier = nn.Sequential( 110 | nn.Dropout(0.0), 111 | nn.Linear(self.last_channel, n_class), 112 | ) 113 | 114 | self._initialize_weights() 115 | 116 | def forward(self, x): 117 | x = self.features(x) 118 | x = x.view(-1, self.last_channel) 119 | x = self.classifier(x) 120 | return x 121 | 122 | def _initialize_weights(self): 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | if m.bias is not None: 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | elif isinstance(m, nn.Linear): 133 | n = m.weight.size(1) 134 | m.weight.data.normal_(0, 0.01) 135 | m.bias.data.zero_() 136 | 137 | 138 | if __name__ == '__main__': 139 | """Testing 140 | """ 141 | model = MnasNet() 142 | print(model) 143 | -------------------------------------------------------------------------------- /nets/mobilenet_v2_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | https://github.com/ericsun99/MobileNet-V2-Pytorch 5 | pretrain model: mobilenetv2_Top1_71.806_Top2_90.410.pth.tar 6 | ''' 7 | 8 | import torch.nn as nn 9 | import math 10 | 11 | 12 | def conv_bn(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | nn.ReLU6(inplace=True) 17 | ) 18 | 19 | 20 | def conv_1x1_bn(inp, oup): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU6(inplace=True) 25 | ) 26 | 27 | 28 | class InvertedResidual(nn.Module): 29 | def __init__(self, inp, oup, stride, expand_ratio): 30 | super(InvertedResidual, self).__init__() 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU6(inplace=True), 41 | # dw 42 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 43 | nn.BatchNorm2d(inp * expand_ratio), 44 | nn.ReLU6(inplace=True), 45 | # pw-linear 46 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup), 48 | ) 49 | 50 | def forward(self, x): 51 | if self.use_res_connect: 52 | return x + self.conv(x) 53 | else: 54 | return self.conv(x) 55 | 56 | 57 | class MobileNetV2(nn.Module): 58 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 59 | super(MobileNetV2, self).__init__() 60 | # setting of inverted residual blocks 61 | self.interverted_residual_setting = [ 62 | # t, c, n, s 63 | [1, 16, 1, 1], 64 | [6, 24, 2, 2], 65 | [6, 32, 3, 2], 66 | [6, 64, 4, 2], 67 | [6, 96, 3, 1], 68 | [6, 160, 3, 2], 69 | [6, 320, 1, 1], 70 | ] 71 | 72 | # building first layer 73 | assert input_size % 32 == 0 74 | input_channel = int(32 * width_mult) 75 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 76 | self.features = [conv_bn(3, input_channel, 2)] 77 | # building inverted residual blocks 78 | for t, c, n, s in self.interverted_residual_setting: 79 | output_channel = int(c * width_mult) 80 | for i in range(n): 81 | if i == 0: 82 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 83 | else: 84 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 85 | input_channel = output_channel 86 | # building last several layers 87 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 88 | self.features.append(nn.AvgPool2d(input_size//32)) 89 | # make it nn.Sequential 90 | self.features = nn.Sequential(*self.features) 91 | 92 | # building classifier 93 | self.classifier = nn.Sequential( 94 | nn.Dropout(), 95 | nn.Linear(self.last_channel, n_class), 96 | ) 97 | 98 | self._initialize_weights() 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = x.view(-1, self.last_channel) 103 | x = self.classifier(x) 104 | return x 105 | 106 | def _initialize_weights(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | n = m.weight.size(1) 118 | m.weight.data.normal_(0, 0.01) 119 | m.bias.data.zero_() 120 | 121 | 122 | def mobilenet_v2(**kwargs): 123 | width_mult = kwargs.get('width_mult', 1.0) 124 | n_class = kwargs.get('num_classes',1000) 125 | return MobileNetV2(n_class=n_class, width_mult=width_mult) 126 | 127 | 128 | -------------------------------------------------------------------------------- /nets/mobilenet_v3_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | Creates a MobileNetV3 Model as defined in: 4 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 5 | Searching for MobileNetV3 6 | arXiv preprint arXiv:1905.02244. 7 | https://arxiv.org/abs/1905.02244 8 | https://github.com/kuan-wang/pytorch-mobilenet-v3/blob/master/mobilenetv3.py 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | __all__ = ['MobileNetV3', 'mobilenetv3', 'mobilenet_v3_large', 'mobilenet_v3_small'] 16 | 17 | 18 | def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): 19 | return nn.Sequential( 20 | conv_layer(inp, oup, 3, stride, 1, bias=False), 21 | norm_layer(oup), 22 | nlin_layer(inplace=True) 23 | ) 24 | 25 | 26 | def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): 27 | return nn.Sequential( 28 | conv_layer(inp, oup, 1, 1, 0, bias=False), 29 | norm_layer(oup), 30 | nlin_layer(inplace=True) 31 | ) 32 | 33 | 34 | class Hswish(nn.Module): 35 | def __init__(self, inplace=True): 36 | super(Hswish, self).__init__() 37 | self.inplace = inplace 38 | 39 | def forward(self, x): 40 | return x * F.relu6(x + 3., inplace=self.inplace) / 6. 41 | 42 | 43 | class Hsigmoid(nn.Module): 44 | def __init__(self, inplace=True): 45 | super(Hsigmoid, self).__init__() 46 | self.inplace = inplace 47 | 48 | def forward(self, x): 49 | return F.relu6(x + 3., inplace=self.inplace) / 6. 50 | 51 | 52 | class SEModule(nn.Module): 53 | def __init__(self, channel, reduction=4): 54 | super(SEModule, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, channel // reduction, bias=False), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(channel // reduction, channel, bias=False), 60 | Hsigmoid() 61 | # nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | y = self.avg_pool(x).view(b, c) 67 | y = self.fc(y).view(b, c, 1, 1) 68 | return x * y.expand_as(x) 69 | 70 | 71 | class Identity(nn.Module): 72 | def __init__(self, channel): 73 | super(Identity, self).__init__() 74 | 75 | def forward(self, x): 76 | return x 77 | 78 | 79 | def make_divisible(x, divisible_by=8): 80 | import numpy as np 81 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 82 | 83 | 84 | class MobileBottleneck(nn.Module): 85 | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'): 86 | super(MobileBottleneck, self).__init__() 87 | assert stride in [1, 2] 88 | assert kernel in [3, 5] 89 | padding = (kernel - 1) // 2 90 | self.use_res_connect = stride == 1 and inp == oup 91 | 92 | conv_layer = nn.Conv2d 93 | norm_layer = nn.BatchNorm2d 94 | if nl == 'RE': 95 | nlin_layer = nn.ReLU # or ReLU6 96 | elif nl == 'HS': 97 | nlin_layer = Hswish 98 | else: 99 | raise NotImplementedError 100 | if se: 101 | SELayer = SEModule 102 | else: 103 | SELayer = Identity 104 | 105 | self.conv = nn.Sequential( 106 | # pw 107 | conv_layer(inp, exp, 1, 1, 0, bias=False), 108 | norm_layer(exp), 109 | nlin_layer(inplace=True), 110 | # dw 111 | conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), 112 | norm_layer(exp), 113 | SELayer(exp), 114 | nlin_layer(inplace=True), 115 | # pw-linear 116 | conv_layer(exp, oup, 1, 1, 0, bias=False), 117 | norm_layer(oup), 118 | ) 119 | 120 | def forward(self, x): 121 | if self.use_res_connect: 122 | return x + self.conv(x) 123 | else: 124 | return self.conv(x) 125 | 126 | 127 | class MobileNetV3(nn.Module): 128 | def __init__(self, num_classes=1000, input_size=224, mode='large', width_mult=1.0): 129 | super(MobileNetV3, self).__init__() 130 | input_channel = 16 131 | last_channel = 1280 132 | if mode == 'large': 133 | # refer to Table 1 in paper 134 | mobile_setting = [ 135 | # k, exp, c, se, nl, s, 136 | [3, 16, 16, False, 'RE', 1], 137 | [3, 64, 24, False, 'RE', 2], 138 | [3, 72, 24, False, 'RE', 1], 139 | [5, 72, 40, True, 'RE', 2], 140 | [5, 120, 40, True, 'RE', 1], 141 | [5, 120, 40, True, 'RE', 1], 142 | [3, 240, 80, False, 'HS', 2], 143 | [3, 200, 80, False, 'HS', 1], 144 | [3, 184, 80, False, 'HS', 1], 145 | [3, 184, 80, False, 'HS', 1], 146 | [3, 480, 112, True, 'HS', 1], 147 | [3, 672, 112, True, 'HS', 1], 148 | [5, 672, 112, True, 'HS', 1], # c = 112, paper set it to 160 by error 149 | [5, 672, 160, True, 'HS', 2], 150 | [5, 960, 160, True, 'HS', 1], 151 | ] 152 | elif mode == 'small': 153 | # refer to Table 2 in paper 154 | mobile_setting = [ 155 | # k, exp, c, se, nl, s, 156 | [3, 16, 16, True, 'RE', 2], 157 | [3, 72, 24, False, 'RE', 2], 158 | [3, 88, 24, False, 'RE', 1], 159 | [5, 96, 40, True, 'HS', 2], # stride = 2, paper set it to 1 by error 160 | [5, 240, 40, True, 'HS', 1], 161 | [5, 240, 40, True, 'HS', 1], 162 | [5, 120, 48, True, 'HS', 1], 163 | [5, 144, 48, True, 'HS', 1], 164 | [5, 288, 96, True, 'HS', 2], 165 | [5, 576, 96, True, 'HS', 1], 166 | [5, 576, 96, True, 'HS', 1], 167 | ] 168 | else: 169 | raise NotImplementedError 170 | 171 | # building first layer 172 | assert input_size % 32 == 0 173 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 16! 174 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 175 | self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] 176 | 177 | # building mobile blocks 178 | for k, exp, c, se, nl, s in mobile_setting: 179 | output_channel = make_divisible(c * width_mult) 180 | exp_channel = make_divisible(exp * width_mult) 181 | self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl)) 182 | input_channel = output_channel 183 | 184 | # building last several layers 185 | if mode == 'large': 186 | last_conv = make_divisible(960 * width_mult) 187 | self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) 188 | self.features.append(nn.AdaptiveAvgPool2d(1)) 189 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 190 | self.features.append(Hswish(inplace=True)) 191 | self.features.append(nn.Conv2d(last_channel, num_classes, 1, 1, 0)) 192 | elif mode == 'small': 193 | last_conv = make_divisible(576 * width_mult) 194 | self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) 195 | self.features.append(SEModule(last_conv)) # refer to paper Table2 196 | self.features.append(nn.AdaptiveAvgPool2d(1)) 197 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 198 | self.features.append(Hswish(inplace=True)) 199 | self.features.append(nn.Conv2d(last_channel, num_classes, 1, 1, 0)) 200 | else: 201 | raise NotImplementedError 202 | 203 | # make it nn.Sequential 204 | self.features = nn.Sequential(*self.features) 205 | 206 | self._initialize_weights() 207 | 208 | def forward(self, x): 209 | x = self.features(x) 210 | x = x.mean(3).mean(2) 211 | return x 212 | 213 | def _initialize_weights(self): 214 | # weight initialization 215 | for m in self.modules(): 216 | if isinstance(m, nn.Conv2d): 217 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 218 | if m.bias is not None: 219 | nn.init.zeros_(m.bias) 220 | elif isinstance(m, nn.BatchNorm2d): 221 | nn.init.ones_(m.weight) 222 | nn.init.zeros_(m.bias) 223 | elif isinstance(m, nn.Linear): 224 | nn.init.normal_(m.weight, 0, 0.01) 225 | if m.bias is not None: 226 | nn.init.zeros_(m.bias) 227 | 228 | 229 | def mobilenet_v3_large(pretrained=False, **kwargs): 230 | model = MobileNetV3(mode='large', **kwargs) 231 | if pretrained: 232 | raise NotImplementedError 233 | return model 234 | 235 | 236 | def mobilenet_v3_small(pretrained=False, **kwargs): 237 | model = MobileNetV3(mode='small', **kwargs) 238 | if pretrained: 239 | raise NotImplementedError 240 | return model 241 | 242 | 243 | if __name__ == '__main__': 244 | net = mobilenet_v3_large() 245 | print('mobilenetv3:\n', net) 246 | print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) 247 | input_size=(16, 3, 224, 224) 248 | x = torch.randn(input_size) 249 | out = net(x) 250 | 251 | -------------------------------------------------------------------------------- /nets/net_interface.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | from importlib import import_module 4 | 5 | class NetModule(object): 6 | def __init__(self, module_name, net_name, **kwargs): 7 | self.module_name = module_name 8 | self.net_name = net_name 9 | self.m = import_module('nets.' + self.module_name) 10 | 11 | def create_model(self, **kwargs): 12 | """ 13 | when use a pretrained model of imagenet, pretrained_model_num_classes is 1000 14 | :param kwargs: 15 | :return: 16 | """ 17 | _model = getattr(self.m, self.net_name) 18 | model = _model(**kwargs) 19 | return model -------------------------------------------------------------------------------- /nets/oct_resnet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | https://arxiv.org/abs/1904.05049 4 | https://github.com/d-li14/octconv.pytorch 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | 11 | __all__ = ['OctResNet', 'oct_resnet26', 'oct_resnet50', 'oct_resnet101', 'oct_resnet152', 'oct_resnet200'] 12 | 13 | 14 | 15 | class OctaveConv(nn.Module): 16 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1, 17 | groups=1, bias=False): 18 | super(OctaveConv, self).__init__() 19 | self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) 20 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 21 | assert stride == 1 or stride == 2, "Stride should be 1 or 2." 22 | self.stride = stride 23 | assert 0 <= alpha_in <= 1 and 0 <= alpha_out <= 1, "Alphas should be in the interval from 0 to 1." 24 | self.alpha_in, self.alpha_out = alpha_in, alpha_out 25 | self.conv_l2l = None if alpha_in == 0 or alpha_out == 0 else \ 26 | nn.Conv2d(int(alpha_in * in_channels), int(alpha_out * out_channels), 27 | kernel_size, 1, padding, dilation, groups, bias) 28 | self.conv_l2h = None if alpha_in == 0 or alpha_out == 1 else \ 29 | nn.Conv2d(int(alpha_in * in_channels), out_channels - int(alpha_out * out_channels), 30 | kernel_size, 1, padding, dilation, groups, bias) 31 | self.conv_h2l = None if alpha_in == 1 or alpha_out == 0 else \ 32 | nn.Conv2d(in_channels - int(alpha_in * in_channels), int(alpha_out * out_channels), 33 | kernel_size, 1, padding, dilation, groups, bias) 34 | self.conv_h2h = None if alpha_in == 1 or alpha_out == 1 else \ 35 | nn.Conv2d(in_channels - int(alpha_in * in_channels), out_channels - int(alpha_out * out_channels), 36 | kernel_size, 1, padding, dilation, groups, bias) 37 | 38 | def forward(self, x): 39 | x_h, x_l = x if type(x) is tuple else (x, None) 40 | 41 | if x_h is not None: 42 | x_h = self.downsample(x_h) if self.stride == 2 else x_h 43 | x_h2h = self.conv_h2h(x_h) 44 | x_h2l = self.conv_h2l(self.downsample(x_h)) if self.alpha_out > 0 else None 45 | if x_l is not None: 46 | x_l2h = self.conv_l2h(x_l) 47 | x_l2h = self.upsample(x_l2h) if self.stride == 1 else x_l2h 48 | x_l2l = self.downsample(x_l) if self.stride == 2 else x_l 49 | x_l2l = self.conv_l2l(x_l2l) if self.alpha_out > 0 else None 50 | x_h = x_l2h + x_h2h 51 | x_l = x_h2l + x_l2l if x_h2l is not None and x_l2l is not None else None 52 | return x_h, x_l 53 | else: 54 | return x_h2h, x_h2l 55 | 56 | 57 | class Conv_BN(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1, 59 | groups=1, bias=False, norm_layer=nn.BatchNorm2d): 60 | super(Conv_BN, self).__init__() 61 | self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation, 62 | groups, bias) 63 | self.bn_h = None if alpha_out == 1 else norm_layer(int(out_channels * (1 - alpha_out))) 64 | self.bn_l = None if alpha_out == 0 else norm_layer(int(out_channels * alpha_out)) 65 | 66 | def forward(self, x): 67 | x_h, x_l = self.conv(x) 68 | x_h = self.bn_h(x_h) 69 | x_l = self.bn_l(x_l) if x_l is not None else None 70 | return x_h, x_l 71 | 72 | 73 | class Conv_BN_ACT(nn.Module): 74 | def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1, 75 | groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU): 76 | super(Conv_BN_ACT, self).__init__() 77 | self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation, 78 | groups, bias) 79 | self.bn_h = None if alpha_out == 1 else norm_layer(int(out_channels * (1 - alpha_out))) 80 | self.bn_l = None if alpha_out == 0 else norm_layer(int(out_channels * alpha_out)) 81 | self.act = activation_layer(inplace=True) 82 | 83 | def forward(self, x): 84 | x_h, x_l = self.conv(x) 85 | x_h = self.act(self.bn_h(x_h)) 86 | x_l = self.act(self.bn_l(x_l)) if x_l is not None else None 87 | return x_h, x_l 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | expansion = 4 92 | 93 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 94 | base_width=64, alpha_in=0.5, alpha_out=0.5, norm_layer=None, output=False): 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = Conv_BN_ACT(inplanes, width, kernel_size=1, alpha_in=alpha_in, alpha_out=alpha_out, norm_layer=norm_layer) 101 | self.conv2 = Conv_BN_ACT(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, norm_layer=norm_layer, 102 | alpha_in=0 if output else 0.5, alpha_out=0 if output else 0.5) 103 | self.conv3 = Conv_BN(width, planes * self.expansion, kernel_size=1, norm_layer=norm_layer, 104 | alpha_in=0 if output else 0.5, alpha_out=0 if output else 0.5) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x): 110 | identity_h = x[0] if type(x) is tuple else x 111 | identity_l = x[1] if type(x) is tuple else None 112 | 113 | x_h, x_l = self.conv1(x) 114 | x_h, x_l = self.conv2((x_h, x_l)) 115 | x_h, x_l = self.conv3((x_h, x_l)) 116 | 117 | if self.downsample is not None: 118 | identity_h, identity_l = self.downsample(x) 119 | 120 | x_h += identity_h 121 | x_l = x_l + identity_l if identity_l is not None else None 122 | 123 | x_h = self.relu(x_h) 124 | x_l = self.relu(x_l) if x_l is not None else None 125 | 126 | return x_h, x_l 127 | 128 | 129 | return x_h 130 | 131 | 132 | class OctResNet(nn.Module): 133 | 134 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 135 | groups=1, width_per_group=64, norm_layer=None): 136 | super(OctResNet, self).__init__() 137 | if norm_layer is None: 138 | norm_layer = nn.BatchNorm2d 139 | 140 | self.inplanes = 64 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, alpha_in=0) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer, alpha_out=0, output=True) 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | # Zero-initialize the last BN in each residual branch, 163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 165 | if zero_init_residual: 166 | for m in self.modules(): 167 | if isinstance(m, Bottleneck): 168 | nn.init.constant_(m.bn3.weight, 0) 169 | elif isinstance(m, BasicBlock): 170 | nn.init.constant_(m.bn2.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, alpha_in=0.5, alpha_out=0.5, norm_layer=None, output=False): 173 | if norm_layer is None: 174 | norm_layer = nn.BatchNorm2d 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | Conv_BN(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, alpha_in=alpha_in, alpha_out=alpha_out) 179 | ) 180 | 181 | layers = [] 182 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 183 | self.base_width, alpha_in, alpha_out, norm_layer, output)) 184 | self.inplanes = planes * block.expansion 185 | for _ in range(1, blocks): 186 | layers.append(block(self.inplanes, planes, groups=self.groups, 187 | base_width=self.base_width, norm_layer=norm_layer, 188 | alpha_in=0 if output else 0.5, alpha_out=0 if output else 0.5, output=output)) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | x = self.maxpool(x) 197 | 198 | x_h, x_l = self.layer1(x) 199 | x_h, x_l = self.layer2((x_h,x_l)) 200 | x_h, x_l = self.layer3((x_h,x_l)) 201 | x_h, x_l = self.layer4((x_h,x_l)) 202 | x = self.avgpool(x_h) 203 | x = x.view(x.size(0), -1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | 209 | def oct_resnet26(pretrained=False, **kwargs): 210 | """Constructs a Octave ResNet-26 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = OctResNet(Bottleneck, [2, 2, 2, 2], **kwargs) 216 | return model 217 | 218 | 219 | def oct_resnet50(pretrained=False, **kwargs): 220 | """Constructs a Octave ResNet-50 model. 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = OctResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 226 | return model 227 | 228 | 229 | def oct_resnet101(pretrained=False, **kwargs): 230 | """Constructs a Octave ResNet-101 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = OctResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 236 | return model 237 | 238 | 239 | def oct_resnet152(pretrained=False, **kwargs): 240 | """Constructs a Octave ResNet-152 model. 241 | 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | """ 245 | model = OctResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 246 | return model 247 | 248 | 249 | def oct_resnet200(pretrained=False, **kwargs): 250 | """Constructs a Octave ResNet-200 model. 251 | 252 | Args: 253 | pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | """ 255 | model = OctResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 256 | return model 257 | 258 | if __name__ == '__main__': 259 | net = oct_resnet50() 260 | print('octave convolution:\n', net) 261 | print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters()) / 1000000.0)) 262 | input_size = (16, 3, 256, 256) 263 | x = torch.randn(input_size) 264 | out = net(x) -------------------------------------------------------------------------------- /nets/pnasnet_5_large_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | 9 | pretrained_settings = { 10 | 'pnasnet5large': { 11 | 'imagenet': { 12 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 13 | 'input_space': 'RGB', 14 | 'input_size': [3, 331, 331], 15 | 'input_range': [0, 1], 16 | 'mean': [0.5, 0.5, 0.5], 17 | 'std': [0.5, 0.5, 0.5], 18 | 'num_classes': 1000 19 | }, 20 | 'imagenet+background': { 21 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 22 | 'input_space': 'RGB', 23 | 'input_size': [3, 331, 331], 24 | 'input_range': [0, 1], 25 | 'mean': [0.5, 0.5, 0.5], 26 | 'std': [0.5, 0.5, 0.5], 27 | 'num_classes': 1001 28 | } 29 | } 30 | } 31 | 32 | 33 | class MaxPool(nn.Module): 34 | 35 | def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): 36 | super(MaxPool, self).__init__() 37 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 38 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) 39 | 40 | def forward(self, x): 41 | if self.zero_pad: 42 | x = self.zero_pad(x) 43 | x = self.pool(x) 44 | if self.zero_pad: 45 | x = x[:, :, 1:, 1:] 46 | return x 47 | 48 | 49 | class SeparableConv2d(nn.Module): 50 | 51 | def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, 52 | dw_padding): 53 | super(SeparableConv2d, self).__init__() 54 | self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, 55 | kernel_size=dw_kernel_size, 56 | stride=dw_stride, padding=dw_padding, 57 | groups=in_channels, bias=False) 58 | self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 59 | kernel_size=1, bias=False) 60 | 61 | def forward(self, x): 62 | x = self.depthwise_conv2d(x) 63 | x = self.pointwise_conv2d(x) 64 | return x 65 | 66 | 67 | class BranchSeparables(nn.Module): 68 | 69 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 70 | stem_cell=False, zero_pad=False): 71 | super(BranchSeparables, self).__init__() 72 | padding = kernel_size // 2 73 | middle_channels = out_channels if stem_cell else in_channels 74 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 75 | self.relu_1 = nn.ReLU() 76 | self.separable_1 = SeparableConv2d(in_channels, middle_channels, 77 | kernel_size, dw_stride=stride, 78 | dw_padding=padding) 79 | self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) 80 | self.relu_2 = nn.ReLU() 81 | self.separable_2 = SeparableConv2d(middle_channels, out_channels, 82 | kernel_size, dw_stride=1, 83 | dw_padding=padding) 84 | self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) 85 | 86 | def forward(self, x): 87 | x = self.relu_1(x) 88 | if self.zero_pad: 89 | x = self.zero_pad(x) 90 | x = self.separable_1(x) 91 | if self.zero_pad: 92 | x = x[:, :, 1:, 1:].contiguous() 93 | x = self.bn_sep_1(x) 94 | x = self.relu_2(x) 95 | x = self.separable_2(x) 96 | x = self.bn_sep_2(x) 97 | return x 98 | 99 | 100 | class ReluConvBn(nn.Module): 101 | 102 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 103 | super(ReluConvBn, self).__init__() 104 | self.relu = nn.ReLU() 105 | self.conv = nn.Conv2d(in_channels, out_channels, 106 | kernel_size=kernel_size, stride=stride, 107 | bias=False) 108 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 109 | 110 | def forward(self, x): 111 | x = self.relu(x) 112 | x = self.conv(x) 113 | x = self.bn(x) 114 | return x 115 | 116 | 117 | class FactorizedReduction(nn.Module): 118 | 119 | def __init__(self, in_channels, out_channels): 120 | super(FactorizedReduction, self).__init__() 121 | self.relu = nn.ReLU() 122 | self.path_1 = nn.Sequential(OrderedDict([ 123 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 124 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 125 | kernel_size=1, bias=False)), 126 | ])) 127 | self.path_2 = nn.Sequential(OrderedDict([ 128 | ('pad', nn.ZeroPad2d((0, 1, 0, 1))), 129 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 130 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 131 | kernel_size=1, bias=False)), 132 | ])) 133 | self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) 134 | 135 | def forward(self, x): 136 | x = self.relu(x) 137 | 138 | x_path1 = self.path_1(x) 139 | 140 | x_path2 = self.path_2.pad(x) 141 | x_path2 = x_path2[:, :, 1:, 1:] 142 | x_path2 = self.path_2.avgpool(x_path2) 143 | x_path2 = self.path_2.conv(x_path2) 144 | 145 | out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) 146 | return out 147 | 148 | 149 | class CellBase(nn.Module): 150 | 151 | def cell_forward(self, x_left, x_right): 152 | x_comb_iter_0_left = self.comb_iter_0_left(x_left) 153 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 154 | x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right 155 | 156 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 157 | x_comb_iter_1_right = self.comb_iter_1_right(x_right) 158 | x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right 159 | 160 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 161 | x_comb_iter_2_right = self.comb_iter_2_right(x_right) 162 | x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right 163 | 164 | x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) 165 | x_comb_iter_3_right = self.comb_iter_3_right(x_right) 166 | x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right 167 | 168 | x_comb_iter_4_left = self.comb_iter_4_left(x_left) 169 | if self.comb_iter_4_right: 170 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 171 | else: 172 | x_comb_iter_4_right = x_right 173 | x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right 174 | 175 | x_out = torch.cat( 176 | [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, 177 | x_comb_iter_4], 1) 178 | return x_out 179 | 180 | 181 | class CellStem0(CellBase): 182 | 183 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 184 | out_channels_right): 185 | super(CellStem0, self).__init__() 186 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 187 | kernel_size=1) 188 | self.comb_iter_0_left = BranchSeparables(in_channels_left, 189 | out_channels_left, 190 | kernel_size=5, stride=2, 191 | stem_cell=True) 192 | self.comb_iter_0_right = nn.Sequential(OrderedDict([ 193 | ('max_pool', MaxPool(3, stride=2)), 194 | ('conv', nn.Conv2d(in_channels_left, out_channels_left, 195 | kernel_size=1, bias=False)), 196 | ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), 197 | ])) 198 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 199 | out_channels_right, 200 | kernel_size=7, stride=2) 201 | self.comb_iter_1_right = MaxPool(3, stride=2) 202 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 203 | out_channels_right, 204 | kernel_size=5, stride=2) 205 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 206 | out_channels_right, 207 | kernel_size=3, stride=2) 208 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 209 | out_channels_right, 210 | kernel_size=3) 211 | self.comb_iter_3_right = MaxPool(3, stride=2) 212 | self.comb_iter_4_left = BranchSeparables(in_channels_right, 213 | out_channels_right, 214 | kernel_size=3, stride=2, 215 | stem_cell=True) 216 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 217 | out_channels_right, 218 | kernel_size=1, stride=2) 219 | 220 | def forward(self, x_left): 221 | x_right = self.conv_1x1(x_left) 222 | x_out = self.cell_forward(x_left, x_right) 223 | return x_out 224 | 225 | 226 | class Cell(CellBase): 227 | 228 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 229 | out_channels_right, is_reduction=False, zero_pad=False, 230 | match_prev_layer_dimensions=False): 231 | super(Cell, self).__init__() 232 | 233 | # If `is_reduction` is set to `True` stride 2 is used for 234 | # convolutional and pooling layers to reduce the spatial size of 235 | # the output of a cell approximately by a factor of 2. 236 | stride = 2 if is_reduction else 1 237 | 238 | # If `match_prev_layer_dimensions` is set to `True` 239 | # `FactorizedReduction` is used to reduce the spatial size 240 | # of the left input of a cell approximately by a factor of 2. 241 | self.match_prev_layer_dimensions = match_prev_layer_dimensions 242 | if match_prev_layer_dimensions: 243 | self.conv_prev_1x1 = FactorizedReduction(in_channels_left, 244 | out_channels_left) 245 | else: 246 | self.conv_prev_1x1 = ReluConvBn(in_channels_left, 247 | out_channels_left, kernel_size=1) 248 | 249 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 250 | kernel_size=1) 251 | self.comb_iter_0_left = BranchSeparables(out_channels_left, 252 | out_channels_left, 253 | kernel_size=5, stride=stride, 254 | zero_pad=zero_pad) 255 | self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 256 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 257 | out_channels_right, 258 | kernel_size=7, stride=stride, 259 | zero_pad=zero_pad) 260 | self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 261 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 262 | out_channels_right, 263 | kernel_size=5, stride=stride, 264 | zero_pad=zero_pad) 265 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 266 | out_channels_right, 267 | kernel_size=3, stride=stride, 268 | zero_pad=zero_pad) 269 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 270 | out_channels_right, 271 | kernel_size=3) 272 | self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 273 | self.comb_iter_4_left = BranchSeparables(out_channels_left, 274 | out_channels_left, 275 | kernel_size=3, stride=stride, 276 | zero_pad=zero_pad) 277 | if is_reduction: 278 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 279 | out_channels_right, 280 | kernel_size=1, stride=stride) 281 | else: 282 | self.comb_iter_4_right = None 283 | 284 | def forward(self, x_left, x_right): 285 | x_left = self.conv_prev_1x1(x_left) 286 | x_right = self.conv_1x1(x_right) 287 | x_out = self.cell_forward(x_left, x_right) 288 | return x_out 289 | 290 | 291 | class PNASNet5Large(nn.Module): 292 | def __init__(self, num_classes=1001): 293 | super().__init__() 294 | self.num_classes = num_classes 295 | self.conv_0 = nn.Sequential(OrderedDict([ 296 | ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), 297 | ('bn', nn.BatchNorm2d(96, eps=0.001)) 298 | ])) 299 | self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, 300 | in_channels_right=96, 301 | out_channels_right=54) 302 | self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, 303 | in_channels_right=270, out_channels_right=108, 304 | match_prev_layer_dimensions=True, 305 | is_reduction=True) 306 | self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, 307 | in_channels_right=540, out_channels_right=216, 308 | match_prev_layer_dimensions=True) 309 | self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, 310 | in_channels_right=1080, out_channels_right=216) 311 | self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, 312 | in_channels_right=1080, out_channels_right=216) 313 | self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, 314 | in_channels_right=1080, out_channels_right=216) 315 | self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, 316 | in_channels_right=1080, out_channels_right=432, 317 | is_reduction=True, zero_pad=True) 318 | self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, 319 | in_channels_right=2160, out_channels_right=432, 320 | match_prev_layer_dimensions=True) 321 | self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, 322 | in_channels_right=2160, out_channels_right=432) 323 | self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, 324 | in_channels_right=2160, out_channels_right=432) 325 | self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, 326 | in_channels_right=2160, out_channels_right=864, 327 | is_reduction=True) 328 | self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, 329 | in_channels_right=4320, out_channels_right=864, 330 | match_prev_layer_dimensions=True) 331 | self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, 332 | in_channels_right=4320, out_channels_right=864) 333 | self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, 334 | in_channels_right=4320, out_channels_right=864) 335 | self.relu = nn.ReLU() 336 | self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) 337 | self.dropout = nn.Dropout(0.5) 338 | self.last_linear = nn.Linear(4320, num_classes) 339 | 340 | def features(self, x): 341 | x_conv_0 = self.conv_0(x) 342 | x_stem_0 = self.cell_stem_0(x_conv_0) 343 | x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) 344 | x_cell_0 = self.cell_0(x_stem_0, x_stem_1) 345 | x_cell_1 = self.cell_1(x_stem_1, x_cell_0) 346 | x_cell_2 = self.cell_2(x_cell_0, x_cell_1) 347 | x_cell_3 = self.cell_3(x_cell_1, x_cell_2) 348 | x_cell_4 = self.cell_4(x_cell_2, x_cell_3) 349 | x_cell_5 = self.cell_5(x_cell_3, x_cell_4) 350 | x_cell_6 = self.cell_6(x_cell_4, x_cell_5) 351 | x_cell_7 = self.cell_7(x_cell_5, x_cell_6) 352 | x_cell_8 = self.cell_8(x_cell_6, x_cell_7) 353 | x_cell_9 = self.cell_9(x_cell_7, x_cell_8) 354 | x_cell_10 = self.cell_10(x_cell_8, x_cell_9) 355 | x_cell_11 = self.cell_11(x_cell_9, x_cell_10) 356 | return x_cell_11 357 | 358 | def logits(self, features): 359 | x = self.relu(features) 360 | x = self.avg_pool(x) 361 | x = x.view(x.size(0), -1) 362 | x = self.dropout(x) 363 | x = self.last_linear(x) 364 | return x 365 | 366 | def forward(self, input): 367 | x = self.features(input) 368 | x = self.logits(x) 369 | return x 370 | 371 | 372 | def pnasnet5large(num_classes=1001, pretrained='imagenet'): 373 | r"""PNASNet-5 model architecture from the 374 | `"Progressive Neural Architecture Search" 375 | `_ paper. 376 | """ 377 | if pretrained: 378 | settings = pretrained_settings['pnasnet5large'][pretrained] 379 | assert num_classes == settings[ 380 | 'num_classes'], 'num_classes should be {}, but is {}'.format( 381 | settings['num_classes'], num_classes) 382 | 383 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 384 | model = PNASNet5Large(num_classes=1001) 385 | model.load_state_dict(model_zoo.load_url(settings['url'])) 386 | 387 | if pretrained == 'imagenet': 388 | new_last_linear = nn.Linear(model.last_linear.in_features, 1000) 389 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 390 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 391 | model.last_linear = new_last_linear 392 | 393 | model.input_space = settings['input_space'] 394 | model.input_size = settings['input_size'] 395 | model.input_range = settings['input_range'] 396 | 397 | model.mean = settings['mean'] 398 | model.std = settings['std'] 399 | else: 400 | model = PNASNet5Large(num_classes=num_classes) 401 | return model 402 | 403 | 404 | def pnasnet_5_large(**kwargs): 405 | num_classes = kwargs.get('num_classes', 1001) 406 | model = PNASNet5Large(num_classes=num_classes) 407 | return model 408 | 409 | -------------------------------------------------------------------------------- /nets/polynet_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils import model_zoo 5 | 6 | __all__ = ['PolyNet', 'polynet'] 7 | 8 | pretrained_settings = { 9 | 'polynet': { 10 | 'imagenet': { 11 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/polynet-f71d82a5.pth', 12 | 'input_space': 'RGB', 13 | 'input_size': [3, 331, 331], 14 | 'input_range': [0, 1], 15 | 'mean': [0.485, 0.456, 0.406], 16 | 'std': [0.229, 0.224, 0.225], 17 | 'num_classes': 1000 18 | }, 19 | } 20 | } 21 | 22 | 23 | class BasicConv2d(nn.Module): 24 | 25 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, 26 | output_relu=True): 27 | super(BasicConv2d, self).__init__() 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 29 | stride=stride, padding=padding, bias=False) 30 | self.bn = nn.BatchNorm2d(out_planes) 31 | self.relu = nn.ReLU() if output_relu else None 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | x = self.bn(x) 36 | if self.relu: 37 | x = self.relu(x) 38 | return x 39 | 40 | 41 | class PolyConv2d(nn.Module): 42 | """A block that is used inside poly-N (poly-2, poly-3, and so on) modules. 43 | The Convolution layer is shared between all Inception blocks inside 44 | a poly-N module. BatchNorm layers are not shared between Inception blocks 45 | and therefore the number of BatchNorm layers is equal to the number of 46 | Inception blocks inside a poly-N module. 47 | """ 48 | 49 | def __init__(self, in_planes, out_planes, kernel_size, num_blocks, 50 | stride=1, padding=0): 51 | super(PolyConv2d, self).__init__() 52 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 53 | stride=stride, padding=padding, bias=False) 54 | self.bn_blocks = nn.ModuleList([ 55 | nn.BatchNorm2d(out_planes) for _ in range(num_blocks) 56 | ]) 57 | self.relu = nn.ReLU() 58 | 59 | def forward(self, x, block_index): 60 | x = self.conv(x) 61 | bn = self.bn_blocks[block_index] 62 | x = bn(x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class Stem(nn.Module): 68 | 69 | def __init__(self): 70 | super(Stem, self).__init__() 71 | self.conv1 = nn.Sequential( 72 | BasicConv2d(3, 32, kernel_size=3, stride=2), 73 | BasicConv2d(32, 32, kernel_size=3), 74 | BasicConv2d(32, 64, kernel_size=3, padding=1), 75 | ) 76 | self.conv1_pool_branch = nn.MaxPool2d(3, stride=2) 77 | self.conv1_branch = BasicConv2d(64, 96, kernel_size=3, stride=2) 78 | self.conv2_short = nn.Sequential( 79 | BasicConv2d(160, 64, kernel_size=1), 80 | BasicConv2d(64, 96, kernel_size=3), 81 | ) 82 | self.conv2_long = nn.Sequential( 83 | BasicConv2d(160, 64, kernel_size=1), 84 | BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0)), 85 | BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3)), 86 | BasicConv2d(64, 96, kernel_size=3), 87 | ) 88 | self.conv2_pool_branch = nn.MaxPool2d(3, stride=2) 89 | self.conv2_branch = BasicConv2d(192, 192, kernel_size=3, stride=2) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | 94 | x0 = self.conv1_pool_branch(x) 95 | x1 = self.conv1_branch(x) 96 | x = torch.cat((x0, x1), 1) 97 | 98 | x0 = self.conv2_short(x) 99 | x1 = self.conv2_long(x) 100 | x = torch.cat((x0, x1), 1) 101 | 102 | x0 = self.conv2_pool_branch(x) 103 | x1 = self.conv2_branch(x) 104 | out = torch.cat((x0, x1), 1) 105 | return out 106 | 107 | 108 | class BlockA(nn.Module): 109 | """Inception-ResNet-A block.""" 110 | 111 | def __init__(self): 112 | super(BlockA, self).__init__() 113 | self.path0 = nn.Sequential( 114 | BasicConv2d(384, 32, kernel_size=1), 115 | BasicConv2d(32, 48, kernel_size=3, padding=1), 116 | BasicConv2d(48, 64, kernel_size=3, padding=1), 117 | ) 118 | self.path1 = nn.Sequential( 119 | BasicConv2d(384, 32, kernel_size=1), 120 | BasicConv2d(32, 32, kernel_size=3, padding=1), 121 | ) 122 | self.path2 = BasicConv2d(384, 32, kernel_size=1) 123 | self.conv2d = BasicConv2d(128, 384, kernel_size=1, output_relu=False) 124 | 125 | def forward(self, x): 126 | x0 = self.path0(x) 127 | x1 = self.path1(x) 128 | x2 = self.path2(x) 129 | out = torch.cat((x0, x1, x2), 1) 130 | out = self.conv2d(out) 131 | return out 132 | 133 | 134 | class BlockB(nn.Module): 135 | """Inception-ResNet-B block.""" 136 | 137 | def __init__(self): 138 | super(BlockB, self).__init__() 139 | self.path0 = nn.Sequential( 140 | BasicConv2d(1152, 128, kernel_size=1), 141 | BasicConv2d(128, 160, kernel_size=(1, 7), padding=(0, 3)), 142 | BasicConv2d(160, 192, kernel_size=(7, 1), padding=(3, 0)), 143 | ) 144 | self.path1 = BasicConv2d(1152, 192, kernel_size=1) 145 | self.conv2d = BasicConv2d(384, 1152, kernel_size=1, output_relu=False) 146 | 147 | def forward(self, x): 148 | x0 = self.path0(x) 149 | x1 = self.path1(x) 150 | out = torch.cat((x0, x1), 1) 151 | out = self.conv2d(out) 152 | return out 153 | 154 | 155 | class BlockC(nn.Module): 156 | """Inception-ResNet-C block.""" 157 | 158 | def __init__(self): 159 | super(BlockC, self).__init__() 160 | self.path0 = nn.Sequential( 161 | BasicConv2d(2048, 192, kernel_size=1), 162 | BasicConv2d(192, 224, kernel_size=(1, 3), padding=(0, 1)), 163 | BasicConv2d(224, 256, kernel_size=(3, 1), padding=(1, 0)), 164 | ) 165 | self.path1 = BasicConv2d(2048, 192, kernel_size=1) 166 | self.conv2d = BasicConv2d(448, 2048, kernel_size=1, output_relu=False) 167 | 168 | def forward(self, x): 169 | x0 = self.path0(x) 170 | x1 = self.path1(x) 171 | out = torch.cat((x0, x1), 1) 172 | out = self.conv2d(out) 173 | return out 174 | 175 | 176 | class ReductionA(nn.Module): 177 | """A dimensionality reduction block that is placed after stage-a 178 | Inception-ResNet blocks. 179 | """ 180 | 181 | def __init__(self): 182 | super(ReductionA, self).__init__() 183 | self.path0 = nn.Sequential( 184 | BasicConv2d(384, 256, kernel_size=1), 185 | BasicConv2d(256, 256, kernel_size=3, padding=1), 186 | BasicConv2d(256, 384, kernel_size=3, stride=2), 187 | ) 188 | self.path1 = BasicConv2d(384, 384, kernel_size=3, stride=2) 189 | self.path2 = nn.MaxPool2d(3, stride=2) 190 | 191 | def forward(self, x): 192 | x0 = self.path0(x) 193 | x1 = self.path1(x) 194 | x2 = self.path2(x) 195 | out = torch.cat((x0, x1, x2), 1) 196 | return out 197 | 198 | 199 | class ReductionB(nn.Module): 200 | """A dimensionality reduction block that is placed after stage-b 201 | Inception-ResNet blocks. 202 | """ 203 | def __init__(self): 204 | super(ReductionB, self).__init__() 205 | self.path0 = nn.Sequential( 206 | BasicConv2d(1152, 256, kernel_size=1), 207 | BasicConv2d(256, 256, kernel_size=3, padding=1), 208 | BasicConv2d(256, 256, kernel_size=3, stride=2), 209 | ) 210 | self.path1 = nn.Sequential( 211 | BasicConv2d(1152, 256, kernel_size=1), 212 | BasicConv2d(256, 256, kernel_size=3, stride=2), 213 | ) 214 | self.path2 = nn.Sequential( 215 | BasicConv2d(1152, 256, kernel_size=1), 216 | BasicConv2d(256, 384, kernel_size=3, stride=2), 217 | ) 218 | self.path3 = nn.MaxPool2d(3, stride=2) 219 | 220 | def forward(self, x): 221 | x0 = self.path0(x) 222 | x1 = self.path1(x) 223 | x2 = self.path2(x) 224 | x3 = self.path3(x) 225 | out = torch.cat((x0, x1, x2, x3), 1) 226 | return out 227 | 228 | 229 | class InceptionResNetBPoly(nn.Module): 230 | """Base class for constructing poly-N Inception-ResNet-B modules. 231 | When `num_blocks` is equal to 1, a module will have only a first-order path 232 | and will be equal to a standard Inception-ResNet-B block. 233 | When `num_blocks` is equal to 2, a module will have first-order and 234 | second-order paths and will be called Inception-ResNet-B poly-2 module. 235 | Increasing value of the `num_blocks` parameter will produce a higher order 236 | Inception-ResNet-B poly-N modules. 237 | """ 238 | 239 | def __init__(self, scale, num_blocks): 240 | super(InceptionResNetBPoly, self).__init__() 241 | assert num_blocks >= 1, 'num_blocks should be greater or equal to 1' 242 | self.scale = scale 243 | self.num_blocks = num_blocks 244 | self.path0_1x1 = PolyConv2d(1152, 128, kernel_size=1, 245 | num_blocks=self.num_blocks) 246 | self.path0_1x7 = PolyConv2d(128, 160, kernel_size=(1, 7), 247 | num_blocks=self.num_blocks, padding=(0, 3)) 248 | self.path0_7x1 = PolyConv2d(160, 192, kernel_size=(7, 1), 249 | num_blocks=self.num_blocks, padding=(3, 0)) 250 | self.path1 = PolyConv2d(1152, 192, kernel_size=1, 251 | num_blocks=self.num_blocks) 252 | # conv2d blocks are not shared between Inception-ResNet-B blocks 253 | self.conv2d_blocks = nn.ModuleList([ 254 | BasicConv2d(384, 1152, kernel_size=1, output_relu=False) 255 | for _ in range(self.num_blocks) 256 | ]) 257 | self.relu = nn.ReLU() 258 | 259 | def forward_block(self, x, block_index): 260 | x0 = self.path0_1x1(x, block_index) 261 | x0 = self.path0_1x7(x0, block_index) 262 | x0 = self.path0_7x1(x0, block_index) 263 | x1 = self.path1(x, block_index) 264 | out = torch.cat((x0, x1), 1) 265 | conv2d_block = self.conv2d_blocks[block_index] 266 | out = conv2d_block(out) 267 | return out 268 | 269 | def forward(self, x): 270 | out = x 271 | for block_index in range(self.num_blocks): 272 | x = self.forward_block(x, block_index) 273 | out = out + x * self.scale 274 | x = self.relu(x) 275 | out = self.relu(out) 276 | return out 277 | 278 | 279 | class InceptionResNetCPoly(nn.Module): 280 | """Base class for constructing poly-N Inception-ResNet-C modules. 281 | When `num_blocks` is equal to 1, a module will have only a first-order path 282 | and will be equal to a standard Inception-ResNet-C block. 283 | When `num_blocks` is equal to 2, a module will have first-order and 284 | second-order paths and will be called Inception-ResNet-C poly-2 module. 285 | Increasing value of the `num_blocks` parameter will produce a higher order 286 | Inception-ResNet-C poly-N modules. 287 | """ 288 | 289 | def __init__(self, scale, num_blocks): 290 | super(InceptionResNetCPoly, self).__init__() 291 | assert num_blocks >= 1, 'num_blocks should be greater or equal to 1' 292 | self.scale = scale 293 | self.num_blocks = num_blocks 294 | self.path0_1x1 = PolyConv2d(2048, 192, kernel_size=1, 295 | num_blocks=self.num_blocks) 296 | self.path0_1x3 = PolyConv2d(192, 224, kernel_size=(1, 3), 297 | num_blocks=self.num_blocks, padding=(0, 1)) 298 | self.path0_3x1 = PolyConv2d(224, 256, kernel_size=(3, 1), 299 | num_blocks=self.num_blocks, padding=(1, 0)) 300 | self.path1 = PolyConv2d(2048, 192, kernel_size=1, 301 | num_blocks=self.num_blocks) 302 | # conv2d blocks are not shared between Inception-ResNet-C blocks 303 | self.conv2d_blocks = nn.ModuleList([ 304 | BasicConv2d(448, 2048, kernel_size=1, output_relu=False) 305 | for _ in range(self.num_blocks) 306 | ]) 307 | self.relu = nn.ReLU() 308 | 309 | def forward_block(self, x, block_index): 310 | x0 = self.path0_1x1(x, block_index) 311 | x0 = self.path0_1x3(x0, block_index) 312 | x0 = self.path0_3x1(x0, block_index) 313 | x1 = self.path1(x, block_index) 314 | out = torch.cat((x0, x1), 1) 315 | conv2d_block = self.conv2d_blocks[block_index] 316 | out = conv2d_block(out) 317 | return out 318 | 319 | def forward(self, x): 320 | out = x 321 | for block_index in range(self.num_blocks): 322 | x = self.forward_block(x, block_index) 323 | out = out + x * self.scale 324 | x = self.relu(x) 325 | out = self.relu(out) 326 | return out 327 | 328 | 329 | class MultiWay(nn.Module): 330 | """Base class for constructing N-way modules (2-way, 3-way, and so on).""" 331 | 332 | def __init__(self, scale, block_cls, num_blocks): 333 | super(MultiWay, self).__init__() 334 | assert num_blocks >= 1, 'num_blocks should be greater or equal to 1' 335 | self.scale = scale 336 | self.blocks = nn.ModuleList([block_cls() for _ in range(num_blocks)]) 337 | self.relu = nn.ReLU() 338 | 339 | def forward(self, x): 340 | out = x 341 | for block in self.blocks: 342 | out = out + block(x) * self.scale 343 | out = self.relu(out) 344 | return out 345 | 346 | 347 | # Some helper classes to simplify the construction of PolyNet 348 | 349 | class InceptionResNetA2Way(MultiWay): 350 | 351 | def __init__(self, scale): 352 | super(InceptionResNetA2Way, self).__init__(scale, block_cls=BlockA, 353 | num_blocks=2) 354 | 355 | 356 | class InceptionResNetB2Way(MultiWay): 357 | 358 | def __init__(self, scale): 359 | super(InceptionResNetB2Way, self).__init__(scale, block_cls=BlockB, 360 | num_blocks=2) 361 | 362 | 363 | class InceptionResNetC2Way(MultiWay): 364 | 365 | def __init__(self, scale): 366 | super(InceptionResNetC2Way, self).__init__(scale, block_cls=BlockC, 367 | num_blocks=2) 368 | 369 | 370 | class InceptionResNetBPoly3(InceptionResNetBPoly): 371 | 372 | def __init__(self, scale): 373 | super(InceptionResNetBPoly3, self).__init__(scale, num_blocks=3) 374 | 375 | 376 | class InceptionResNetCPoly3(InceptionResNetCPoly): 377 | 378 | def __init__(self, scale): 379 | super(InceptionResNetCPoly3, self).__init__(scale, num_blocks=3) 380 | 381 | 382 | class PolyNet(nn.Module): 383 | 384 | def __init__(self, num_classes=1000): 385 | super(PolyNet, self).__init__() 386 | self.stem = Stem() 387 | self.stage_a = nn.Sequential( 388 | InceptionResNetA2Way(scale=1), 389 | InceptionResNetA2Way(scale=0.992308), 390 | InceptionResNetA2Way(scale=0.984615), 391 | InceptionResNetA2Way(scale=0.976923), 392 | InceptionResNetA2Way(scale=0.969231), 393 | InceptionResNetA2Way(scale=0.961538), 394 | InceptionResNetA2Way(scale=0.953846), 395 | InceptionResNetA2Way(scale=0.946154), 396 | InceptionResNetA2Way(scale=0.938462), 397 | InceptionResNetA2Way(scale=0.930769), 398 | ) 399 | self.reduction_a = ReductionA() 400 | self.stage_b = nn.Sequential( 401 | InceptionResNetBPoly3(scale=0.923077), 402 | InceptionResNetB2Way(scale=0.915385), 403 | InceptionResNetBPoly3(scale=0.907692), 404 | InceptionResNetB2Way(scale=0.9), 405 | InceptionResNetBPoly3(scale=0.892308), 406 | InceptionResNetB2Way(scale=0.884615), 407 | InceptionResNetBPoly3(scale=0.876923), 408 | InceptionResNetB2Way(scale=0.869231), 409 | InceptionResNetBPoly3(scale=0.861538), 410 | InceptionResNetB2Way(scale=0.853846), 411 | InceptionResNetBPoly3(scale=0.846154), 412 | InceptionResNetB2Way(scale=0.838462), 413 | InceptionResNetBPoly3(scale=0.830769), 414 | InceptionResNetB2Way(scale=0.823077), 415 | InceptionResNetBPoly3(scale=0.815385), 416 | InceptionResNetB2Way(scale=0.807692), 417 | InceptionResNetBPoly3(scale=0.8), 418 | InceptionResNetB2Way(scale=0.792308), 419 | InceptionResNetBPoly3(scale=0.784615), 420 | InceptionResNetB2Way(scale=0.776923), 421 | ) 422 | self.reduction_b = ReductionB() 423 | self.stage_c = nn.Sequential( 424 | InceptionResNetCPoly3(scale=0.769231), 425 | InceptionResNetC2Way(scale=0.761538), 426 | InceptionResNetCPoly3(scale=0.753846), 427 | InceptionResNetC2Way(scale=0.746154), 428 | InceptionResNetCPoly3(scale=0.738462), 429 | InceptionResNetC2Way(scale=0.730769), 430 | InceptionResNetCPoly3(scale=0.723077), 431 | InceptionResNetC2Way(scale=0.715385), 432 | InceptionResNetCPoly3(scale=0.707692), 433 | InceptionResNetC2Way(scale=0.7), 434 | ) 435 | self.avg_pool = nn.AvgPool2d(9, stride=1) 436 | self.dropout = nn.Dropout(0.2) 437 | self.last_linear = nn.Linear(2048, num_classes) 438 | 439 | def features(self, x): 440 | x = self.stem(x) 441 | x = self.stage_a(x) 442 | x = self.reduction_a(x) 443 | x = self.stage_b(x) 444 | x = self.reduction_b(x) 445 | x = self.stage_c(x) 446 | return x 447 | 448 | def logits(self, x): 449 | x = self.avg_pool(x) 450 | x = self.dropout(x) 451 | x = x.view(x.size(0), -1) 452 | x = self.last_linear(x) 453 | return x 454 | 455 | def forward(self, x): 456 | x = self.features(x) 457 | x = self.logits(x) 458 | return x 459 | 460 | 461 | def _polynet(num_classes=1000, pretrained='imagenet'): 462 | """PolyNet architecture from the paper 463 | 'PolyNet: A Pursuit of Structural Diversity in Very Deep Networks' 464 | https://arxiv.org/abs/1611.05725 465 | """ 466 | if pretrained: 467 | settings = pretrained_settings['polynet'][pretrained] 468 | assert num_classes == settings['num_classes'], \ 469 | 'num_classes should be {}, but is {}'.format( 470 | settings['num_classes'], num_classes) 471 | model = PolyNet(num_classes=num_classes) 472 | model.load_state_dict(model_zoo.load_url(settings['url'])) 473 | model.input_space = settings['input_space'] 474 | model.input_size = settings['input_size'] 475 | model.input_range = settings['input_range'] 476 | model.mean = settings['mean'] 477 | model.std = settings['std'] 478 | else: 479 | model = PolyNet(num_classes=num_classes) 480 | return model 481 | 482 | 483 | def polynet(**kwargs): 484 | num_classes = kwargs.get('num_classes', 1000) 485 | model = PolyNet(num_classes=num_classes) 486 | return model 487 | -------------------------------------------------------------------------------- /nets/resnet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | ''' 5 | 6 | import torch.nn as nn 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 64 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = nn.BatchNorm2d(64) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), -1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | 161 | def resnet18(**kwargs): 162 | """Constructs a ResNet-18 model. 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | num_classes = kwargs.get('num_classes', 1000) 167 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 168 | return model 169 | 170 | 171 | def resnet34(**kwargs): 172 | """Constructs a ResNet-34 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | num_classes = kwargs.get('num_classes', 1000) 177 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 178 | return model 179 | 180 | 181 | def resnet50(**kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | num_classes = kwargs.get('num_classes', 1000) 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 188 | return model 189 | 190 | 191 | def resnet101(**kwargs): 192 | """Constructs a ResNet-101 model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | num_classes = kwargs.get('num_classes', 1000) 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 198 | return model 199 | 200 | 201 | def resnet152(**kwargs): 202 | """Constructs a ResNet-152 model. 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | num_classes = kwargs.get('num_classes', 1000) 207 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 208 | return model 209 | 210 | 211 | def re_define_layers(model, **kwargs): 212 | """ 213 | for fin 214 | :param kwargs: 215 | :return: 216 | """ 217 | model_net_name = kwargs.get('model_net_name','resnet101') 218 | num_classes = kwargs.get('num_classes', 1000) 219 | if model_net_name == 'resnet18': 220 | block = BasicBlock 221 | elif model_net_name == 'resnet34': 222 | block = BasicBlock 223 | elif model_net_name == 'resnet50': 224 | block = Bottleneck 225 | elif model_net_name == 'resnet101': 226 | block = Bottleneck 227 | elif model_net_name == 'resnet152': 228 | block = Bottleneck 229 | else: 230 | raise ValueError('model_net_name error !') 231 | model.fc = nn.Linear(512 * block.expansion, num_classes) 232 | 233 | 234 | if __name__ == "__main__": 235 | print('begin...') 236 | print('done') 237 | -------------------------------------------------------------------------------- /nets/shufflenet_v2_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/ericsun99/Shufflenet-v2-Pytorch 4 | 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | from torch.nn import init 12 | import math 13 | 14 | def conv_bn(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | 22 | def conv_1x1_bn(inp, oup): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 25 | nn.BatchNorm2d(oup), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | def channel_shuffle(x, groups): 30 | batchsize, num_channels, height, width = x.data.size() 31 | 32 | channels_per_group = num_channels // groups 33 | 34 | # reshape 35 | x = x.view(batchsize, groups, 36 | channels_per_group, height, width) 37 | 38 | x = torch.transpose(x, 1, 2).contiguous() 39 | 40 | # flatten 41 | x = x.view(batchsize, -1, height, width) 42 | 43 | return x 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, benchmodel): 47 | super(InvertedResidual, self).__init__() 48 | self.benchmodel = benchmodel 49 | self.stride = stride 50 | assert stride in [1, 2] 51 | 52 | oup_inc = oup//2 53 | 54 | if self.benchmodel == 1: 55 | #assert inp == oup_inc 56 | self.banch2 = nn.Sequential( 57 | # pw 58 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 59 | nn.BatchNorm2d(oup_inc), 60 | nn.ReLU(inplace=True), 61 | # dw 62 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 63 | nn.BatchNorm2d(oup_inc), 64 | # pw-linear 65 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 66 | nn.BatchNorm2d(oup_inc), 67 | nn.ReLU(inplace=True), 68 | ) 69 | else: 70 | self.banch1 = nn.Sequential( 71 | # dw 72 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 73 | nn.BatchNorm2d(inp), 74 | # pw-linear 75 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 76 | nn.BatchNorm2d(oup_inc), 77 | nn.ReLU(inplace=True), 78 | ) 79 | 80 | self.banch2 = nn.Sequential( 81 | # pw 82 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(oup_inc), 84 | nn.ReLU(inplace=True), 85 | # dw 86 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 87 | nn.BatchNorm2d(oup_inc), 88 | # pw-linear 89 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 90 | nn.BatchNorm2d(oup_inc), 91 | nn.ReLU(inplace=True), 92 | ) 93 | 94 | @staticmethod 95 | def _concat(x, out): 96 | # concatenate along channel axis 97 | return torch.cat((x, out), 1) 98 | 99 | def forward(self, x): 100 | if 1==self.benchmodel: 101 | x1 = x[:, :(x.shape[1]//2), :, :] 102 | x2 = x[:, (x.shape[1]//2):, :, :] 103 | out = self._concat(x1, self.banch2(x2)) 104 | elif 2==self.benchmodel: 105 | out = self._concat(self.banch1(x), self.banch2(x)) 106 | 107 | return channel_shuffle(out, 2) 108 | 109 | 110 | class ShuffleNetV2(nn.Module): 111 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 112 | super(ShuffleNetV2, self).__init__() 113 | 114 | assert input_size % 32 == 0 115 | 116 | self.stage_repeats = [4, 8, 4] 117 | # index 0 is invalid and should never be called. 118 | # only used for indexing convenience. 119 | if width_mult == 0.5: 120 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 121 | elif width_mult == 1.0: 122 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 123 | elif width_mult == 1.5: 124 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 125 | elif width_mult == 2.0: 126 | self.stage_out_channels = [-1, 24, 224, 488, 976, 2048] 127 | else: 128 | raise ValueError( 129 | """{} groups is not supported for 130 | 1x1 Grouped Convolutions""".format(num_groups)) 131 | 132 | # building first layer 133 | input_channel = self.stage_out_channels[1] 134 | self.conv1 = conv_bn(3, input_channel, 2) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | 137 | self.features = [] 138 | # building inverted residual blocks 139 | for idxstage in range(len(self.stage_repeats)): 140 | numrepeat = self.stage_repeats[idxstage] 141 | output_channel = self.stage_out_channels[idxstage+2] 142 | for i in range(numrepeat): 143 | if i == 0: 144 | #inp, oup, stride, benchmodel): 145 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 146 | else: 147 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 148 | input_channel = output_channel 149 | 150 | 151 | # make it nn.Sequential 152 | self.features = nn.Sequential(*self.features) 153 | 154 | # building last several layers 155 | self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 156 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 157 | # building classifier 158 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.maxpool(x) 163 | x = self.features(x) 164 | x = self.conv_last(x) 165 | x = self.globalpool(x) 166 | x = x.view(-1, self.stage_out_channels[-1]) 167 | x = self.classifier(x) 168 | return x 169 | 170 | def _shufflenetv2(width_mult=1.): 171 | model = ShuffleNetV2(width_mult=width_mult) 172 | return model 173 | 174 | def shufflenet_v2(**kwargs): 175 | num_classes = kwargs.get('num_classes', 1000) 176 | width_mult = kwargs.get('width_mult', 1.) 177 | model = ShuffleNetV2(n_class=num_classes, width_mult=width_mult) 178 | return model 179 | 180 | if __name__ == "__main__": 181 | """Testing 182 | """ 183 | model = ShuffleNetV2() 184 | print(model) 185 | -------------------------------------------------------------------------------- /nets/squeezenet_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | ''' 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | 13 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 14 | 15 | 16 | model_urls = { 17 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 18 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 19 | } 20 | 21 | 22 | class Fire(nn.Module): 23 | 24 | def __init__(self, inplanes, squeeze_planes, 25 | expand1x1_planes, expand3x3_planes): 26 | super(Fire, self).__init__() 27 | self.inplanes = inplanes 28 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 29 | self.squeeze_activation = nn.ReLU(inplace=True) 30 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 31 | kernel_size=1) 32 | self.expand1x1_activation = nn.ReLU(inplace=True) 33 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 34 | kernel_size=3, padding=1) 35 | self.expand3x3_activation = nn.ReLU(inplace=True) 36 | 37 | def forward(self, x): 38 | x = self.squeeze_activation(self.squeeze(x)) 39 | return torch.cat([ 40 | self.expand1x1_activation(self.expand1x1(x)), 41 | self.expand3x3_activation(self.expand3x3(x)) 42 | ], 1) 43 | 44 | 45 | class SqueezeNet(nn.Module): 46 | 47 | def __init__(self, version=1.0, num_classes=1000): 48 | super(SqueezeNet, self).__init__() 49 | if version not in [1.0, 1.1]: 50 | raise ValueError("Unsupported SqueezeNet version {version}:" 51 | "1.0 or 1.1 expected".format(version=version)) 52 | self.num_classes = num_classes 53 | if version == 1.0: 54 | self.features = nn.Sequential( 55 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 56 | nn.ReLU(inplace=True), 57 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 58 | Fire(96, 16, 64, 64), 59 | Fire(128, 16, 64, 64), 60 | Fire(128, 32, 128, 128), 61 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 62 | Fire(256, 32, 128, 128), 63 | Fire(256, 48, 192, 192), 64 | Fire(384, 48, 192, 192), 65 | Fire(384, 64, 256, 256), 66 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 67 | Fire(512, 64, 256, 256), 68 | ) 69 | else: 70 | self.features = nn.Sequential( 71 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 72 | nn.ReLU(inplace=True), 73 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 74 | Fire(64, 16, 64, 64), 75 | Fire(128, 16, 64, 64), 76 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 77 | Fire(128, 32, 128, 128), 78 | Fire(256, 32, 128, 128), 79 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 80 | Fire(256, 48, 192, 192), 81 | Fire(384, 48, 192, 192), 82 | Fire(384, 64, 256, 256), 83 | Fire(512, 64, 256, 256), 84 | ) 85 | # Final convolution is initialized differently form the rest 86 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 87 | self.classifier = nn.Sequential( 88 | nn.Dropout(p=0.5), 89 | final_conv, 90 | nn.ReLU(inplace=True), 91 | nn.AvgPool2d(13, stride=1) 92 | ) 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | if m is final_conv: 96 | init.normal(m.weight.data, mean=0.0, std=0.01) 97 | else: 98 | init.kaiming_uniform(m.weight.data) 99 | if m.bias is not None: 100 | m.bias.data.zero_() 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | x = self.classifier(x) 105 | return x.view(x.size(0), self.num_classes) 106 | 107 | 108 | def _squeezenet1_0(pretrained=False, **kwargs): 109 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 110 | accuracy with 50x fewer parameters and <0.5MB model size" 111 | `_ paper. 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = SqueezeNet(version=1.0, **kwargs) 116 | if pretrained: 117 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) 118 | return model 119 | 120 | 121 | def squeezenet1_0(**kwargs): 122 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 123 | accuracy with 50x fewer parameters and <0.5MB model size" 124 | `_ paper. 125 | Args: 126 | pretrained (bool): If True, returns a model pre-trained on ImageNet 127 | """ 128 | num_classes = kwargs.get('num_classes', 1000) 129 | model = SqueezeNet(version=1.0, num_classes=num_classes) 130 | return model 131 | 132 | 133 | def _squeezenet1_1(pretrained=False, **kwargs): 134 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 135 | `_. 136 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 137 | than SqueezeNet 1.0, without sacrificing accuracy. 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | model = SqueezeNet(version=1.1, **kwargs) 142 | if pretrained: 143 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) 144 | return model 145 | 146 | 147 | def squeezenet1_1(**kwargs): 148 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 149 | `_. 150 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 151 | than SqueezeNet 1.0, without sacrificing accuracy. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | num_classes = kwargs.get('num_classes', 1000) 156 | model = SqueezeNet(version=1.1, num_classes=num_classes) 157 | return model 158 | -------------------------------------------------------------------------------- /nets/vgg_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ''' 3 | https://github.com/pytorch/vision.git 4 | ''' 5 | 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | import math 9 | 10 | 11 | __all__ = [ 12 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 13 | 'vgg19_bn', 'vgg19', 14 | ] 15 | 16 | 17 | model_urls = { 18 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 19 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 20 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 21 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 22 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 23 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 24 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 25 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 26 | } 27 | 28 | 29 | class VGG(nn.Module): 30 | 31 | def __init__(self, features, num_classes=1000, init_weights=True): 32 | super(VGG, self).__init__() 33 | self.features = features 34 | self.classifier = nn.Sequential( 35 | nn.Linear(512 * 7 * 7, 4096), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(4096, num_classes), 42 | ) 43 | if init_weights: 44 | self._initialize_weights() 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = x.view(x.size(0), -1) 49 | x = self.classifier(x) 50 | return x 51 | 52 | def _initialize_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.weight.data.normal_(0, 0.01) 64 | m.bias.data.zero_() 65 | 66 | 67 | def make_layers(cfg, batch_norm=False): 68 | layers = [] 69 | in_channels = 3 70 | for v in cfg: 71 | if v == 'M': 72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 73 | else: 74 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 75 | if batch_norm: 76 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 77 | else: 78 | layers += [conv2d, nn.ReLU(inplace=True)] 79 | in_channels = v 80 | return nn.Sequential(*layers) 81 | 82 | 83 | cfg = { 84 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 85 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 86 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 87 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 88 | } 89 | 90 | 91 | def vgg11(**kwargs): 92 | """VGG 11-layer model (configuration "A") 93 | Args: 94 | pretrained (bool): If True, returns a model pre-trained on ImageNet 95 | """ 96 | pretrained = kwargs.get('pretrained', False) 97 | num_classes = kwargs.get('num_classes', 1000) 98 | if pretrained: 99 | kwargs['init_weights'] = False 100 | model = VGG(make_layers(cfg['A']), num_classes=num_classes)#, **kwargs) 101 | #if pretrained: 102 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 103 | return model 104 | 105 | 106 | def vgg11_bn(**kwargs): 107 | """VGG 11-layer model (configuration "A") with batch normalization 108 | Args: 109 | pretrained (bool): If True, returns a model pre-trained on ImageNet 110 | """ 111 | pretrained = kwargs.get('pretrained', False) 112 | num_classes = kwargs.get('num_classes', 1000) 113 | if pretrained: 114 | kwargs['init_weights'] = False 115 | model = VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes)#, **kwargs) 116 | #if pretrained: 117 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 118 | return model 119 | 120 | 121 | def vgg13(**kwargs): 122 | """VGG 13-layer model (configuration "B") 123 | Args: 124 | pretrained (bool): If True, returns a model pre-trained on ImageNet 125 | """ 126 | pretrained = kwargs.get('pretrained', False) 127 | num_classes = kwargs.get('num_classes', 1000) 128 | if pretrained: 129 | kwargs['init_weights'] = False 130 | model = VGG(make_layers(cfg['B']), num_classes=num_classes)#, **kwargs) 131 | #if pretrained: 132 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 133 | return model 134 | 135 | 136 | def vgg13_bn(**kwargs): 137 | """VGG 13-layer model (configuration "B") with batch normalization 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | pretrained = kwargs.get('pretrained', False) 142 | num_classes = kwargs.get('num_classes', 1000) 143 | if pretrained: 144 | kwargs['init_weights'] = False 145 | model = VGG(make_layers(cfg['B'], batch_norm=True), num_classes=num_classes)#, **kwargs) 146 | #if pretrained: 147 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 148 | return model 149 | 150 | 151 | def vgg16(**kwargs): 152 | """VGG 16-layer model (configuration "D") 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | pretrained = kwargs.get('pretrained', False) 157 | num_classes = kwargs.get('num_classes', 1000) 158 | if pretrained: 159 | kwargs['init_weights'] = False 160 | model = VGG(make_layers(cfg['D']), num_classes=num_classes)#, **kwargs) 161 | #if pretrained: 162 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 163 | return model 164 | 165 | 166 | def vgg16_bn(**kwargs): 167 | """VGG 16-layer model (configuration "D") with batch normalization 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | pretrained = kwargs.get('pretrained', False) 172 | num_classes = kwargs.get('num_classes', 1000) 173 | if pretrained: 174 | kwargs['init_weights'] = False 175 | model = VGG(make_layers(cfg['D'], batch_norm=True), num_classes=num_classes)#, **kwargs) 176 | #if pretrained: 177 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 178 | return model 179 | 180 | 181 | def vgg19(**kwargs): 182 | """VGG 19-layer model (configuration "E") 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | pretrained = kwargs.get('pretrained', False) 187 | num_classes = kwargs.get('num_classes', 1000) 188 | if pretrained: 189 | kwargs['init_weights'] = False 190 | model = VGG(make_layers(cfg['E']), num_classes=num_classes)#, **kwargs) 191 | #if pretrained: 192 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 193 | return model 194 | 195 | 196 | def vgg19_bn(**kwargs): 197 | """VGG 19-layer model (configuration 'E') with batch normalization 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | pretrained = kwargs.get('pretrained', False) 202 | num_classes = kwargs.get('num_classes', 1000) 203 | if pretrained: 204 | kwargs['init_weights'] = False 205 | model = VGG(make_layers(cfg['E'], batch_norm=True), num_classes=num_classes)#, **kwargs) 206 | #if pretrained: 207 | # model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 208 | return model 209 | -------------------------------------------------------------------------------- /nets/xception_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | 4 | @author: tstandley 5 | Adapted by cadene 6 | 7 | Creates an Xception Model as defined in: 8 | 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | 13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 14 | 15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 16 | 17 | REMEMBER to set your image size to 3x299x299 for both test and validation 18 | 19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 20 | std=[0.5, 0.5, 0.5]) 21 | 22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 23 | """ 24 | from __future__ import print_function, division, absolute_import 25 | import math 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | import torch.utils.model_zoo as model_zoo 30 | from torch.nn import init 31 | 32 | __all__ = ['xception'] 33 | 34 | pretrained_settings = { 35 | 'xception': { 36 | 'imagenet': { 37 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', 38 | 'input_space': 'RGB', 39 | 'input_size': [3, 299, 299], 40 | 'input_range': [0, 1], 41 | 'mean': [0.5, 0.5, 0.5], 42 | 'std': [0.5, 0.5, 0.5], 43 | 'num_classes': 1000, 44 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 45 | } 46 | } 47 | } 48 | 49 | 50 | class SeparableConv2d(nn.Module): 51 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 52 | super(SeparableConv2d,self).__init__() 53 | 54 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 55 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 56 | 57 | def forward(self,x): 58 | x = self.conv1(x) 59 | x = self.pointwise(x) 60 | return x 61 | 62 | 63 | class Block(nn.Module): 64 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 65 | super(Block, self).__init__() 66 | 67 | if out_filters != in_filters or strides!=1: 68 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 69 | self.skipbn = nn.BatchNorm2d(out_filters) 70 | else: 71 | self.skip=None 72 | 73 | self.relu = nn.ReLU(inplace=True) 74 | rep=[] 75 | 76 | filters=in_filters 77 | if grow_first: 78 | rep.append(self.relu) 79 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 80 | rep.append(nn.BatchNorm2d(out_filters)) 81 | filters = out_filters 82 | 83 | for i in range(reps-1): 84 | rep.append(self.relu) 85 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 86 | rep.append(nn.BatchNorm2d(filters)) 87 | 88 | if not grow_first: 89 | rep.append(self.relu) 90 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 91 | rep.append(nn.BatchNorm2d(out_filters)) 92 | 93 | if not start_with_relu: 94 | rep = rep[1:] 95 | else: 96 | rep[0] = nn.ReLU(inplace=False) 97 | 98 | if strides != 1: 99 | rep.append(nn.MaxPool2d(3,strides,1)) 100 | self.rep = nn.Sequential(*rep) 101 | 102 | def forward(self,inp): 103 | x = self.rep(inp) 104 | 105 | if self.skip is not None: 106 | skip = self.skip(inp) 107 | skip = self.skipbn(skip) 108 | else: 109 | skip = inp 110 | 111 | x+=skip 112 | return x 113 | 114 | 115 | class Xception(nn.Module): 116 | """ 117 | Xception optimized for the ImageNet dataset, as specified in 118 | https://arxiv.org/pdf/1610.02357.pdf 119 | """ 120 | def __init__(self, num_classes=1000): 121 | """ Constructor 122 | Args: 123 | num_classes: number of classes 124 | """ 125 | super(Xception, self).__init__() 126 | self.num_classes = num_classes 127 | 128 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) 129 | self.bn1 = nn.BatchNorm2d(32) 130 | self.relu = nn.ReLU(inplace=True) 131 | 132 | self.conv2 = nn.Conv2d(32,64,3,bias=False) 133 | self.bn2 = nn.BatchNorm2d(64) 134 | #do relu here 135 | 136 | self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) 137 | self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) 138 | self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) 139 | 140 | self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) 141 | self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) 142 | self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) 143 | self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) 144 | 145 | self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) 146 | self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) 147 | self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) 148 | self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) 149 | 150 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 151 | 152 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 153 | self.bn3 = nn.BatchNorm2d(1536) 154 | 155 | #do relu here 156 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 157 | self.bn4 = nn.BatchNorm2d(2048) 158 | 159 | self.last_linear = nn.Linear(2048, num_classes) 160 | # self.last_linear = self.fc 161 | # #------- init weights -------- 162 | # for m in self.modules(): 163 | # if isinstance(m, nn.Conv2d): 164 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 165 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 166 | # elif isinstance(m, nn.BatchNorm2d): 167 | # m.weight.data.fill_(1) 168 | # m.bias.data.zero_() 169 | # #----------------------------- 170 | 171 | def features(self, input): 172 | x = self.conv1(input) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | 176 | x = self.conv2(x) 177 | x = self.bn2(x) 178 | x = self.relu(x) 179 | 180 | x = self.block1(x) 181 | x = self.block2(x) 182 | x = self.block3(x) 183 | x = self.block4(x) 184 | x = self.block5(x) 185 | x = self.block6(x) 186 | x = self.block7(x) 187 | x = self.block8(x) 188 | x = self.block9(x) 189 | x = self.block10(x) 190 | x = self.block11(x) 191 | x = self.block12(x) 192 | 193 | x = self.conv3(x) 194 | x = self.bn3(x) 195 | x = self.relu(x) 196 | 197 | x = self.conv4(x) 198 | x = self.bn4(x) 199 | return x 200 | 201 | def logits(self, features): 202 | x = self.relu(features) 203 | 204 | x = F.adaptive_avg_pool2d(x, (1, 1)) 205 | x = x.view(x.size(0), -1) 206 | x = self.last_linear(x) 207 | return x 208 | 209 | def forward(self, input): 210 | x = self.features(input) 211 | x = self.logits(x) 212 | return x 213 | 214 | 215 | def _xception(num_classes=1000, pretrained='imagenet'): 216 | model = Xception(num_classes=num_classes) 217 | if pretrained: 218 | settings = pretrained_settings['xception'][pretrained] 219 | assert num_classes == settings['num_classes'], \ 220 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 221 | 222 | model = Xception(num_classes=num_classes) 223 | model.load_state_dict(model_zoo.load_url(settings['url'])) 224 | 225 | model.input_space = settings['input_space'] 226 | model.input_size = settings['input_size'] 227 | model.input_range = settings['input_range'] 228 | model.mean = settings['mean'] 229 | model.std = settings['std'] 230 | 231 | # TODO: ugly 232 | model.last_linear = model.fc 233 | del model.fc 234 | return model 235 | 236 | 237 | def xception(**kwargs): 238 | num_classes = kwargs.get('num_classes', 1000) 239 | model = Xception(num_classes=num_classes) 240 | return model 241 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.1.1 2 | cloudpickle==0.5.6 3 | cycler==0.10.0 4 | dask==0.19.2 5 | decorator==4.3.0 6 | easydict==1.8 7 | imgaug==0.2.6 8 | kiwisolver==1.0.1 9 | matplotlib==3.0.0 10 | networkx==2.2 11 | numpy==1.15.2 12 | opencv-python==3.4.3.18 13 | Pillow==5.2.0 14 | pkg-resources==0.0.0 15 | protobuf==3.6.1 16 | pyparsing==2.2.1 17 | python-dateutil==2.7.3 18 | PyWavelets==1.0.1 19 | scikit-image==0.14.0 20 | scipy==1.1.0 21 | six==1.11.0 22 | tensorboardX==1.4 23 | toolz==0.9.0 24 | torch==0.4.1 25 | torchvision==0.2.1 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import textwrap 4 | import time 5 | import os, sys 6 | sys.path.append(os.path.dirname(__file__)) 7 | from utils.config import process_config, check_config_dict 8 | from utils.logger import ExampleLogger 9 | from trainers.example_model import ExampleModel 10 | from trainers.example_trainer import ExampleTrainer 11 | from data_loader.dataset import get_data_loader 12 | 13 | config = process_config(os.path.join(os.path.dirname(__file__), 'configs', 'config.json')) 14 | 15 | class ImageClassificationPytorch: 16 | def __init__(self, config): 17 | gpu_id = config['gpu_id'] 18 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 19 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 20 | check_config_dict(config) 21 | self.config = config 22 | self.init() 23 | 24 | 25 | def init(self): 26 | # create net 27 | self.model = ExampleModel(self.config) 28 | # load 29 | self.model.load() 30 | # create your data generator 31 | self.train_loader, self.test_loader = get_data_loader(self.config) 32 | # create logger 33 | self.logger = ExampleLogger(self.config) 34 | # create trainer and path all previous components to it 35 | self.trainer = ExampleTrainer(self.model, self.train_loader, self.test_loader, self.config, self.logger) 36 | 37 | 38 | def run(self): 39 | # here you train your model 40 | self.trainer.train() 41 | 42 | 43 | def close(self): 44 | # close 45 | self.logger.close() 46 | 47 | 48 | def main(): 49 | imageClassificationPytorch = ImageClassificationPytorch(config) 50 | imageClassificationPytorch.run() 51 | imageClassificationPytorch.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | now = time.strftime('%Y-%m-%d | %H:%M:%S', time.localtime(time.time())) 57 | 58 | print('----------------------------------------------------------------------') 59 | print('Time: ' + now) 60 | print('----------------------------------------------------------------------') 61 | print(' Now start ...') 62 | print('----------------------------------------------------------------------') 63 | 64 | main() 65 | 66 | print('----------------------------------------------------------------------') 67 | print(' All Done!') 68 | print('----------------------------------------------------------------------') 69 | print('Start time: ' + now) 70 | print('Now time: ' + time.strftime('%Y-%m-%d | %H:%M:%S', time.localtime(time.time()))) 71 | print('----------------------------------------------------------------------') -------------------------------------------------------------------------------- /trainers/base_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import math 4 | import os 5 | from collections import OrderedDict 6 | import numpy as np 7 | import torch 8 | 9 | class BaseModel: 10 | def __init__(self,config): 11 | self.config = config 12 | 13 | # save function thet save the checkpoint in the path defined in configfile 14 | def save(self): 15 | """ 16 | implement the logic of saving model 17 | """ 18 | print("Saving model...") 19 | save_path = self.config['save_path'] 20 | if not os.path.exists(save_path): 21 | os.makedirs(save_path) 22 | save_name = os.path.join(save_path,self.config['save_name']) 23 | state_dict = OrderedDict() 24 | for item, value in self.net.state_dict().items(): 25 | if 'module' in item.split('.')[0]: 26 | name = '.'.join(item.split('.')[1:]) 27 | else: 28 | name = item 29 | state_dict[name] = value 30 | torch.save(state_dict, save_name) 31 | print("Model saved: ", save_name) 32 | 33 | # load lateset checkpoint from the experiment path defined in config_file 34 | def load(self): 35 | """ 36 | implement the logic of loading model 37 | """ 38 | raise NotImplementedError 39 | 40 | 41 | def build_model(self): 42 | """ 43 | implement the logic of model 44 | """ 45 | raise NotImplementedError -------------------------------------------------------------------------------- /trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import print_function 3 | import os 4 | import time 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | 9 | class BaseTrainer: 10 | def __init__(self, model, train_loader, val_loader, config, logger): 11 | self.model = model 12 | self.logger = logger 13 | self.config = config 14 | self.train_loader = train_loader 15 | self.val_loader = val_loader 16 | self.eval_train = 0. 17 | self.eval_validate = 0. 18 | self.optimizer = None 19 | self.loss = None 20 | 21 | 22 | def train(self): 23 | total_epoch_num = self.config['num_epochs'] 24 | if self.config['evaluate_before_train']: 25 | self.evaluate_epoch() 26 | self.eval_validate = self.eval_top1.avg 27 | print('\n') 28 | for cur_epoch in range(1, total_epoch_num+1): 29 | epoch_start_time = time.time() 30 | self.cur_epoch = cur_epoch 31 | self.train_epoch() 32 | self.evaluate_epoch() 33 | self.eval_train, self.eval_validate = self.train_top1.avg, self.eval_top1.avg 34 | # printer 35 | self.logger.log_printer.epoch_case_print(self.cur_epoch, 36 | self.train_top1.avg, self.eval_top1.avg, 37 | self.train_losses.avg, self.eval_losses.avg, 38 | time.time()-epoch_start_time) 39 | # save model 40 | self.model.save() 41 | # logger 42 | self.logger.write_info_to_logger(variable_dict={'epoch':self.cur_epoch, 'lr':self.learning_rate, 43 | 'train_acc':self.eval_train,'validate_acc':self.eval_validate, 44 | 'train_avg_loss':self.train_losses.avg,'validate_avg_loss':self.eval_losses.avg, 45 | 'gpus_index': self.config['gpu_id'], 46 | 'save_name': os.path.join(self.config['save_path'], 47 | self.config['save_name']), 48 | 'net_name': self.config['model_net_name']}) 49 | self.logger.write() 50 | # tensorboard summary 51 | if self.config['is_tensorboard']: 52 | self.logger.summarizer.data_summarize(self.cur_epoch, summarizer='train', 53 | summaries_dict={'train_acc': self.eval_train, 'train_avg_loss': self.train_losses.avg}) 54 | self.logger.summarizer.data_summarize(self.cur_epoch, summarizer='validate', 55 | summaries_dict={'validate_acc': self.eval_validate,'validate_avg_loss': self.eval_losses.avg}) 56 | # if self.cur_epoch == total_epoch_num: 57 | # self.logger.summarizer.graph_summary(self.model.net) 58 | 59 | 60 | def train_epoch(self): 61 | """ 62 | implement the logic of epoch: 63 | -loop ever the number of iteration in the config and call teh train step 64 | """ 65 | raise NotImplementedError 66 | 67 | 68 | def train_step(self): 69 | """ 70 | implement the logic of the train step 71 | """ 72 | raise NotImplementedError 73 | 74 | 75 | def evaluate_epoch(self): 76 | """ 77 | implement the logic of epoch: 78 | -loop ever the number of iteration in the config and call teh train step 79 | """ 80 | raise NotImplementedError 81 | 82 | 83 | def evaluate_step(self): 84 | """ 85 | implement the logic of the train step 86 | """ 87 | raise NotImplementedError 88 | 89 | 90 | def get_loss(self): 91 | """ 92 | implement the logic of model loss 93 | """ 94 | raise NotImplementedError 95 | 96 | 97 | def create_optimization(self): 98 | """ 99 | implement the logic of the optimization 100 | """ 101 | raise NotImplementedError -------------------------------------------------------------------------------- /trainers/example_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import math 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from trainers.base_model import BaseModel 10 | from nets.net_interface import NetModule 11 | 12 | class ExampleModel(BaseModel): 13 | def __init__(self, config): 14 | super(ExampleModel, self).__init__(config) 15 | self.config = config 16 | self.interface = NetModule(self.config['model_module_name'], self.config['model_net_name']) 17 | self.create_model() 18 | 19 | 20 | def create_model(self): 21 | self.net = self.interface.create_model(num_classes=self.config['num_classes']) 22 | if torch.cuda.is_available(): 23 | self.net.cuda() 24 | 25 | 26 | def load(self): 27 | # train_mode: 0:from scratch, 1:finetuning, 2:update 28 | # if not update all parameters: 29 | # for param in list(self.net.parameters())[:-1]: # only update parameters of last layer 30 | # param.requires_grad = False 31 | train_mode = self.config['train_mode'] 32 | if train_mode == 'fromscratch': 33 | if torch.cuda.device_count() > 1: 34 | self.net = nn.DataParallel(self.net) 35 | if torch.cuda.is_available(): 36 | self.net.cuda() 37 | print('from scratch...') 38 | 39 | elif train_mode == 'finetune': 40 | self._load() 41 | if torch.cuda.device_count() > 1: 42 | self.net = nn.DataParallel(self.net,device_ids=range(torch.cuda.device_count())) 43 | if torch.cuda.is_available(): 44 | self.net.cuda() 45 | print('finetuning...') 46 | 47 | elif train_mode == 'update': 48 | self._load() 49 | print('updating...') 50 | 51 | else: 52 | ValueError('train_mode is error...') 53 | 54 | 55 | def _load(self): 56 | _state_dict = torch.load(os.path.join(self.config['pretrained_path'], self.config['pretrained_file']), 57 | map_location=None) 58 | # for multi-gpus 59 | state_dict = OrderedDict() 60 | for item, value in _state_dict.items(): 61 | if 'module' in item.split('.')[0]: 62 | name = '.'.join(item.split('.')[1:]) 63 | else: 64 | name = item 65 | state_dict[name] = value 66 | # for handling in case of different models compared to the saved pretrain-weight 67 | model_dict = self.net.state_dict() 68 | same = {k: v for k, v in state_dict.items() if \ 69 | (k in model_dict and model_dict[k].size() == v.size())} # or (k not in state_dict)} 70 | diff = {k: v for k, v in state_dict.items() if \ 71 | (k in model_dict and model_dict[k].size() != v.size()) or (k not in model_dict)} 72 | print('diff: ', [i for i, v in diff.items()]) 73 | model_dict.update(same) 74 | self.net.load_state_dict(model_dict) 75 | 76 | -------------------------------------------------------------------------------- /trainers/example_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from trainers.base_trainer import BaseTrainer 7 | from utils import utils 8 | 9 | class ExampleTrainer(BaseTrainer): 10 | def __init__(self, model, train_loader, val_loader, config, logger): 11 | super(ExampleTrainer, self).__init__(model, train_loader, val_loader, config, logger) 12 | self.create_optimization() 13 | 14 | 15 | def train_epoch(self): 16 | """ 17 | training in a epoch 18 | :return: 19 | """ 20 | # Learning rate adjustment 21 | self.learning_rate = self.adjust_learning_rate(self.optimizer, self.cur_epoch) 22 | self.train_losses = utils.AverageMeter() 23 | self.train_top1 = utils.AverageMeter() 24 | self.train_top5 = utils.AverageMeter() 25 | # Set the model to be in training mode (for dropout and batchnorm) 26 | self.model.net.train() 27 | for batch_idx, (batch_x, batch_y) in enumerate(self.train_loader): 28 | if torch.cuda.is_available(): 29 | batch_x, batch_y = batch_x.cuda(async=self.config['async_loading']), batch_y.cuda(async=self.config['async_loading']) 30 | batch_x_var, batch_y_var = Variable(batch_x), Variable(batch_y) 31 | self.train_step(batch_x_var, batch_y_var) 32 | 33 | # printer 34 | self.logger.log_printer.iter_case_print(self.cur_epoch, self.eval_train, self.eval_validate, 35 | len(self.train_loader), batch_idx+1, self.train_losses.avg, self.learning_rate) 36 | 37 | # tensorboard summary 38 | if self.config['is_tensorboard']: 39 | self.logger.summarizer.data_summarize(batch_idx, summarizer="train", summaries_dict={"lr":self.learning_rate, 'train_loss':self.train_losses.avg}) 40 | 41 | time.sleep(1) 42 | 43 | 44 | def train_step(self, images, labels): 45 | """ 46 | training in a step 47 | :param images: 48 | :param labels: 49 | :return: 50 | """ 51 | # Forward pass 52 | infer = self.model.net(images) 53 | 54 | # label to one_hot 55 | # ids = labels.long().view(-1,1) 56 | # print(ids) 57 | # # one_hot_labels = torch.zeros(32, 2).scatter_(dim=1, index=ids, value=1.) 58 | 59 | # Loss function 60 | losses = self.get_loss(infer,labels) 61 | 62 | loss = losses.item()#.data[0] 63 | # measure accuracy and record loss 64 | prec1, prec5 = self.compute_accuracy(infer.data, labels.data, topk=(1, 5)) 65 | self.train_losses.update(loss, images.size(0)) 66 | self.train_top1.update(prec1[0], images.size(0)) 67 | self.train_top5.update(prec5[0], images.size(0)) 68 | # Optimization step 69 | if torch.cuda.device_count() > 1 and torch.cuda.is_available(): 70 | self.optimizer.module.zero_grad() 71 | else: 72 | self.optimizer.zero_grad() 73 | losses.backward() 74 | if torch.cuda.device_count() > 1 and torch.cuda.is_available(): 75 | self.optimizer.module.step() 76 | else: 77 | self.optimizer.step() 78 | 79 | 80 | def get_loss(self, pred, label): 81 | """ 82 | compute loss 83 | :param pred: 84 | :param label: 85 | :return: 86 | """ 87 | criterion = nn.CrossEntropyLoss() # nn.MSELoss() 88 | if torch.cuda.is_available(): 89 | criterion.cuda() 90 | return criterion(pred, label) 91 | 92 | 93 | def create_optimization(self): 94 | """ 95 | optimizer 96 | :return: 97 | """ 98 | self.optimizer = torch.optim.Adam(self.model.net.parameters(), 99 | lr=self.config['learning_rate'], weight_decay=0) #lr:1e-4 100 | if torch.cuda.device_count() > 1: 101 | print('optimizer device_count: ',torch.cuda.device_count()) 102 | self.optimizer = nn.DataParallel(self.optimizer,device_ids=range(torch.cuda.device_count())) 103 | """ 104 | # optimizing parameters seperately 105 | ignored_params = list(map(id, self.model.net.fc.parameters())) 106 | base_params = filter(lambda p: id(p) not in ignored_params, 107 | self.model.net.parameters()) 108 | self.optimizer = torch.optim.Adam([ 109 | {'params': base_params}, 110 | {'params': self.model.net.fc.parameters(), 'lr': 1e-3} 111 | ], lr=1e-2, betas=(0.9, 0.99), eps=1e-08, weight_decay=0, amsgrad=False)""" 112 | 113 | 114 | def adjust_learning_rate(self, optimizer, epoch): 115 | """ 116 | decay learning rate 117 | :param optimizer: 118 | :param epoch: the first epoch is 1 119 | :return: 120 | """ 121 | # """Decay Learning rate at 1/2 and 3/4 of the num_epochs""" 122 | # lr = lr_init 123 | # if epoch >= num_epochs * 0.75: 124 | # lr *= decay_rate ** 2 125 | # elif epoch >= num_epochs * 0.5: 126 | # lr *= decay_rate 127 | learning_rate = self.config['learning_rate'] * (self.config['learning_rate_decay'] ** ((epoch - 1) // self.config['learning_rate_decay_epoch'])) 128 | if torch.cuda.device_count() > 1 and torch.cuda.is_available(): 129 | for param_group in optimizer.module.param_groups: 130 | param_group['lr'] = learning_rate 131 | else: 132 | for param_group in optimizer.param_groups: 133 | param_group['lr'] = learning_rate 134 | 135 | return learning_rate 136 | 137 | 138 | def compute_accuracy(self, output, target, topk=(1,)): 139 | """ 140 | compute top-n accuracy 141 | :param output: 142 | :param target: 143 | :param topk: 144 | :return: 145 | """ 146 | maxk = max(topk) 147 | batch_size = target.size(0) 148 | _, idx = output.topk(maxk, 1, True, True) 149 | idx = idx.t() 150 | correct = idx.eq(target.view(1, -1).expand_as(idx)) 151 | acc_arr = [] 152 | for k in topk: 153 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 154 | acc_arr.append(correct_k.mul_(1.0 / batch_size)) 155 | return acc_arr 156 | 157 | 158 | def evaluate_epoch(self): 159 | """ 160 | evaluating in a epoch 161 | :return: 162 | """ 163 | self.eval_losses = utils.AverageMeter() 164 | self.eval_top1 = utils.AverageMeter() 165 | self.eval_top5 = utils.AverageMeter() 166 | # Set the model to be in testing mode (for dropout and batchnorm) 167 | self.model.net.eval() 168 | for batch_idx, (batch_x, batch_y) in enumerate(self.val_loader): 169 | if torch.cuda.is_available(): 170 | batch_x, batch_y = batch_x.cuda(async=self.config['async_loading']), batch_y.cuda(async=self.config['async_loading']) 171 | batch_x_var, batch_y_var = Variable(batch_x), Variable(batch_y) 172 | self.evaluate_step(batch_x_var, batch_y_var) 173 | utils.view_bar(batch_idx+1, len(self.val_loader)) 174 | 175 | 176 | def evaluate_step(self, images, labels): 177 | """ 178 | evaluating in a step 179 | :param images: 180 | :param labels: 181 | :return: 182 | """ 183 | with torch.no_grad(): 184 | infer = self.model.net(images) 185 | # label to one_hot 186 | # ids = labels.long().view(-1, 1) 187 | # one_hot_labels = torch.zeros(32, 2).scatter_(dim=1, index=ids, value=1.) 188 | 189 | # Loss function 190 | losses = self.get_loss(infer, labels) 191 | loss = losses.item()#losses.data[0] 192 | 193 | # measure accuracy and record loss 194 | prec1, prec5 = self.compute_accuracy(infer.data, labels.data, topk=(1, 5)) 195 | 196 | self.eval_losses.update(loss, images.size(0)) # loss.data[0] 197 | self.eval_top1.update(prec1[0], images.size(0)) 198 | self.eval_top5.update(prec5[0], images.size(0)) 199 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import json 4 | from easydict import EasyDict as edict 5 | 6 | def process_config(jsonfile=None): 7 | try: 8 | if jsonfile is not None: 9 | with open(jsonfile, 'r') as config_file: 10 | config_args_dict = json.load(config_file) 11 | else: 12 | print("Add a config file using \'--config file_name.json\'", file=sys.stderr) 13 | exit(1) 14 | 15 | except FileNotFoundError: 16 | print("ERROR: Config file not found: {}".format(args.config), file=sys.stderr) 17 | exit(1) 18 | except json.decoder.JSONDecodeError: 19 | print("ERROR: Config file is not a proper JSON file!", file=sys.stderr) 20 | exit(1) 21 | 22 | config_args = edict(config_args_dict) 23 | 24 | print(config_args) 25 | print("\n") 26 | 27 | return config_args 28 | 29 | 30 | def check_config_dict(config_dict): 31 | """ 32 | check configuration 33 | :param config_dict: input config 34 | :return: 35 | """ 36 | if isinstance(config_dict["model_module_name"],str) is False: 37 | raise TypeError("model_module_name param input err...") 38 | if isinstance(config_dict["model_net_name"],str) is False: 39 | raise TypeError("model_net_name param input err...") 40 | if isinstance(config_dict["gpu_id"],str) is False: 41 | 42 | raise TypeError("gpu_id param input err...") 43 | if isinstance(config_dict["async_loading"],bool) is False: 44 | raise TypeError("async_loading param input err...") 45 | if isinstance(config_dict["is_tensorboard"],bool) is False: 46 | raise TypeError("is_tensorboard param input err...") 47 | if isinstance(config_dict["evaluate_before_train"],bool) is False: 48 | raise TypeError("evaluate_before_train param input err...") 49 | if isinstance(config_dict["shuffle"],bool) is False: 50 | raise TypeError("shuffle param input err...") 51 | if isinstance(config_dict["data_aug"],bool) is False: 52 | raise TypeError("data_aug param input err...") 53 | 54 | if isinstance(config_dict["num_epochs"],int) is False: 55 | raise TypeError("num_epochs param input err...") 56 | if isinstance(config_dict["img_height"],int) is False: 57 | raise TypeError("img_height param input err...") 58 | if isinstance(config_dict["img_width"],int) is False: 59 | raise TypeError("img_width param input err...") 60 | if isinstance(config_dict["num_channels"],int) is False: 61 | raise TypeError("num_channels param input err...") 62 | if isinstance(config_dict["num_classes"],int) is False: 63 | raise TypeError("num_classes param input err...") 64 | if isinstance(config_dict["batch_size"],int) is False: 65 | raise TypeError("batch_size param input err...") 66 | if isinstance(config_dict["dataloader_workers"],int) is False: 67 | raise TypeError("dataloader_workers param input err...") 68 | if isinstance(config_dict["learning_rate"],(int,float)) is False: 69 | raise TypeError("learning_rate param input err...") 70 | if isinstance(config_dict["learning_rate_decay"],(int,float)) is False: 71 | raise TypeError("learning_rate_decay param input err...") 72 | if isinstance(config_dict["learning_rate_decay_epoch"],int) is False: 73 | raise TypeError("learning_rate_decay_epoch param input err...") 74 | 75 | if isinstance(config_dict["train_mode"],str) is False: 76 | raise TypeError("train_mode param input err...") 77 | if isinstance(config_dict["file_label_separator"],str) is False: 78 | raise TypeError("file_label_separator param input err...") 79 | if isinstance(config_dict["pretrained_path"],str) is False: 80 | raise TypeError("pretrained_path param input err...") 81 | if isinstance(config_dict["pretrained_file"],str) is False: 82 | raise TypeError("pretrained_file param input err...") 83 | if isinstance(config_dict["save_path"],str) is False: 84 | raise TypeError("save_path param input err...") 85 | if isinstance(config_dict["save_name"],str) is False: 86 | raise TypeError("save_name param input err...") 87 | 88 | if not os.path.exists(os.path.join(config_dict["pretrained_path"], config_dict["pretrained_file"])): 89 | raise ValueError("cannot find pretrained_path or pretrained_file...") 90 | if not os.path.exists(config_dict["save_path"]): 91 | raise ValueError("cannot find save_path...") 92 | 93 | if isinstance(config_dict["train_data_root_dir"],str) is False: 94 | raise TypeError("train_data_root_dir param input err...") 95 | if isinstance(config_dict["val_data_root_dir"],str) is False: 96 | raise TypeError("val_data_root_dir param input err...") 97 | if isinstance(config_dict["train_data_file"],str) is False: 98 | raise TypeError("train_data_file param input err...") 99 | if isinstance(config_dict["val_data_file"],str) is False: 100 | raise TypeError("val_data_file param input err...") 101 | 102 | if not os.path.exists(config_dict["train_data_root_dir"]): 103 | raise ValueError("cannot find train_data_root_dir...") 104 | if not os.path.exists(config_dict["val_data_root_dir"]): 105 | raise ValueError("cannot find val_data_root_dir...") 106 | if not os.path.exists(config_dict["train_data_file"]): 107 | raise ValueError("cannot find train_data_file...") 108 | if not os.path.exists(config_dict["val_data_file"]): 109 | raise ValueError("cannot find val_data_file...") 110 | 111 | 112 | 113 | #global_config = process_config('configs/config.json') 114 | 115 | if __name__ == '__main__': 116 | config = global_config 117 | print(config['experiment_dir']) 118 | print('done') -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import print_function 3 | import os, sys 4 | import numpy as np 5 | import logging 6 | from utils import utils 7 | from utils import logger_summarizer 8 | 9 | class ExampleLogger: 10 | """ 11 | self.log_writer.info("log") 12 | self.log_writer.warning("warning log) 13 | self.log_writer.error("error log ") 14 | 15 | try: 16 | main() 17 | except KeyboardInterrupt: 18 | sys.stdout.flush() 19 | """ 20 | def __init__(self, config): 21 | self.config = config 22 | self.log_writer = self.init() 23 | self.log_printer = DefinedPrinter() 24 | if self.config['is_tensorboard']: 25 | self.summarizer = logger_summarizer.Logger(self.config) 26 | self.log_info = {} 27 | 28 | 29 | def init(self): 30 | """ 31 | initial 32 | :return: 33 | """ 34 | log_writer = logging.getLogger(__name__) 35 | log_writer.setLevel(logging.INFO) 36 | self.log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') 37 | if not os.path.exists(self.log_dir): 38 | os.makedirs(self.log_dir) 39 | handler = logging.FileHandler(os.path.join(self.log_dir, 'alg_training.log'),encoding='utf-8') 40 | handler.setLevel(logging.DEBUG) 41 | logging_format = logging.Formatter("@%(asctime)s [%(filename)s -- %(funcName)s]%(lineno)s : %(levelname)s - %(message)s", datefmt='%Y-%m-%d %A %H:%M:%S') 42 | handler.setFormatter(logging_format) 43 | log_writer.addHandler(handler) 44 | return log_writer 45 | 46 | 47 | def write_info_to_logger(self, variable_dict): 48 | """ 49 | print 50 | :param variable_dict: 51 | :return: 52 | """ 53 | if variable_dict is not None: 54 | for tag, value in variable_dict.items(): 55 | self.log_info[tag] = value 56 | 57 | 58 | def write(self): 59 | """ 60 | log writing 61 | :return: 62 | """ 63 | _info = 'epoch: %d, lr: %f, eval_train: %f, eval_validate: %f, train_avg_loss: %f, validate_avg_loss: %f, gpu_index: %s, net: %s, save: %s' % ( 64 | self.log_info['epoch'],self.log_info['lr'], self.log_info['train_acc'], self.log_info['validate_acc'], 65 | self.log_info['train_avg_loss'], self.log_info['validate_avg_loss'], 66 | self.log_info['gpus_index'], self.log_info['net_name'], self.log_info['save_name']) 67 | 68 | self.log_writer.info(_info) 69 | sys.stdout.flush() 70 | 71 | 72 | def write_warning(self, warning_dict): 73 | """ 74 | warninginfo writing 75 | :return: 76 | """ 77 | _info = 'epoch: %d, lr: %f, loss: %f'%(warning_dict['epoch'],warning_dict['lr'], warning_dict['loss']) 78 | self.log_writer.warning(_info) 79 | sys.stdout.flush() 80 | 81 | def clear(self): 82 | """ 83 | clear log_info 84 | :return: 85 | """ 86 | self.log_info = {} 87 | 88 | 89 | def close(self): 90 | if self.config['is_tensorboard']: 91 | self.summarizer.train_summary_writer.close() 92 | self.summarizer.validate_summary_writer.close() 93 | 94 | 95 | 96 | class DefinedPrinter: 97 | """ 98 | Printer 99 | """ 100 | 101 | def init_case_print(self, loss_start, eval_start_train, eval_start_val): 102 | """ 103 | print when init 104 | :param loss_start: 105 | :param eval_start_train: 106 | :param eval_start_val: 107 | :return: 108 | """ 109 | log = "\nInitial Situation:\n" + \ 110 | "Loss= \033[1;32m" + "{:.6f}".format(loss_start) + "\033[0m, " + \ 111 | "Training EVAL= \033[1;36m" + "{:.5f}".format(eval_start_train * 100) + "%\033[0m , " + \ 112 | "Validating EVAL= \033[0;31m" + "{:.5f}".format(eval_start_val * 100) + '%\033[0m' 113 | print('\n\r', log) 114 | print('---------------------------------------------------------------------------') 115 | 116 | 117 | def iter_case_print(self, epoch, eval_train, eval_validate, limit, iteration, loss, lr): 118 | """ 119 | print per batch 120 | :param epoch: 121 | :param eval_train: 122 | :param eval_validate: 123 | :param limit: 124 | :param iteration: 125 | :param loss: 126 | :param lr: 127 | :param global_step: 128 | :return: 129 | """ 130 | 131 | log = "Epoch \033[1;33m" + str(epoch) + "\033[0m, " + \ 132 | "Iter \033[1;33m" + str(iteration) + '/' + str(limit) + "\033[0m, " + \ 133 | "Loss \033[1;32m" + "{:.6f}".format(loss) + "\033[0m, " + \ 134 | "lr \033[1;37;45m" + "{:.6f}".format(lr) + "\033[0m, " + \ 135 | "Training EVAL \033[1;36m" + "{:.5f}".format(eval_train * 100) + "%\033[0m, " + \ 136 | "Validating EVAL \033[1;31m" + "{:.5f}".format(eval_validate * 100) + "%\033[0m, " 137 | 138 | print(log) 139 | 140 | def epoch_case_print(self, epoch, eval_train, eval_validate, loss_train, loss_validate, fitTime): 141 | """ 142 | print per epoch 143 | :param epoch: 144 | :param eval_train: 145 | :param eval_validate: 146 | :param fitTime: 147 | :return: 148 | """ 149 | log = "\nEpoch \033[1;36m" + str(epoch) + "\033[0m, " + \ 150 | "Training EVAL \033[1;36m" + "{:.5f}".format(eval_train * 100) + "%\033[0m , " + \ 151 | "Validating EVAL= \033[0;31m" + "{:.5f}".format(eval_validate * 100) + '%\033[0m , ' + \ 152 | "\n\r" + \ 153 | "Training avg_loss \033[1;32m" + "{:.5f}".format(loss_train) + "\033[0m, " + \ 154 | "Validating avg_loss \033[1;32m" + "{:.5f}".format(loss_validate) + "\033[0m, " + \ 155 | "epoch time " + str(fitTime) + ' ms' + '\n' 156 | print('\n\r', log, '\n\r') 157 | 158 | -------------------------------------------------------------------------------- /utils/logger_summarizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | # import tensorflow as tf 5 | import torch 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | class Logger: 10 | """ 11 | tensorboardX summary for pytorch 12 | """ 13 | def __init__(self, config): 14 | self.config = config 15 | self.train_summary_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs", "train") 16 | self.validate_summary_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs", "val") 17 | if not os.path.exists(self.train_summary_dir): 18 | os.makedirs(self.train_summary_dir) 19 | if not os.path.exists(self.validate_summary_dir): 20 | os.makedirs(self.validate_summary_dir) 21 | self.train_summary_writer = SummaryWriter(self.train_summary_dir) 22 | self.validate_summary_writer = SummaryWriter(self.validate_summary_dir) 23 | 24 | 25 | # it can summarize scalars and images. 26 | def data_summarize(self, step, summarizer="train", summaries_dict=None): 27 | """ 28 | :param step: the step of the summary 29 | :param summarizer: use the train summary writer or the validate one 30 | :param summaries_dict: the dict of the summaries values (tag,value) 31 | :return: 32 | """ 33 | summary_writer = self.train_summary_writer if summarizer == "train" else self.validate_summary_writer 34 | if summaries_dict is not None: 35 | summary_writer.add_scalars('./', summaries_dict, step) 36 | # summary = tf.Summary() 37 | # for tag, value in summaries_dict.items(): 38 | # summary.value.add(tag=tag, simple_value=value) 39 | # summary_writer.add_summary(summary, step) 40 | # summary_writer.flush() 41 | 42 | 43 | def graph_summary(self, net, summarizer="train"): 44 | summary_writer = self.train_summary_writer if summarizer == "train" else self.validate_summary_writer 45 | input_to_model = torch.rand(1, self.config['img_height'], self.config['img_width'], self.config['num_channels']) 46 | summary_writer.add_graph(net, (input_to_model,)) 47 | 48 | 49 | def close(self): 50 | self.train_summary_writer.close() 51 | self.validate_summary_writer.close() 52 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import print_function 3 | import numpy as np 4 | import time 5 | import sys 6 | import os 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value 11 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 12 | """ 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """Computes the precision@k for the specified values of k""" 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | 38 | res = [] 39 | for k in topk: 40 | correct_k = correct[:k].view(-1).float().sum(0) 41 | res.append(correct_k.mul_(100.0 / batch_size)) 42 | return res 43 | 44 | 45 | def view_bar(num, total): 46 | """ 47 | 48 | :param num: 49 | :param total: 50 | :return: 51 | """ 52 | rate = float(num + 1) / total 53 | rate_num = int(rate * 100) 54 | if num != total: 55 | r = '\r[%s%s]%d%%' % ("=" * rate_num, " " * (100 - rate_num), rate_num,) 56 | else: 57 | r = '\r[%s%s]%d%%' % ("=" * 100, " " * 0, 100,) 58 | sys.stdout.write(r) 59 | sys.stdout.flush() 60 | --------------------------------------------------------------------------------