├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── extract_features.py ├── imgs ├── pytorch.png └── transfer-learning.jpeg ├── inference.py ├── inference ├── alexnet.sh ├── recursive_resnet.sh ├── resnet.sh └── vggnet.sh ├── main.py ├── networks ├── __init__.py └── resnet.py ├── test ├── alexnet.sh ├── resnet.sh └── vggnet.sh └── train ├── alexnet.sh ├── inception.sh ├── resnet.sh ├── squeeze.sh ├── vggnet.sh └── xception.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # torch 104 | *.t7 105 | vectors/* 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bumsoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | # fine-tuning.pytorch 4 | Pytorch implementation of Fine-Tuning (Transfer Learning) CNN Networks. 5 | This project is made by Bumsoo Kim. 6 | 7 | Korea University, Master-Ph.D intergrated Course. 8 |

9 | 10 | 11 | ## Fine-Tuning 12 | In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest. 13 | 14 | Futher explanations can be found [here](http://cs231n.github.io/transfer-learning/). 15 | 16 | ## Requirements 17 | See the [installation instruction](INSTALL.md) for a step-by-step installation guide. 18 | See the [server instruction](SERVER.md) for server settup. 19 | - Install [cuda-8.0](https://developer.nvidia.com/cuda-downloads) 20 | - Install [cudnn v5.1](https://developer.nvidia.com/cudnn) 21 | - Download [PyTorch for python-2.7](https://pytorch.org) and clone the repository. 22 | - Download [PyTorch-3.5](https://pytorch.org) for using further pretrained libraries with anaconda3. 23 | ```bash 24 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl 25 | pip install torchvision 26 | git clone https://github.com/meliketoy/resnet-fine-tuning.pytorch 27 | ``` 28 | 29 | - Download Pretrained models for PyTorch (Only for 3.5) 30 | ```bash 31 | $ git clone https://github.com/Cadene/pretrained-models.pytorch.git 32 | $ pretrained-models.pytorch 33 | $ python setup.py install 34 | ``` 35 | 36 | ## Basic Setups 37 | After you have cloned this repository into your file system, open [config.py](./config.py), 38 | And edit the lines below to your data directory. 39 | ```bash 40 | data_base = [:dir to your original dataset] 41 | aug_base = [:dir to your actually trained dataset] 42 | ``` 43 | 44 | For training, your data file system should be in the following hierarchy. 45 | Organizing codes for your data into the given requirements will be provided [here](https://github.com/meliketoy/image-preprocessing) 46 | 47 | ```bash 48 | [:data file name] 49 | 50 | |-train 51 | |-[:class 0] 52 | |-[:class 1] 53 | |-[:class 2] 54 | ... 55 | |-[:class n] 56 | |-val 57 | |-[:class 0] 58 | |-[:class 1] 59 | |-[:class 2] 60 | ... 61 | |-[:class n] 62 | ``` 63 | 64 | ## How to run 65 | After you have cloned the repository, you can train the dataset by running the script below. 66 | 67 | You can set the dimension of the additional layer in [config.py](./config.py) 68 | 69 | The resetClassifier option will automatically detect the number of classes in your data folder and reset the last classifier layer to the according number. 70 | 71 | ```bash 72 | # zero-base training 73 | python main.py --lr [:lr] --depth [:depth] --resetClassifier 74 | 75 | # fine-tuning 76 | python main.py --finetune --lr [:lr] --depth [:depth] 77 | 78 | # fine-tuning with additional linear layers 79 | python main.py --finetune --addlayer --lr [:lr] --depth [:depth] 80 | ``` 81 | 82 | ## Train various networks 83 | 84 | I have added fine-tuning & transfer learning script for alexnet, VGG(11, 13, 16, 19), 85 | ResNet(18, 34, 50, 101, 152). 86 | 87 | Please modify the [scripts](./train) and run the line below. 88 | 89 | ```bash 90 | 91 | $ ./train/[:network].sh 92 | 93 | # For example, if you want to pretrain alexnet, just run 94 | $ ./train/alexnet.sh 95 | 96 | ``` 97 | 98 | ## Test (Inference) various networks 99 | 100 | For testing out your fine-tuned model on alexnet, VGG(11, 13, 16, 19), ResNet(18, 34, 50, 101, 152), 101 | 102 | First, set your data directory as test_dir in [config.py](./config.py). 103 | 104 | Please modify the [scripts](./test) and run the line below. 105 | 106 | ```bash 107 | 108 | $ ./test/[:network].sh 109 | 110 | ``` 111 | For example, if you have trained ResNet with 50 layers, first modify the [resnet test script](./test/resnet.sh) 112 | 113 | ```bash 114 | $ vi ./test/resnet.sh 115 | 116 | python main.py \ 117 | --net_type resnet \ 118 | --depth 50 119 | --testOnly 120 | 121 | $ ./test/resnet.sh 122 | 123 | ``` 124 | 125 | The code above will automatically download weights from the given depth data, and train your dataset with a very small learning rate. 126 | 127 | ## Feature extraction 128 | For various training mechanisms, extracted feature vectors are needed. 129 | 130 | This repository will provide you not only feature extraction from pre-trained networks, 131 | 132 | but also extractions from a model that was trained by yourself. 133 | 134 | Just set the test directory in the [config.py](config.py) and run the code below. 135 | 136 | ```bash 137 | python extract_features.py 138 | ``` 139 | 140 | This will automatically create pickles in a newly created 'vector' directory, 141 | 142 | which will contain dictionary pickles which contains the below. 143 | 144 | Currently, the 'score' will only cover 0~1 scores for binary classes. 145 | 146 | Confidence scores for multiple class will be updated afterwards. 147 | 148 | ```bash 149 | pickle_file [name : image base name] 150 | |- 'file_name' : file name of the test image 151 | |- 'features' : extracted feature vector 152 | |- 'score' : Score for binary classification 153 | ``` 154 | 155 | Enjoy :-) 156 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Configuration File 2 | 3 | # Base directory for data formats 4 | #name = 'INBREAST_5' 5 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_INBREAST_SPLIT' 6 | 7 | #name = 'GURO_CELL' 8 | #test_dir = '/home/bumsoo/Data/split/GURO_CELL/val/' 9 | 10 | # INBREAST 11 | #name = 'INBREAST_TRAIN' 12 | #test_dir = '/home/bumsoo/Data/test/HWEJIN_INBREAST_SPLIT' 13 | 14 | # GURO_SPLIT 15 | name = 'GURO_TRAIN' 16 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_GURO_SPLIT' 17 | 18 | # GURO_ALL -> INBREAST_ALL 19 | #name = 'GURO_ALL' 20 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/INBREAST_ALL' 21 | 22 | # MIX_TRAIN -> MIX_TEST 23 | #name = 'MIX_TRAIN' 24 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_MIX_TEST' 25 | test_dir = '/home/bumsoo/inference_patches/guro_patches_test_8' 26 | 27 | # GURO80+INBREAST_ALL 28 | #name = 'GURO80+INBREAST' 29 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/GURO80+INBREAST' 30 | 31 | # INBREAST80+GURO_ALL 32 | #name = 'INBREAST80+GURO' 33 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/INBREAST80+GURO' 34 | 35 | # Inference (INBREAST_test) 36 | #name = 'INBREAST_TRAIN' 37 | #test_dir = '/home/bumsoo/Data/test/inbreast_patches_test_9' 38 | #test_dir = '/mnt/datasets/inbreast_patches_test_9' 39 | 40 | # Inference (GURO_test) 41 | #name = 'GURO_TRAIN' 42 | #test_dir = '/home/bumsoo/Data/test/guro_patches_test_0' 43 | #test_dir = '/home/bumsoo/guro_patches_test_9' 44 | 45 | # Inference (GURO_ALL -> INBREAST_ALL) 46 | #name = 'GURO_ALL' 47 | #test_dir = '/home/bumsoo/Data/test/PATCH_INBREAST_TEST/inbreast_patches_4' 48 | 49 | data_base = '/home/mnt/datasets/'+name 50 | aug_base = '/home/bumsoo/Data/split/'+name 51 | 52 | # model option 53 | batch_size = 16 54 | num_epochs = 50 55 | lr_decay_epoch=20 56 | feature_size = 500 57 | 58 | # meanstd options 59 | # INBREAST_SPLIT 60 | #mean = [0.601176900699946, 0.601176900699946, 0.601176900699946] 61 | #std = [0.083943294373731825, 0.083943294373731825, 0.083943294373731825] 62 | 63 | # GURO_SPLIT 64 | #mean = [0.49113493759286625, 0.49113493759286625, 0.49113493759286625] 65 | #std = [0.14704804249157166, 0.14704804249157166, 0.14704804249157166] 66 | 67 | # GURO_ALL 68 | #mean = [0.42641446119819587, 0.42641446119819587, 0.42641446119819587] 69 | #std = [0.19647293715592193, 0.19647293715592193, 0.19647293715592193] 70 | 71 | # GURO+INBREAST 72 | #mean = [0.53753781240686382, 0.53753781240686382, 0.53753781240686382] 73 | #std = [0.12187187243213095, 0.12187187243213095, 0.12187187243213095] 74 | 75 | # MIX_TRAIN 76 | #mean = [0.50528327792298555, 0.50528327792298555, 0.50528327792298555] 77 | #std = [0.13993786443871117, 0.13993786443871117, 0.13993786443871117] 78 | 79 | # GURO80+INBREAST_ALL 80 | #mean = [0.49977846189176656, 0.49977846189176656, 0.49977846189176656] 81 | #std = [0.14111615457915755, 0.14111615457915755, 0.14111615457915755] 82 | 83 | # INBREAST80+GURO_ALL 84 | mean = [0.4856586910840433, 0.4856586910840433, 0.4856586910840433] 85 | std = [0.14210993338737993, 0.14210993338737993, 0.14210993338737993] 86 | 87 | # GURO_CELL 88 | #mean = [0.78076776409256798, 0.61738499185119988, 0.62287074541563914] 89 | #std = [0.18391759503019442, 0.26082926658759176, 0.23288027411260487] 90 | 91 | # INBREAST_1 92 | #mean = [0.60284723168105081, 0.60284723168105081, 0.60284723168105081] 93 | #std = [0.081163047606150382, 0.081163047606150382, 0.081163047606150382] 94 | 95 | # INBREAST_2 96 | #mean = [0.61158796966756579, 0.61158796966756579, 0.61158796966756579] 97 | #std = [0.08487070239187032, 0.08487070239187032, 0.08487070239187032] 98 | 99 | # INBREAST_3 100 | #mean = [0.60108720150874573, 0.60108720150874573, 0.60108720150874573] 101 | #std = [0.081551750213639501, 0.081551750213639501, 0.081551750213639501] 102 | 103 | # INBREAST_4 104 | #mean = [0.60402172760178874, 0.60402172760178874, 0.60402172760178874] 105 | #std = [0.078366899563820674, 0.078366899563820674, 0.078366899563820674] 106 | 107 | # INBREAST_5 108 | #mean = [0.59631095620282071, 0.59631095620282071, 0.59631095620282071] 109 | #std = [0.080351500548752522, 0.080351500548752522, 0.080351500548752522] 110 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # ************************************************************ 2 | # Author : Bumsoo Kim, 2017 3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch 4 | # 5 | # Korea University, Data-Mining Lab 6 | # Deep Convolutional Network Fine tuning Implementation 7 | # 8 | # Description : extract_features.py 9 | # The main code for extracting features of trained model. 10 | # *********************************************************** 11 | 12 | from __future__ import print_function, division 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import torch.backends.cudnn as cudnn 18 | import numpy as np 19 | import config as cf 20 | import torchvision 21 | import time 22 | import copy 23 | import os 24 | import sys 25 | import argparse 26 | 27 | from torchvision import datasets, models, transforms 28 | from networks import * 29 | from torch.autograd import Variable 30 | from PIL import Image 31 | import pickle 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training') 34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning_rate') 35 | parser.add_argument('--net_type', default='resnet', type=str, help='model') 36 | parser.add_argument('--depth', default=50, type=int, help='depth of model') 37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model') 38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning') 39 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model') 40 | args = parser.parse_args() 41 | 42 | # Phase 1 : Data Upload 43 | print('\n[Phase 1] : Data Preperation') 44 | 45 | data_dir = cf.test_dir 46 | trainset_dir = cf.data_base.split("/")[-1] + os.sep 47 | print("| Preparing %s dataset..." %(cf.test_dir.split("/")[-1])) 48 | 49 | use_gpu = torch.cuda.is_available() 50 | 51 | # Phase 2 : Model setup 52 | print('\n[Phase 2] : Model setup') 53 | 54 | def getNetwork(args): 55 | if (args.net_type == 'vggnet'): 56 | net = VGG(args.finetune, args.depth) 57 | file_name = 'vgg-%s' %(args.depth) 58 | elif (args.net_type == 'resnet'): 59 | net = resnet(args.finetune, args.depth) 60 | file_name = 'resnet-%s' %(args.depth) 61 | else: 62 | print('Error : Network should be either [VGGNet / ResNet]') 63 | sys.exit(1) 64 | 65 | return net, file_name 66 | 67 | def softmax(x): 68 | return np.exp(x) / np.sum(np.exp(x), axis=0) 69 | 70 | print("| Loading checkpoint model for feature extraction...") 71 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' 72 | assert os.path.isdir('checkpoint/'+trainset_dir), 'Error: No model has been trained on the dataset!' 73 | _, file_name = getNetwork(args) 74 | checkpoint = torch.load('./checkpoint/'+trainset_dir+file_name+'.t7') 75 | model = checkpoint['model'] 76 | 77 | print("| Consisting a feature extractor from the model...") 78 | if(args.net_type == 'alexnet' or args.net_type == 'vggnet'): 79 | feature_map = list(checkpoint['model'].module.classifier.children()) 80 | feature_map.pop() 81 | new_classifier = nn.Sequential(*feature_map) 82 | extractor = copy.deepcopy(checkpoint['model']) 83 | extractor.module.classifier = new_classifier 84 | elif (args.net_type) == 'resnet'): 85 | feature_map = list(model.module.children()) 86 | feature_map.pop() 87 | extractor = nn.Sequential(*feature_map) 88 | 89 | if use_gpu: 90 | model.cuda() 91 | extractor.cuda() 92 | cudnn.benchmark = True 93 | 94 | model.eval() 95 | extractor.eval() 96 | 97 | sample_input = Variable(torch.randn(1,3,224,224), volatile=True) 98 | if use_gpu: 99 | sample_input = sample_input.cuda() 100 | 101 | sample_output = extractor(sample_input) 102 | featureSize = sample_output.size(1) 103 | print("| Feature dimension = %d" %featureSize) 104 | 105 | print("\n[Phase 3] : Feature & Score Extraction") 106 | 107 | def is_image(f): 108 | return f.endswith(".png") or f.endswith(".jpg") 109 | 110 | test_transform = transforms.Compose([ 111 | transforms.Scale(224), 112 | transforms.CenterCrop(224), 113 | transforms.ToTensor(), 114 | transforms.Normalize(cf.mean, cf.std) 115 | ]) 116 | 117 | if not os.path.isdir('vectors'): 118 | os.mkdir('vectors') 119 | 120 | for subdir, dirs, files in os.walk(data_dir): 121 | for f in files: 122 | file_path = subdir + os.sep + f 123 | if (is_image(f)): 124 | vector_dict = { 125 | 'file_path': "", 126 | 'feature': [], 127 | 'score': 0, 128 | } 129 | 130 | image = Image.open(file_path).convert('RGB') 131 | if test_transform is not None: 132 | image = test_transform(image) 133 | inputs = image 134 | inputs = Variable(inputs, volatile=True) 135 | if use_gpu: 136 | inputs = inputs.cuda() 137 | inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2)) # add batch dim in the front 138 | features = extractor(inputs).view(featureSize) 139 | 140 | outputs = model(inputs) 141 | softmax_res = softmax(outputs.data.cpu().numpy()[0]) 142 | 143 | vector_dict['file_path'] = file_path 144 | vector_dict['feature'] = features 145 | vector_dict['score'] = softmax_res[1] 146 | 147 | vector_file = 'vectors' + os.sep + os.path.splitext(f)[0] + ".pickle" 148 | 149 | print(vector_file) 150 | print(vector_dict['feature'].size()) 151 | print(vector_dict['score']) 152 | 153 | with open(vector_file, 'wb') as pkl: 154 | pickle.dump(vector_dict, pkl, protocol=pickle.HIGHEST_PROTOCOL) 155 | -------------------------------------------------------------------------------- /imgs/pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/fine-tuning.pytorch/91b45bbf1287a33603c344d64c06b6b1bf8f226e/imgs/pytorch.png -------------------------------------------------------------------------------- /imgs/transfer-learning.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/fine-tuning.pytorch/91b45bbf1287a33603c344d64c06b6b1bf8f226e/imgs/transfer-learning.jpeg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # ************************************************************ 2 | # Author : Bumsoo Kim, 2017 3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch 4 | # 5 | # Korea University, Data-Mining Lab 6 | # Deep Convolutional Network Fine tuning Implementation 7 | # 8 | # Description : inference.py 9 | # The main code for inference test phase of trained model. 10 | # *********************************************************** 11 | 12 | from __future__ import print_function, division 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import torch.backends.cudnn as cudnn 18 | import numpy as np 19 | import config as cf 20 | import torchvision 21 | import time 22 | import copy 23 | import os 24 | import sys 25 | import argparse 26 | import csv 27 | 28 | from torchvision import datasets, models, transforms 29 | from networks import * 30 | from torch.autograd import Variable 31 | from PIL import Image 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training') 34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning_rate') 35 | parser.add_argument('--net_type', default='resnet', type=str, help='model') 36 | parser.add_argument('--depth', default=50, type=int, help='depth of model') 37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model') 38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning') 39 | parser.add_argument('--path', default=cf.test_dir, type=str, help='inference path') 40 | args = parser.parse_args() 41 | 42 | # Phase 1 : Data Upload 43 | print('\n[Phase 1] : Data Preperation') 44 | 45 | cf.test_dir = args.path 46 | data_dir = cf.test_dir 47 | trainset_dir = cf.data_base.split("/")[-1] + os.sep 48 | print("| Preparing %s dataset..." %(cf.test_dir.split("/")[-1])) 49 | 50 | use_gpu = torch.cuda.is_available() 51 | 52 | # Phase 2 : Model setup 53 | print('\n[Phase 2] : Model setup') 54 | 55 | def getNetwork(args): 56 | if (args.net_type == 'alexnet'): 57 | net = models.alexnet(pretrained=args.finetune) 58 | file_name = 'alexnet' 59 | elif (args.net_type == 'vggnet'): 60 | if(args.depth == 16): 61 | net = models.vgg16(pretrained=args.finetune) 62 | file_name = 'vgg-%s' %(args.depth) 63 | elif (args.net_type == 'inception'): 64 | net = models.inception(pretrained=args.finetune) 65 | file_name = 'inceptino-v3' 66 | elif (args.net_type == 'resnet'): 67 | net = resnet(args.finetune, args.depth) 68 | file_name = 'resnet-%s' %(args.depth) 69 | else: 70 | print('Error : Network should be either [VGGNet / ResNet]') 71 | sys.exit(1) 72 | 73 | return net, file_name 74 | 75 | def softmax(x): 76 | return np.exp(x) / np.sum(np.exp(x), axis=0) 77 | 78 | print("| Loading checkpoint model for inference phase...") 79 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' 80 | assert os.path.isdir('checkpoint/'+trainset_dir), 'Error: No model has been trained on the dataset!' 81 | _, file_name = getNetwork(args) 82 | checkpoint = torch.load('./checkpoint/'+trainset_dir+file_name+'.t7') 83 | model = checkpoint['model'] 84 | 85 | if use_gpu: 86 | model.cuda() 87 | cudnn.benchmark = True 88 | 89 | model.eval() 90 | 91 | sample_input = Variable(torch.randn(1,3,224,224), volatile=True) 92 | if use_gpu: 93 | sample_input = sample_input.cuda() 94 | 95 | print("\n[Phase 3] : Score Inference") 96 | 97 | def is_image(f): 98 | return f.endswith(".png") or f.endswith(".jpg") 99 | 100 | test_transform = transforms.Compose([ 101 | transforms.Scale(224), 102 | transforms.CenterCrop(224), 103 | transforms.ToTensor(), 104 | transforms.Normalize(cf.mean, cf.std) 105 | ]) 106 | 107 | if not os.path.isdir('result'): 108 | os.mkdir('result') 109 | 110 | output_file = "./result/"+cf.test_dir.split("/")[-1]+".csv" 111 | 112 | with open(output_file, 'wb') as csvfile: 113 | fields = ['file_name', 'score'] 114 | writer = csv.DictWriter(csvfile, fieldnames=fields) 115 | for subdir, dirs, files in os.walk(data_dir): 116 | for f in files: 117 | file_path = subdir + os.sep + f 118 | if (is_image(f)): 119 | image = Image.open(file_path).convert('RGB') 120 | if test_transform is not None: 121 | image = test_transform(image) 122 | inputs = image 123 | inputs = Variable(inputs, volatile=True) 124 | if use_gpu: 125 | inputs = inputs.cuda() 126 | inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2)) # add batch dim in the front 127 | 128 | outputs = model(inputs) 129 | softmax_res = softmax(outputs.data.cpu().numpy()[0]) 130 | score = softmax_res[1] 131 | 132 | print(file_path + "," + str(score)) 133 | writer.writerow({'file_name': file_path, 'score':score}) 134 | -------------------------------------------------------------------------------- /inference/alexnet.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --net_type alexnet 3 | -------------------------------------------------------------------------------- /inference/recursive_resnet.sh: -------------------------------------------------------------------------------- 1 | for ((i=7;i<=9;i++)); do 2 | python inference.py \ 3 | --net_type resnet \ 4 | --depth 152 \ 5 | --path /home/bumsoo/Data/test/inbreast_patches_test_1_$i 6 | done 7 | -------------------------------------------------------------------------------- /inference/resnet.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --net_type resnet \ 3 | --depth 152 \ 4 | -------------------------------------------------------------------------------- /inference/vggnet.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --net_type vggnet \ 3 | --depth 16 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ************************************************************ 2 | # Author : Bumsoo Kim, 2017 3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch 4 | # 5 | # Korea University, Data-Mining Lab 6 | # Deep Convolutional Network Fine tuning Implementation 7 | # 8 | # Description : main.py 9 | # The main code for training classification networks. 10 | # *********************************************************** 11 | 12 | from __future__ import print_function, division 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import torch.backends.cudnn as cudnn 18 | import numpy as np 19 | import config as cf 20 | import torchvision 21 | import time 22 | import copy 23 | import os 24 | import sys 25 | import argparse 26 | import pretrainedmodels # exclude this for python2.7 users 27 | 28 | from torchvision import datasets, models, transforms 29 | from networks import * 30 | from torch.autograd import Variable 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training') 33 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 34 | parser.add_argument('--net_type', default='resnet', type=str, help='model') 35 | parser.add_argument('--depth', default=50, type=int, help='depth of model') 36 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') 37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model') 38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning') 39 | parser.add_argument('--resetClassifier', '-r', action='store_true', help='Reset classifier') 40 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model') 41 | args = parser.parse_args() 42 | 43 | # Phase 1 : Data Upload 44 | print('\n[Phase 1] : Data Preperation') 45 | 46 | if args.net_type == 'inception' or args.net_type == 'xception': 47 | data_transforms = { 48 | 'train': transforms.Compose([ 49 | transforms.Scale(320), 50 | transforms.RandomSizedCrop(299), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize(cf.mean, cf.std) 54 | ]), 55 | 'val': transforms.Compose([ 56 | transforms.Scale(320), 57 | transforms.CenterCrop(299), 58 | transforms.ToTensor(), 59 | transforms.Normalize(cf.mean, cf.std) 60 | ]), 61 | } 62 | else: 63 | data_transforms = { 64 | 'train': transforms.Compose([ 65 | transforms.Scale(256), 66 | transforms.RandomSizedCrop(224), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(cf.mean, cf.std) 70 | ]), 71 | 'val': transforms.Compose([ 72 | transforms.Scale(256), 73 | transforms.CenterCrop(224), 74 | transforms.ToTensor(), 75 | transforms.Normalize(cf.mean, cf.std) 76 | ]), 77 | } 78 | 79 | data_dir = cf.aug_base 80 | dataset_dir = cf.data_base.split("/")[-1] + os.sep 81 | print("| Preparing model trained on %s dataset..." %(cf.data_base.split("/")[-1])) 82 | dsets = { 83 | x : datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 84 | for x in ['train', 'val'] 85 | } 86 | dset_loaders = { 87 | x : torch.utils.data.DataLoader(dsets[x], batch_size = cf.batch_size, shuffle=(x=='train'), num_workers=4) 88 | for x in ['train', 'val'] 89 | } 90 | 91 | dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']} 92 | dset_classes = dsets['train'].classes 93 | 94 | use_gpu = torch.cuda.is_available() 95 | 96 | # Phase 2 : Model setup 97 | print('\n[Phase 2] : Model setup') 98 | 99 | def getNetwork(args): 100 | if (args.net_type == 'alexnet'): 101 | net = models.alexnet(pretrained=args.finetune) 102 | file_name = 'alexnet' 103 | elif (args.net_type == 'vggnet'): 104 | if(args.depth == 11): 105 | net = models.vgg11(pretrained=args.finetune) 106 | elif(args.depth == 13): 107 | net = models.vgg13(pretrained=args.finetune) 108 | elif(args.depth == 16): 109 | net = models.vgg16(pretrained=args.finetune) 110 | elif(args.depth == 19): 111 | net = models.vgg19(pretrained=args.finetune) 112 | else: 113 | print('Error : VGGnet should have depth of either [11, 13, 16, 19]') 114 | sys.exit(1) 115 | file_name = 'vgg-%s' %(args.depth) 116 | elif (args.net_type == 'squeezenet'): 117 | net = models.squeezenet1_0(pretrained=args.finetune) 118 | file_name = 'squeeze' 119 | elif (args.net_type == 'resnet'): 120 | net = resnet(args.finetune, args.depth) 121 | file_name = 'resnet-%s' %(args.depth) 122 | elif (args.net_type == 'inception'): 123 | net = pretrainedmodels.inceptionv3(num_classes=1000, pretrained='imagenet') 124 | file_name = 'inception-v3' 125 | elif (args.net_type == 'xception'): 126 | net = pretrainedmodels.xception(num_classes=1000, pretrained='imagenet') 127 | file_name = 'xception' 128 | else: 129 | print('Error : Network should be either [alexnet / squeezenet / vggnet / resnet]') 130 | sys.exit(1) 131 | 132 | return net, file_name 133 | 134 | def softmax(x): 135 | return np.exp(x) / np.sum(np.exp(x), axis=0) 136 | 137 | # Test only option 138 | if (args.testOnly): 139 | print("| Loading checkpoint model for test phase...") 140 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' 141 | _, file_name = getNetwork(args) 142 | print('| Loading '+file_name+".t7...") 143 | checkpoint = torch.load('./checkpoint/'+dataset_dir+'/'+file_name+'.t7') 144 | model = checkpoint['model'] 145 | 146 | if use_gpu: 147 | model.cuda() 148 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 149 | # cudnn.benchmark = True 150 | 151 | model.eval() 152 | test_loss = 0 153 | correct = 0 154 | total = 0 155 | 156 | testsets = datasets.ImageFolder(cf.test_dir, data_transforms['val']) 157 | 158 | testloader = torch.utils.data.DataLoader( 159 | testsets, 160 | batch_size = 1, 161 | shuffle = False, 162 | num_workers=1 163 | ) 164 | 165 | print("\n[Phase 3 : Inference on %s]" %cf.test_dir) 166 | for batch_idx, (inputs, targets) in enumerate(testloader):#dset_loaders['val']): 167 | if use_gpu: 168 | inputs, targets = inputs.cuda(), targets.cuda() 169 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 170 | outputs = model(inputs) 171 | 172 | # print(outputs.data.cpu().numpy()[0]) 173 | softmax_res = softmax(outputs.data.cpu().numpy()[0]) 174 | 175 | _, predicted = torch.max(outputs.data, 1) 176 | total += targets.size(0) 177 | correct += predicted.eq(targets.data).cpu().sum() 178 | 179 | acc = 100.*correct/total 180 | print("| Test Result\tAcc@1 %.2f%%" %(acc)) 181 | 182 | sys.exit(0) 183 | 184 | # Training model 185 | def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=cf.num_epochs): 186 | global dataset_dir 187 | since = time.time() 188 | 189 | best_model, best_acc = model, 0.0 190 | 191 | print('\n[Phase 3] : Training Model') 192 | print('| Training Epochs = %d' %num_epochs) 193 | print('| Initial Learning Rate = %f' %args.lr) 194 | print('| Optimizer = SGD') 195 | for epoch in range(num_epochs): 196 | for phase in ['train', 'val']: 197 | if phase == 'train': 198 | optimizer, lr = lr_scheduler(optimizer, epoch) 199 | print('\n=> Training Epoch #%d, LR=%f' %(epoch+1, lr)) 200 | model.train(True) 201 | else: 202 | model.train(False) 203 | model.eval() 204 | 205 | running_loss, running_corrects, tot = 0.0, 0, 0 206 | 207 | for batch_idx, (inputs, labels) in enumerate(dset_loaders[phase]): 208 | if use_gpu: 209 | inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) 210 | else: 211 | inputs, labels = Variable(inputs), Variable(labels) 212 | 213 | optimizer.zero_grad() 214 | 215 | # Forward Propagation 216 | outputs = model(inputs) 217 | if isinstance(outputs, tuple): 218 | loss = sum((criterion(o, labels) for o in outputs)) 219 | else: 220 | loss = criterion(outputs, labels) 221 | if isinstance(outputs, tuple): 222 | # inception v3 output will be (x, aux) 223 | outputs = outputs[0] 224 | _, preds = torch.max(outputs.data, 1) 225 | 226 | # Backward Propagation 227 | if phase == 'train': 228 | loss.backward() 229 | optimizer.step() 230 | 231 | # Statistics 232 | running_loss += loss.data[0] 233 | running_corrects += preds.eq(labels.data).cpu().sum() 234 | tot += labels.size(0) 235 | 236 | if (phase == 'train'): 237 | sys.stdout.write('\r') 238 | sys.stdout.write('| Epoch [%2d/%2d] Iter [%3d/%3d]\t\tLoss %.4f\tAcc %.2f%%' 239 | %(epoch+1, num_epochs, batch_idx+1, 240 | (len(dsets[phase])//cf.batch_size)+1, loss.data[0], 100.*running_corrects/tot)) 241 | sys.stdout.flush() 242 | sys.stdout.write('\r') 243 | 244 | epoch_loss = running_loss / dset_sizes[phase] 245 | epoch_acc = running_corrects / dset_sizes[phase] 246 | 247 | if (phase == 'val'): 248 | print('\n| Validation Epoch #%d\t\t\tLoss %.4f\tAcc %.2f%%' 249 | %(epoch+1, loss.data[0], 100.*epoch_acc)) 250 | 251 | if epoch_acc > best_acc :#and epoch > 80: 252 | print('| Saving Best model...\t\t\tTop1 %.2f%%' %(100.*epoch_acc)) 253 | best_acc = epoch_acc 254 | best_model = copy.deepcopy(model) 255 | state = { 256 | 'model': best_model, 257 | 'acc': epoch_acc, 258 | 'epoch':epoch, 259 | } 260 | if not os.path.isdir('checkpoint'): 261 | os.mkdir('checkpoint') 262 | save_point = './checkpoint/'+dataset_dir 263 | if not os.path.isdir(save_point): 264 | os.mkdir(save_point) 265 | torch.save(state, save_point+file_name+'.t7') 266 | 267 | time_elapsed = time.time() - since 268 | print('\nTraining completed in\t{:.0f} min {:.0f} sec'. format(time_elapsed // 60, time_elapsed % 60)) 269 | print('Best validation Acc\t{:.2f}%'.format(best_acc*100)) 270 | 271 | return best_model 272 | 273 | def exp_lr_scheduler(optimizer, epoch, init_lr=args.lr, weight_decay=args.weight_decay, lr_decay_epoch=cf.lr_decay_epoch): 274 | lr = init_lr * (0.5**(epoch // lr_decay_epoch)) 275 | 276 | for param_group in optimizer.param_groups: 277 | param_group['lr'] = lr 278 | param_group['weight_decay'] = weight_decay 279 | 280 | return optimizer, lr 281 | 282 | model_ft, file_name = getNetwork(args) 283 | 284 | if(args.resetClassifier): 285 | print('| Reset final classifier...') 286 | if(args.addlayer): 287 | print('| Add features of size %d' %cf.feature_size) 288 | num_ftrs = model_ft.fc.in_features 289 | feature_model = list(model_ft.fc.children()) 290 | feature_model.append(nn.Linear(num_ftrs, cf.feature_size)) 291 | feature_model.append(nn.BatchNorm1d(cf.feature_size)) 292 | feature_model.append(nn.ReLU(inplace=True)) 293 | feature_model.append(nn.Linear(cf.feature_size, len(dset_classes))) 294 | model_ft.fc = nn.Sequential(*feature_model) 295 | else: 296 | if(args.net_type == 'alexnet' or args.net_type == 'vggnet'): 297 | num_ftrs = model_ft.classifier[6].in_features 298 | feature_model = list(model_ft.classifier.children()) 299 | feature_model.pop() 300 | feature_model.append(nn.Linear(num_ftrs, len(dset_classes))) 301 | model_ft.classifier = nn.Sequential(*feature_model) 302 | elif(args.net_type == 'resnet'): 303 | num_ftrs = model_ft.fc.in_features 304 | model_ft.fc = nn.Linear(num_ftrs, len(dset_classes)) 305 | elif(args.net_type == 'inception' or args.net_type == 'xception'): 306 | num_ftrs = model_ft.last_linear.in_features 307 | model_ft.last_linear = nn.Linear(num_ftrs, len(dset_classes)) 308 | 309 | if use_gpu: 310 | model_ft = model_ft.cuda() 311 | model_ft = torch.nn.DataParallel(model_ft, device_ids=range(torch.cuda.device_count())) 312 | cudnn.benchmark = True 313 | 314 | if __name__ == "__main__": 315 | criterion = nn.CrossEntropyLoss() 316 | optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 317 | model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=cf.num_epochs) 318 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet'] 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'http://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'http://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'http://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'http://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'http://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | def cfg(depth): 23 | depth_lst = [18, 34, 50, 101, 152] 24 | assert (depth in depth_lst), "Error : ResNet depth should be either 18, 34, 50, 101, 152" 25 | cf_dict = { 26 | '18' : (BasicBlock, [2,2, 2,2]), 27 | '34' : (BasicBlock, [3,4, 6,3]), 28 | '50' : (Bottleneck, [3,4, 6,3]), 29 | '101': (Bottleneck, [3,4,23,3]), 30 | '152': (Bottleneck, [3,8,36,3]), 31 | } 32 | 33 | return cf_dict[str(depth)] 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None): 72 | super(Bottleneck, self).__init__() 73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 76 | padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(planes * 4) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | self.stride = stride 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, num_classes=1000): 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(block, 64, layers[0]) 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 120 | self.avgpool = nn.AvgPool2d(7) 121 | self.fc = nn.Linear(512 * block.expansion, num_classes) 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | x = self.maxpool(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | x = self.avgpool(x) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x 164 | 165 | def resnet(pretrained=False, depth=18, **kwargs): 166 | """Constructs ResNet models for various depths 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | depth (int) : Integer input of either 18, 34, 50, 101, 152 170 | """ 171 | block, num_blocks = cfg(depth) 172 | model = ResNet(block, num_blocks, **kwargs) 173 | if (pretrained): 174 | print("| Downloading ImageNet fine-tuned ResNet-%d..." %depth) 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet%d' %depth])) 176 | return model 177 | -------------------------------------------------------------------------------- /test/alexnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --net_type alexnet \ 3 | --testOnly 4 | -------------------------------------------------------------------------------- /test/resnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --net_type resnet \ 3 | --depth 152 \ 4 | --testOnly 5 | -------------------------------------------------------------------------------- /test/vggnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --net_type vggnet \ 3 | --depth 16 \ 4 | --testOnly 5 | -------------------------------------------------------------------------------- /train/alexnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --lr 1e-3 \ 3 | --weight_decay 1e-4 \ 4 | --net_type alexnet \ 5 | --resetClassifier \ 6 | --finetune 7 | -------------------------------------------------------------------------------- /train/inception.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --optimizer SGD \ 3 | --lr 0.045 \ 4 | --weight_decay 4e-5 \ 5 | --net_type inception \ 6 | --depth 50 \ 7 | --resetClassifier \ 8 | --finetune 9 | -------------------------------------------------------------------------------- /train/resnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --lr 1e-3 \ 3 | --weight_decay 5e-4 \ 4 | --net_type resnet \ 5 | --depth 152 \ 6 | --resetClassifier \ 7 | --finetune \ 8 | #--addlayer 9 | -------------------------------------------------------------------------------- /train/squeeze.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --lr 1e-3 \ 3 | --weight_decay 5e-4 \ 4 | --net_type squeezenet \ 5 | --resetClassifier \ 6 | --finetune 7 | -------------------------------------------------------------------------------- /train/vggnet.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --lr 1e-3 \ 3 | --weight_decay 5e-4 \ 4 | --net_type vggnet \ 5 | --depth 16 \ 6 | --resetClassifier \ 7 | --finetune 8 | -------------------------------------------------------------------------------- /train/xception.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --lr 0.045 \ 3 | --optimizer SGD \ 4 | --weight_decay 1e-5 \ 5 | --net_type xception \ 6 | --resetClassifier \ 7 | --finetune 8 | #--testOnly 9 | --------------------------------------------------------------------------------