├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── extract_features.py
├── imgs
├── pytorch.png
└── transfer-learning.jpeg
├── inference.py
├── inference
├── alexnet.sh
├── recursive_resnet.sh
├── resnet.sh
└── vggnet.sh
├── main.py
├── networks
├── __init__.py
└── resnet.py
├── test
├── alexnet.sh
├── resnet.sh
└── vggnet.sh
└── train
├── alexnet.sh
├── inception.sh
├── resnet.sh
├── squeeze.sh
├── vggnet.sh
└── xception.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # torch
104 | *.t7
105 | vectors/*
106 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Bumsoo Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | # fine-tuning.pytorch
4 | Pytorch implementation of Fine-Tuning (Transfer Learning) CNN Networks.
5 | This project is made by Bumsoo Kim.
6 |
7 | Korea University, Master-Ph.D intergrated Course.
8 | 
9 |
10 |
11 | ## Fine-Tuning
12 | In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.
13 |
14 | Futher explanations can be found [here](http://cs231n.github.io/transfer-learning/).
15 |
16 | ## Requirements
17 | See the [installation instruction](INSTALL.md) for a step-by-step installation guide.
18 | See the [server instruction](SERVER.md) for server settup.
19 | - Install [cuda-8.0](https://developer.nvidia.com/cuda-downloads)
20 | - Install [cudnn v5.1](https://developer.nvidia.com/cudnn)
21 | - Download [PyTorch for python-2.7](https://pytorch.org) and clone the repository.
22 | - Download [PyTorch-3.5](https://pytorch.org) for using further pretrained libraries with anaconda3.
23 | ```bash
24 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl
25 | pip install torchvision
26 | git clone https://github.com/meliketoy/resnet-fine-tuning.pytorch
27 | ```
28 |
29 | - Download Pretrained models for PyTorch (Only for 3.5)
30 | ```bash
31 | $ git clone https://github.com/Cadene/pretrained-models.pytorch.git
32 | $ pretrained-models.pytorch
33 | $ python setup.py install
34 | ```
35 |
36 | ## Basic Setups
37 | After you have cloned this repository into your file system, open [config.py](./config.py),
38 | And edit the lines below to your data directory.
39 | ```bash
40 | data_base = [:dir to your original dataset]
41 | aug_base = [:dir to your actually trained dataset]
42 | ```
43 |
44 | For training, your data file system should be in the following hierarchy.
45 | Organizing codes for your data into the given requirements will be provided [here](https://github.com/meliketoy/image-preprocessing)
46 |
47 | ```bash
48 | [:data file name]
49 |
50 | |-train
51 | |-[:class 0]
52 | |-[:class 1]
53 | |-[:class 2]
54 | ...
55 | |-[:class n]
56 | |-val
57 | |-[:class 0]
58 | |-[:class 1]
59 | |-[:class 2]
60 | ...
61 | |-[:class n]
62 | ```
63 |
64 | ## How to run
65 | After you have cloned the repository, you can train the dataset by running the script below.
66 |
67 | You can set the dimension of the additional layer in [config.py](./config.py)
68 |
69 | The resetClassifier option will automatically detect the number of classes in your data folder and reset the last classifier layer to the according number.
70 |
71 | ```bash
72 | # zero-base training
73 | python main.py --lr [:lr] --depth [:depth] --resetClassifier
74 |
75 | # fine-tuning
76 | python main.py --finetune --lr [:lr] --depth [:depth]
77 |
78 | # fine-tuning with additional linear layers
79 | python main.py --finetune --addlayer --lr [:lr] --depth [:depth]
80 | ```
81 |
82 | ## Train various networks
83 |
84 | I have added fine-tuning & transfer learning script for alexnet, VGG(11, 13, 16, 19),
85 | ResNet(18, 34, 50, 101, 152).
86 |
87 | Please modify the [scripts](./train) and run the line below.
88 |
89 | ```bash
90 |
91 | $ ./train/[:network].sh
92 |
93 | # For example, if you want to pretrain alexnet, just run
94 | $ ./train/alexnet.sh
95 |
96 | ```
97 |
98 | ## Test (Inference) various networks
99 |
100 | For testing out your fine-tuned model on alexnet, VGG(11, 13, 16, 19), ResNet(18, 34, 50, 101, 152),
101 |
102 | First, set your data directory as test_dir in [config.py](./config.py).
103 |
104 | Please modify the [scripts](./test) and run the line below.
105 |
106 | ```bash
107 |
108 | $ ./test/[:network].sh
109 |
110 | ```
111 | For example, if you have trained ResNet with 50 layers, first modify the [resnet test script](./test/resnet.sh)
112 |
113 | ```bash
114 | $ vi ./test/resnet.sh
115 |
116 | python main.py \
117 | --net_type resnet \
118 | --depth 50
119 | --testOnly
120 |
121 | $ ./test/resnet.sh
122 |
123 | ```
124 |
125 | The code above will automatically download weights from the given depth data, and train your dataset with a very small learning rate.
126 |
127 | ## Feature extraction
128 | For various training mechanisms, extracted feature vectors are needed.
129 |
130 | This repository will provide you not only feature extraction from pre-trained networks,
131 |
132 | but also extractions from a model that was trained by yourself.
133 |
134 | Just set the test directory in the [config.py](config.py) and run the code below.
135 |
136 | ```bash
137 | python extract_features.py
138 | ```
139 |
140 | This will automatically create pickles in a newly created 'vector' directory,
141 |
142 | which will contain dictionary pickles which contains the below.
143 |
144 | Currently, the 'score' will only cover 0~1 scores for binary classes.
145 |
146 | Confidence scores for multiple class will be updated afterwards.
147 |
148 | ```bash
149 | pickle_file [name : image base name]
150 | |- 'file_name' : file name of the test image
151 | |- 'features' : extracted feature vector
152 | |- 'score' : Score for binary classification
153 | ```
154 |
155 | Enjoy :-)
156 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Configuration File
2 |
3 | # Base directory for data formats
4 | #name = 'INBREAST_5'
5 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_INBREAST_SPLIT'
6 |
7 | #name = 'GURO_CELL'
8 | #test_dir = '/home/bumsoo/Data/split/GURO_CELL/val/'
9 |
10 | # INBREAST
11 | #name = 'INBREAST_TRAIN'
12 | #test_dir = '/home/bumsoo/Data/test/HWEJIN_INBREAST_SPLIT'
13 |
14 | # GURO_SPLIT
15 | name = 'GURO_TRAIN'
16 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_GURO_SPLIT'
17 |
18 | # GURO_ALL -> INBREAST_ALL
19 | #name = 'GURO_ALL'
20 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/INBREAST_ALL'
21 |
22 | # MIX_TRAIN -> MIX_TEST
23 | #name = 'MIX_TRAIN'
24 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/HWEJIN_MIX_TEST'
25 | test_dir = '/home/bumsoo/inference_patches/guro_patches_test_8'
26 |
27 | # GURO80+INBREAST_ALL
28 | #name = 'GURO80+INBREAST'
29 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/GURO80+INBREAST'
30 |
31 | # INBREAST80+GURO_ALL
32 | #name = 'INBREAST80+GURO'
33 | #test_dir = '/home/bumsoo/Data/test/FINAL_HWEJIN/INBREAST80+GURO'
34 |
35 | # Inference (INBREAST_test)
36 | #name = 'INBREAST_TRAIN'
37 | #test_dir = '/home/bumsoo/Data/test/inbreast_patches_test_9'
38 | #test_dir = '/mnt/datasets/inbreast_patches_test_9'
39 |
40 | # Inference (GURO_test)
41 | #name = 'GURO_TRAIN'
42 | #test_dir = '/home/bumsoo/Data/test/guro_patches_test_0'
43 | #test_dir = '/home/bumsoo/guro_patches_test_9'
44 |
45 | # Inference (GURO_ALL -> INBREAST_ALL)
46 | #name = 'GURO_ALL'
47 | #test_dir = '/home/bumsoo/Data/test/PATCH_INBREAST_TEST/inbreast_patches_4'
48 |
49 | data_base = '/home/mnt/datasets/'+name
50 | aug_base = '/home/bumsoo/Data/split/'+name
51 |
52 | # model option
53 | batch_size = 16
54 | num_epochs = 50
55 | lr_decay_epoch=20
56 | feature_size = 500
57 |
58 | # meanstd options
59 | # INBREAST_SPLIT
60 | #mean = [0.601176900699946, 0.601176900699946, 0.601176900699946]
61 | #std = [0.083943294373731825, 0.083943294373731825, 0.083943294373731825]
62 |
63 | # GURO_SPLIT
64 | #mean = [0.49113493759286625, 0.49113493759286625, 0.49113493759286625]
65 | #std = [0.14704804249157166, 0.14704804249157166, 0.14704804249157166]
66 |
67 | # GURO_ALL
68 | #mean = [0.42641446119819587, 0.42641446119819587, 0.42641446119819587]
69 | #std = [0.19647293715592193, 0.19647293715592193, 0.19647293715592193]
70 |
71 | # GURO+INBREAST
72 | #mean = [0.53753781240686382, 0.53753781240686382, 0.53753781240686382]
73 | #std = [0.12187187243213095, 0.12187187243213095, 0.12187187243213095]
74 |
75 | # MIX_TRAIN
76 | #mean = [0.50528327792298555, 0.50528327792298555, 0.50528327792298555]
77 | #std = [0.13993786443871117, 0.13993786443871117, 0.13993786443871117]
78 |
79 | # GURO80+INBREAST_ALL
80 | #mean = [0.49977846189176656, 0.49977846189176656, 0.49977846189176656]
81 | #std = [0.14111615457915755, 0.14111615457915755, 0.14111615457915755]
82 |
83 | # INBREAST80+GURO_ALL
84 | mean = [0.4856586910840433, 0.4856586910840433, 0.4856586910840433]
85 | std = [0.14210993338737993, 0.14210993338737993, 0.14210993338737993]
86 |
87 | # GURO_CELL
88 | #mean = [0.78076776409256798, 0.61738499185119988, 0.62287074541563914]
89 | #std = [0.18391759503019442, 0.26082926658759176, 0.23288027411260487]
90 |
91 | # INBREAST_1
92 | #mean = [0.60284723168105081, 0.60284723168105081, 0.60284723168105081]
93 | #std = [0.081163047606150382, 0.081163047606150382, 0.081163047606150382]
94 |
95 | # INBREAST_2
96 | #mean = [0.61158796966756579, 0.61158796966756579, 0.61158796966756579]
97 | #std = [0.08487070239187032, 0.08487070239187032, 0.08487070239187032]
98 |
99 | # INBREAST_3
100 | #mean = [0.60108720150874573, 0.60108720150874573, 0.60108720150874573]
101 | #std = [0.081551750213639501, 0.081551750213639501, 0.081551750213639501]
102 |
103 | # INBREAST_4
104 | #mean = [0.60402172760178874, 0.60402172760178874, 0.60402172760178874]
105 | #std = [0.078366899563820674, 0.078366899563820674, 0.078366899563820674]
106 |
107 | # INBREAST_5
108 | #mean = [0.59631095620282071, 0.59631095620282071, 0.59631095620282071]
109 | #std = [0.080351500548752522, 0.080351500548752522, 0.080351500548752522]
110 |
--------------------------------------------------------------------------------
/extract_features.py:
--------------------------------------------------------------------------------
1 | # ************************************************************
2 | # Author : Bumsoo Kim, 2017
3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch
4 | #
5 | # Korea University, Data-Mining Lab
6 | # Deep Convolutional Network Fine tuning Implementation
7 | #
8 | # Description : extract_features.py
9 | # The main code for extracting features of trained model.
10 | # ***********************************************************
11 |
12 | from __future__ import print_function, division
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | import torch.backends.cudnn as cudnn
18 | import numpy as np
19 | import config as cf
20 | import torchvision
21 | import time
22 | import copy
23 | import os
24 | import sys
25 | import argparse
26 |
27 | from torchvision import datasets, models, transforms
28 | from networks import *
29 | from torch.autograd import Variable
30 | from PIL import Image
31 | import pickle
32 |
33 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training')
34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning_rate')
35 | parser.add_argument('--net_type', default='resnet', type=str, help='model')
36 | parser.add_argument('--depth', default=50, type=int, help='depth of model')
37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model')
38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning')
39 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
40 | args = parser.parse_args()
41 |
42 | # Phase 1 : Data Upload
43 | print('\n[Phase 1] : Data Preperation')
44 |
45 | data_dir = cf.test_dir
46 | trainset_dir = cf.data_base.split("/")[-1] + os.sep
47 | print("| Preparing %s dataset..." %(cf.test_dir.split("/")[-1]))
48 |
49 | use_gpu = torch.cuda.is_available()
50 |
51 | # Phase 2 : Model setup
52 | print('\n[Phase 2] : Model setup')
53 |
54 | def getNetwork(args):
55 | if (args.net_type == 'vggnet'):
56 | net = VGG(args.finetune, args.depth)
57 | file_name = 'vgg-%s' %(args.depth)
58 | elif (args.net_type == 'resnet'):
59 | net = resnet(args.finetune, args.depth)
60 | file_name = 'resnet-%s' %(args.depth)
61 | else:
62 | print('Error : Network should be either [VGGNet / ResNet]')
63 | sys.exit(1)
64 |
65 | return net, file_name
66 |
67 | def softmax(x):
68 | return np.exp(x) / np.sum(np.exp(x), axis=0)
69 |
70 | print("| Loading checkpoint model for feature extraction...")
71 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
72 | assert os.path.isdir('checkpoint/'+trainset_dir), 'Error: No model has been trained on the dataset!'
73 | _, file_name = getNetwork(args)
74 | checkpoint = torch.load('./checkpoint/'+trainset_dir+file_name+'.t7')
75 | model = checkpoint['model']
76 |
77 | print("| Consisting a feature extractor from the model...")
78 | if(args.net_type == 'alexnet' or args.net_type == 'vggnet'):
79 | feature_map = list(checkpoint['model'].module.classifier.children())
80 | feature_map.pop()
81 | new_classifier = nn.Sequential(*feature_map)
82 | extractor = copy.deepcopy(checkpoint['model'])
83 | extractor.module.classifier = new_classifier
84 | elif (args.net_type) == 'resnet'):
85 | feature_map = list(model.module.children())
86 | feature_map.pop()
87 | extractor = nn.Sequential(*feature_map)
88 |
89 | if use_gpu:
90 | model.cuda()
91 | extractor.cuda()
92 | cudnn.benchmark = True
93 |
94 | model.eval()
95 | extractor.eval()
96 |
97 | sample_input = Variable(torch.randn(1,3,224,224), volatile=True)
98 | if use_gpu:
99 | sample_input = sample_input.cuda()
100 |
101 | sample_output = extractor(sample_input)
102 | featureSize = sample_output.size(1)
103 | print("| Feature dimension = %d" %featureSize)
104 |
105 | print("\n[Phase 3] : Feature & Score Extraction")
106 |
107 | def is_image(f):
108 | return f.endswith(".png") or f.endswith(".jpg")
109 |
110 | test_transform = transforms.Compose([
111 | transforms.Scale(224),
112 | transforms.CenterCrop(224),
113 | transforms.ToTensor(),
114 | transforms.Normalize(cf.mean, cf.std)
115 | ])
116 |
117 | if not os.path.isdir('vectors'):
118 | os.mkdir('vectors')
119 |
120 | for subdir, dirs, files in os.walk(data_dir):
121 | for f in files:
122 | file_path = subdir + os.sep + f
123 | if (is_image(f)):
124 | vector_dict = {
125 | 'file_path': "",
126 | 'feature': [],
127 | 'score': 0,
128 | }
129 |
130 | image = Image.open(file_path).convert('RGB')
131 | if test_transform is not None:
132 | image = test_transform(image)
133 | inputs = image
134 | inputs = Variable(inputs, volatile=True)
135 | if use_gpu:
136 | inputs = inputs.cuda()
137 | inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2)) # add batch dim in the front
138 | features = extractor(inputs).view(featureSize)
139 |
140 | outputs = model(inputs)
141 | softmax_res = softmax(outputs.data.cpu().numpy()[0])
142 |
143 | vector_dict['file_path'] = file_path
144 | vector_dict['feature'] = features
145 | vector_dict['score'] = softmax_res[1]
146 |
147 | vector_file = 'vectors' + os.sep + os.path.splitext(f)[0] + ".pickle"
148 |
149 | print(vector_file)
150 | print(vector_dict['feature'].size())
151 | print(vector_dict['score'])
152 |
153 | with open(vector_file, 'wb') as pkl:
154 | pickle.dump(vector_dict, pkl, protocol=pickle.HIGHEST_PROTOCOL)
155 |
--------------------------------------------------------------------------------
/imgs/pytorch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/fine-tuning.pytorch/91b45bbf1287a33603c344d64c06b6b1bf8f226e/imgs/pytorch.png
--------------------------------------------------------------------------------
/imgs/transfer-learning.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/fine-tuning.pytorch/91b45bbf1287a33603c344d64c06b6b1bf8f226e/imgs/transfer-learning.jpeg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # ************************************************************
2 | # Author : Bumsoo Kim, 2017
3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch
4 | #
5 | # Korea University, Data-Mining Lab
6 | # Deep Convolutional Network Fine tuning Implementation
7 | #
8 | # Description : inference.py
9 | # The main code for inference test phase of trained model.
10 | # ***********************************************************
11 |
12 | from __future__ import print_function, division
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | import torch.backends.cudnn as cudnn
18 | import numpy as np
19 | import config as cf
20 | import torchvision
21 | import time
22 | import copy
23 | import os
24 | import sys
25 | import argparse
26 | import csv
27 |
28 | from torchvision import datasets, models, transforms
29 | from networks import *
30 | from torch.autograd import Variable
31 | from PIL import Image
32 |
33 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training')
34 | parser.add_argument('--lr', default=1e-3, type=float, help='learning_rate')
35 | parser.add_argument('--net_type', default='resnet', type=str, help='model')
36 | parser.add_argument('--depth', default=50, type=int, help='depth of model')
37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model')
38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning')
39 | parser.add_argument('--path', default=cf.test_dir, type=str, help='inference path')
40 | args = parser.parse_args()
41 |
42 | # Phase 1 : Data Upload
43 | print('\n[Phase 1] : Data Preperation')
44 |
45 | cf.test_dir = args.path
46 | data_dir = cf.test_dir
47 | trainset_dir = cf.data_base.split("/")[-1] + os.sep
48 | print("| Preparing %s dataset..." %(cf.test_dir.split("/")[-1]))
49 |
50 | use_gpu = torch.cuda.is_available()
51 |
52 | # Phase 2 : Model setup
53 | print('\n[Phase 2] : Model setup')
54 |
55 | def getNetwork(args):
56 | if (args.net_type == 'alexnet'):
57 | net = models.alexnet(pretrained=args.finetune)
58 | file_name = 'alexnet'
59 | elif (args.net_type == 'vggnet'):
60 | if(args.depth == 16):
61 | net = models.vgg16(pretrained=args.finetune)
62 | file_name = 'vgg-%s' %(args.depth)
63 | elif (args.net_type == 'inception'):
64 | net = models.inception(pretrained=args.finetune)
65 | file_name = 'inceptino-v3'
66 | elif (args.net_type == 'resnet'):
67 | net = resnet(args.finetune, args.depth)
68 | file_name = 'resnet-%s' %(args.depth)
69 | else:
70 | print('Error : Network should be either [VGGNet / ResNet]')
71 | sys.exit(1)
72 |
73 | return net, file_name
74 |
75 | def softmax(x):
76 | return np.exp(x) / np.sum(np.exp(x), axis=0)
77 |
78 | print("| Loading checkpoint model for inference phase...")
79 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
80 | assert os.path.isdir('checkpoint/'+trainset_dir), 'Error: No model has been trained on the dataset!'
81 | _, file_name = getNetwork(args)
82 | checkpoint = torch.load('./checkpoint/'+trainset_dir+file_name+'.t7')
83 | model = checkpoint['model']
84 |
85 | if use_gpu:
86 | model.cuda()
87 | cudnn.benchmark = True
88 |
89 | model.eval()
90 |
91 | sample_input = Variable(torch.randn(1,3,224,224), volatile=True)
92 | if use_gpu:
93 | sample_input = sample_input.cuda()
94 |
95 | print("\n[Phase 3] : Score Inference")
96 |
97 | def is_image(f):
98 | return f.endswith(".png") or f.endswith(".jpg")
99 |
100 | test_transform = transforms.Compose([
101 | transforms.Scale(224),
102 | transforms.CenterCrop(224),
103 | transforms.ToTensor(),
104 | transforms.Normalize(cf.mean, cf.std)
105 | ])
106 |
107 | if not os.path.isdir('result'):
108 | os.mkdir('result')
109 |
110 | output_file = "./result/"+cf.test_dir.split("/")[-1]+".csv"
111 |
112 | with open(output_file, 'wb') as csvfile:
113 | fields = ['file_name', 'score']
114 | writer = csv.DictWriter(csvfile, fieldnames=fields)
115 | for subdir, dirs, files in os.walk(data_dir):
116 | for f in files:
117 | file_path = subdir + os.sep + f
118 | if (is_image(f)):
119 | image = Image.open(file_path).convert('RGB')
120 | if test_transform is not None:
121 | image = test_transform(image)
122 | inputs = image
123 | inputs = Variable(inputs, volatile=True)
124 | if use_gpu:
125 | inputs = inputs.cuda()
126 | inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2)) # add batch dim in the front
127 |
128 | outputs = model(inputs)
129 | softmax_res = softmax(outputs.data.cpu().numpy()[0])
130 | score = softmax_res[1]
131 |
132 | print(file_path + "," + str(score))
133 | writer.writerow({'file_name': file_path, 'score':score})
134 |
--------------------------------------------------------------------------------
/inference/alexnet.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --net_type alexnet
3 |
--------------------------------------------------------------------------------
/inference/recursive_resnet.sh:
--------------------------------------------------------------------------------
1 | for ((i=7;i<=9;i++)); do
2 | python inference.py \
3 | --net_type resnet \
4 | --depth 152 \
5 | --path /home/bumsoo/Data/test/inbreast_patches_test_1_$i
6 | done
7 |
--------------------------------------------------------------------------------
/inference/resnet.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --net_type resnet \
3 | --depth 152 \
4 |
--------------------------------------------------------------------------------
/inference/vggnet.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --net_type vggnet \
3 | --depth 16
4 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # ************************************************************
2 | # Author : Bumsoo Kim, 2017
3 | # Github : https://github.com/meliketoy/fine-tuning.pytorch
4 | #
5 | # Korea University, Data-Mining Lab
6 | # Deep Convolutional Network Fine tuning Implementation
7 | #
8 | # Description : main.py
9 | # The main code for training classification networks.
10 | # ***********************************************************
11 |
12 | from __future__ import print_function, division
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | import torch.backends.cudnn as cudnn
18 | import numpy as np
19 | import config as cf
20 | import torchvision
21 | import time
22 | import copy
23 | import os
24 | import sys
25 | import argparse
26 | import pretrainedmodels # exclude this for python2.7 users
27 |
28 | from torchvision import datasets, models, transforms
29 | from networks import *
30 | from torch.autograd import Variable
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch Digital Mammography Training')
33 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
34 | parser.add_argument('--net_type', default='resnet', type=str, help='model')
35 | parser.add_argument('--depth', default=50, type=int, help='depth of model')
36 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
37 | parser.add_argument('--finetune', '-f', action='store_true', help='Fine tune pretrained model')
38 | parser.add_argument('--addlayer','-a',action='store_true', help='Add additional layer in fine-tuning')
39 | parser.add_argument('--resetClassifier', '-r', action='store_true', help='Reset classifier')
40 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
41 | args = parser.parse_args()
42 |
43 | # Phase 1 : Data Upload
44 | print('\n[Phase 1] : Data Preperation')
45 |
46 | if args.net_type == 'inception' or args.net_type == 'xception':
47 | data_transforms = {
48 | 'train': transforms.Compose([
49 | transforms.Scale(320),
50 | transforms.RandomSizedCrop(299),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | transforms.Normalize(cf.mean, cf.std)
54 | ]),
55 | 'val': transforms.Compose([
56 | transforms.Scale(320),
57 | transforms.CenterCrop(299),
58 | transforms.ToTensor(),
59 | transforms.Normalize(cf.mean, cf.std)
60 | ]),
61 | }
62 | else:
63 | data_transforms = {
64 | 'train': transforms.Compose([
65 | transforms.Scale(256),
66 | transforms.RandomSizedCrop(224),
67 | transforms.RandomHorizontalFlip(),
68 | transforms.ToTensor(),
69 | transforms.Normalize(cf.mean, cf.std)
70 | ]),
71 | 'val': transforms.Compose([
72 | transforms.Scale(256),
73 | transforms.CenterCrop(224),
74 | transforms.ToTensor(),
75 | transforms.Normalize(cf.mean, cf.std)
76 | ]),
77 | }
78 |
79 | data_dir = cf.aug_base
80 | dataset_dir = cf.data_base.split("/")[-1] + os.sep
81 | print("| Preparing model trained on %s dataset..." %(cf.data_base.split("/")[-1]))
82 | dsets = {
83 | x : datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
84 | for x in ['train', 'val']
85 | }
86 | dset_loaders = {
87 | x : torch.utils.data.DataLoader(dsets[x], batch_size = cf.batch_size, shuffle=(x=='train'), num_workers=4)
88 | for x in ['train', 'val']
89 | }
90 |
91 | dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']}
92 | dset_classes = dsets['train'].classes
93 |
94 | use_gpu = torch.cuda.is_available()
95 |
96 | # Phase 2 : Model setup
97 | print('\n[Phase 2] : Model setup')
98 |
99 | def getNetwork(args):
100 | if (args.net_type == 'alexnet'):
101 | net = models.alexnet(pretrained=args.finetune)
102 | file_name = 'alexnet'
103 | elif (args.net_type == 'vggnet'):
104 | if(args.depth == 11):
105 | net = models.vgg11(pretrained=args.finetune)
106 | elif(args.depth == 13):
107 | net = models.vgg13(pretrained=args.finetune)
108 | elif(args.depth == 16):
109 | net = models.vgg16(pretrained=args.finetune)
110 | elif(args.depth == 19):
111 | net = models.vgg19(pretrained=args.finetune)
112 | else:
113 | print('Error : VGGnet should have depth of either [11, 13, 16, 19]')
114 | sys.exit(1)
115 | file_name = 'vgg-%s' %(args.depth)
116 | elif (args.net_type == 'squeezenet'):
117 | net = models.squeezenet1_0(pretrained=args.finetune)
118 | file_name = 'squeeze'
119 | elif (args.net_type == 'resnet'):
120 | net = resnet(args.finetune, args.depth)
121 | file_name = 'resnet-%s' %(args.depth)
122 | elif (args.net_type == 'inception'):
123 | net = pretrainedmodels.inceptionv3(num_classes=1000, pretrained='imagenet')
124 | file_name = 'inception-v3'
125 | elif (args.net_type == 'xception'):
126 | net = pretrainedmodels.xception(num_classes=1000, pretrained='imagenet')
127 | file_name = 'xception'
128 | else:
129 | print('Error : Network should be either [alexnet / squeezenet / vggnet / resnet]')
130 | sys.exit(1)
131 |
132 | return net, file_name
133 |
134 | def softmax(x):
135 | return np.exp(x) / np.sum(np.exp(x), axis=0)
136 |
137 | # Test only option
138 | if (args.testOnly):
139 | print("| Loading checkpoint model for test phase...")
140 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
141 | _, file_name = getNetwork(args)
142 | print('| Loading '+file_name+".t7...")
143 | checkpoint = torch.load('./checkpoint/'+dataset_dir+'/'+file_name+'.t7')
144 | model = checkpoint['model']
145 |
146 | if use_gpu:
147 | model.cuda()
148 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
149 | # cudnn.benchmark = True
150 |
151 | model.eval()
152 | test_loss = 0
153 | correct = 0
154 | total = 0
155 |
156 | testsets = datasets.ImageFolder(cf.test_dir, data_transforms['val'])
157 |
158 | testloader = torch.utils.data.DataLoader(
159 | testsets,
160 | batch_size = 1,
161 | shuffle = False,
162 | num_workers=1
163 | )
164 |
165 | print("\n[Phase 3 : Inference on %s]" %cf.test_dir)
166 | for batch_idx, (inputs, targets) in enumerate(testloader):#dset_loaders['val']):
167 | if use_gpu:
168 | inputs, targets = inputs.cuda(), targets.cuda()
169 | inputs, targets = Variable(inputs, volatile=True), Variable(targets)
170 | outputs = model(inputs)
171 |
172 | # print(outputs.data.cpu().numpy()[0])
173 | softmax_res = softmax(outputs.data.cpu().numpy()[0])
174 |
175 | _, predicted = torch.max(outputs.data, 1)
176 | total += targets.size(0)
177 | correct += predicted.eq(targets.data).cpu().sum()
178 |
179 | acc = 100.*correct/total
180 | print("| Test Result\tAcc@1 %.2f%%" %(acc))
181 |
182 | sys.exit(0)
183 |
184 | # Training model
185 | def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=cf.num_epochs):
186 | global dataset_dir
187 | since = time.time()
188 |
189 | best_model, best_acc = model, 0.0
190 |
191 | print('\n[Phase 3] : Training Model')
192 | print('| Training Epochs = %d' %num_epochs)
193 | print('| Initial Learning Rate = %f' %args.lr)
194 | print('| Optimizer = SGD')
195 | for epoch in range(num_epochs):
196 | for phase in ['train', 'val']:
197 | if phase == 'train':
198 | optimizer, lr = lr_scheduler(optimizer, epoch)
199 | print('\n=> Training Epoch #%d, LR=%f' %(epoch+1, lr))
200 | model.train(True)
201 | else:
202 | model.train(False)
203 | model.eval()
204 |
205 | running_loss, running_corrects, tot = 0.0, 0, 0
206 |
207 | for batch_idx, (inputs, labels) in enumerate(dset_loaders[phase]):
208 | if use_gpu:
209 | inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
210 | else:
211 | inputs, labels = Variable(inputs), Variable(labels)
212 |
213 | optimizer.zero_grad()
214 |
215 | # Forward Propagation
216 | outputs = model(inputs)
217 | if isinstance(outputs, tuple):
218 | loss = sum((criterion(o, labels) for o in outputs))
219 | else:
220 | loss = criterion(outputs, labels)
221 | if isinstance(outputs, tuple):
222 | # inception v3 output will be (x, aux)
223 | outputs = outputs[0]
224 | _, preds = torch.max(outputs.data, 1)
225 |
226 | # Backward Propagation
227 | if phase == 'train':
228 | loss.backward()
229 | optimizer.step()
230 |
231 | # Statistics
232 | running_loss += loss.data[0]
233 | running_corrects += preds.eq(labels.data).cpu().sum()
234 | tot += labels.size(0)
235 |
236 | if (phase == 'train'):
237 | sys.stdout.write('\r')
238 | sys.stdout.write('| Epoch [%2d/%2d] Iter [%3d/%3d]\t\tLoss %.4f\tAcc %.2f%%'
239 | %(epoch+1, num_epochs, batch_idx+1,
240 | (len(dsets[phase])//cf.batch_size)+1, loss.data[0], 100.*running_corrects/tot))
241 | sys.stdout.flush()
242 | sys.stdout.write('\r')
243 |
244 | epoch_loss = running_loss / dset_sizes[phase]
245 | epoch_acc = running_corrects / dset_sizes[phase]
246 |
247 | if (phase == 'val'):
248 | print('\n| Validation Epoch #%d\t\t\tLoss %.4f\tAcc %.2f%%'
249 | %(epoch+1, loss.data[0], 100.*epoch_acc))
250 |
251 | if epoch_acc > best_acc :#and epoch > 80:
252 | print('| Saving Best model...\t\t\tTop1 %.2f%%' %(100.*epoch_acc))
253 | best_acc = epoch_acc
254 | best_model = copy.deepcopy(model)
255 | state = {
256 | 'model': best_model,
257 | 'acc': epoch_acc,
258 | 'epoch':epoch,
259 | }
260 | if not os.path.isdir('checkpoint'):
261 | os.mkdir('checkpoint')
262 | save_point = './checkpoint/'+dataset_dir
263 | if not os.path.isdir(save_point):
264 | os.mkdir(save_point)
265 | torch.save(state, save_point+file_name+'.t7')
266 |
267 | time_elapsed = time.time() - since
268 | print('\nTraining completed in\t{:.0f} min {:.0f} sec'. format(time_elapsed // 60, time_elapsed % 60))
269 | print('Best validation Acc\t{:.2f}%'.format(best_acc*100))
270 |
271 | return best_model
272 |
273 | def exp_lr_scheduler(optimizer, epoch, init_lr=args.lr, weight_decay=args.weight_decay, lr_decay_epoch=cf.lr_decay_epoch):
274 | lr = init_lr * (0.5**(epoch // lr_decay_epoch))
275 |
276 | for param_group in optimizer.param_groups:
277 | param_group['lr'] = lr
278 | param_group['weight_decay'] = weight_decay
279 |
280 | return optimizer, lr
281 |
282 | model_ft, file_name = getNetwork(args)
283 |
284 | if(args.resetClassifier):
285 | print('| Reset final classifier...')
286 | if(args.addlayer):
287 | print('| Add features of size %d' %cf.feature_size)
288 | num_ftrs = model_ft.fc.in_features
289 | feature_model = list(model_ft.fc.children())
290 | feature_model.append(nn.Linear(num_ftrs, cf.feature_size))
291 | feature_model.append(nn.BatchNorm1d(cf.feature_size))
292 | feature_model.append(nn.ReLU(inplace=True))
293 | feature_model.append(nn.Linear(cf.feature_size, len(dset_classes)))
294 | model_ft.fc = nn.Sequential(*feature_model)
295 | else:
296 | if(args.net_type == 'alexnet' or args.net_type == 'vggnet'):
297 | num_ftrs = model_ft.classifier[6].in_features
298 | feature_model = list(model_ft.classifier.children())
299 | feature_model.pop()
300 | feature_model.append(nn.Linear(num_ftrs, len(dset_classes)))
301 | model_ft.classifier = nn.Sequential(*feature_model)
302 | elif(args.net_type == 'resnet'):
303 | num_ftrs = model_ft.fc.in_features
304 | model_ft.fc = nn.Linear(num_ftrs, len(dset_classes))
305 | elif(args.net_type == 'inception' or args.net_type == 'xception'):
306 | num_ftrs = model_ft.last_linear.in_features
307 | model_ft.last_linear = nn.Linear(num_ftrs, len(dset_classes))
308 |
309 | if use_gpu:
310 | model_ft = model_ft.cuda()
311 | model_ft = torch.nn.DataParallel(model_ft, device_ids=range(torch.cuda.device_count()))
312 | cudnn.benchmark = True
313 |
314 | if __name__ == "__main__":
315 | criterion = nn.CrossEntropyLoss()
316 | optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
317 | model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=cf.num_epochs)
318 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = ['ResNet', 'resnet']
6 |
7 |
8 | model_urls = {
9 | 'resnet18': 'http://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'http://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'http://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'http://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'http://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 | def cfg(depth):
23 | depth_lst = [18, 34, 50, 101, 152]
24 | assert (depth in depth_lst), "Error : ResNet depth should be either 18, 34, 50, 101, 152"
25 | cf_dict = {
26 | '18' : (BasicBlock, [2,2, 2,2]),
27 | '34' : (BasicBlock, [3,4, 6,3]),
28 | '50' : (Bottleneck, [3,4, 6,3]),
29 | '101': (Bottleneck, [3,4,23,3]),
30 | '152': (Bottleneck, [3,8,36,3]),
31 | }
32 |
33 | return cf_dict[str(depth)]
34 |
35 |
36 | class BasicBlock(nn.Module):
37 | expansion = 1
38 |
39 | def __init__(self, inplanes, planes, stride=1, downsample=None):
40 | super(BasicBlock, self).__init__()
41 | self.conv1 = conv3x3(inplanes, planes, stride)
42 | self.bn1 = nn.BatchNorm2d(planes)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.conv2 = conv3x3(planes, planes)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.downsample = downsample
47 | self.stride = stride
48 |
49 | def forward(self, x):
50 | residual = x
51 |
52 | out = self.conv1(x)
53 | out = self.bn1(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv2(out)
57 | out = self.bn2(out)
58 |
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 |
62 | out += residual
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(self, inplanes, planes, stride=1, downsample=None):
72 | super(Bottleneck, self).__init__()
73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
74 | self.bn1 = nn.BatchNorm2d(planes)
75 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
76 | padding=1, bias=False)
77 | self.bn2 = nn.BatchNorm2d(planes)
78 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
79 | self.bn3 = nn.BatchNorm2d(planes * 4)
80 | self.relu = nn.ReLU(inplace=True)
81 | self.downsample = downsample
82 | self.stride = stride
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv3(out)
96 | out = self.bn3(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 | class ResNet(nn.Module):
107 |
108 | def __init__(self, block, layers, num_classes=1000):
109 | self.inplanes = 64
110 | super(ResNet, self).__init__()
111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
112 | bias=False)
113 | self.bn1 = nn.BatchNorm2d(64)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 | self.layer1 = self._make_layer(block, 64, layers[0])
117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
120 | self.avgpool = nn.AvgPool2d(7)
121 | self.fc = nn.Linear(512 * block.expansion, num_classes)
122 |
123 | for m in self.modules():
124 | if isinstance(m, nn.Conv2d):
125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
126 | m.weight.data.normal_(0, math.sqrt(2. / n))
127 | elif isinstance(m, nn.BatchNorm2d):
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 |
131 | def _make_layer(self, block, planes, blocks, stride=1):
132 | downsample = None
133 | if stride != 1 or self.inplanes != planes * block.expansion:
134 | downsample = nn.Sequential(
135 | nn.Conv2d(self.inplanes, planes * block.expansion,
136 | kernel_size=1, stride=stride, bias=False),
137 | nn.BatchNorm2d(planes * block.expansion),
138 | )
139 |
140 | layers = []
141 | layers.append(block(self.inplanes, planes, stride, downsample))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | x = self.conv1(x)
150 | x = self.bn1(x)
151 | x = self.relu(x)
152 | x = self.maxpool(x)
153 |
154 | x = self.layer1(x)
155 | x = self.layer2(x)
156 | x = self.layer3(x)
157 | x = self.layer4(x)
158 |
159 | x = self.avgpool(x)
160 | x = x.view(x.size(0), -1)
161 | x = self.fc(x)
162 |
163 | return x
164 |
165 | def resnet(pretrained=False, depth=18, **kwargs):
166 | """Constructs ResNet models for various depths
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | depth (int) : Integer input of either 18, 34, 50, 101, 152
170 | """
171 | block, num_blocks = cfg(depth)
172 | model = ResNet(block, num_blocks, **kwargs)
173 | if (pretrained):
174 | print("| Downloading ImageNet fine-tuned ResNet-%d..." %depth)
175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet%d' %depth]))
176 | return model
177 |
--------------------------------------------------------------------------------
/test/alexnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --net_type alexnet \
3 | --testOnly
4 |
--------------------------------------------------------------------------------
/test/resnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --net_type resnet \
3 | --depth 152 \
4 | --testOnly
5 |
--------------------------------------------------------------------------------
/test/vggnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --net_type vggnet \
3 | --depth 16 \
4 | --testOnly
5 |
--------------------------------------------------------------------------------
/train/alexnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --lr 1e-3 \
3 | --weight_decay 1e-4 \
4 | --net_type alexnet \
5 | --resetClassifier \
6 | --finetune
7 |
--------------------------------------------------------------------------------
/train/inception.sh:
--------------------------------------------------------------------------------
1 | python3 main.py \
2 | --optimizer SGD \
3 | --lr 0.045 \
4 | --weight_decay 4e-5 \
5 | --net_type inception \
6 | --depth 50 \
7 | --resetClassifier \
8 | --finetune
9 |
--------------------------------------------------------------------------------
/train/resnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --lr 1e-3 \
3 | --weight_decay 5e-4 \
4 | --net_type resnet \
5 | --depth 152 \
6 | --resetClassifier \
7 | --finetune \
8 | #--addlayer
9 |
--------------------------------------------------------------------------------
/train/squeeze.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --lr 1e-3 \
3 | --weight_decay 5e-4 \
4 | --net_type squeezenet \
5 | --resetClassifier \
6 | --finetune
7 |
--------------------------------------------------------------------------------
/train/vggnet.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --lr 1e-3 \
3 | --weight_decay 5e-4 \
4 | --net_type vggnet \
5 | --depth 16 \
6 | --resetClassifier \
7 | --finetune
8 |
--------------------------------------------------------------------------------
/train/xception.sh:
--------------------------------------------------------------------------------
1 | python3 main.py \
2 | --lr 0.045 \
3 | --optimizer SGD \
4 | --weight_decay 1e-5 \
5 | --net_type xception \
6 | --resetClassifier \
7 | --finetune
8 | #--testOnly
9 |
--------------------------------------------------------------------------------