├── .gitignore
├── .ipynb_checkpoints
└── Untitled-checkpoint.ipynb
├── Commands.MD
├── README.MD
├── constant.py
├── data.py
├── folerStructure.py
├── model
├── __init__.py
├── __pycache__
│ ├── block.cpython-37.pyc
│ ├── block.cpython-38.pyc
│ ├── stage.cpython-37.pyc
│ └── stage.cpython-38.pyc
├── components
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── block.cpython-37.pyc
│ │ ├── block.cpython-38.pyc
│ │ ├── stage.cpython-37.pyc
│ │ └── stage.cpython-38.pyc
│ ├── block.py
│ └── stage.py
└── model.py
├── predict.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/.gitignore
--------------------------------------------------------------------------------
/.ipynb_checkpoints/Untitled-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 5
6 | }
7 |
--------------------------------------------------------------------------------
/Commands.MD:
--------------------------------------------------------------------------------
1 | ## Create Enviroment
2 |
3 | ```python
4 | conda create -n base-library python==3.7.0
5 | ```
6 |
7 | ## Active Enviroment
8 |
9 | ```python
10 | conda activate base-library
11 | ```
12 |
13 | ## Install Dependencies
14 |
15 | ```python
16 | pip install tensorflow==2.5.0
17 | ```
18 |
19 | ## Export Dependencies
20 |
21 | ```python
22 | conda list -e > requirements.txt
23 | ```
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # ResNet
2 |
3 | Implementation of ResNet: [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). Give us a star if you like this repo.
4 |
5 |
6 |
7 |
8 |
9 | Slide to tell how did we analyze and implement Resnet: [Click here](https://docs.google.com/presentation/d/1a18IyR1nc6GLTU5hibw3VabLVFbPjeE1/edit?usp=sharing&ouid=110166614717096573075&rtpof=true&sd=true)
10 |
11 |
12 | Authors:
13 | - Github: dark-kazansky, bdghuy, hoangcaobao, hoangduc199891, sonnymetvn
14 | - Email: dark.kazansky@gmail.com, caobaohoang03@gmail.com, hoangduc199892@gmail.com, bdghuy@gmail.com, sonny.metvn@gmail.com
15 |
16 | Advisors:
17 | - Github: [bangoc123](https://github.com/bangoc123)
18 | - Email: protonxai@gmail.com
19 |
20 | ## I. Set up environment
21 | - Step 1: Make sure you have installed conda like miniConda or Anaconda.
22 |
23 |
24 |
25 | - Step 2: from your terminal ```cd``` into Resnet folder then ```python conda env create -f environment.yml```
26 |
27 |
28 | ## II. Set up your dataset
29 |
30 | - Option 1: Run ```data.py``` to download cat&dog data set
31 | - Option 2: Set up custom data set
32 |
33 | ```python
34 |
35 | python folderStructure.py
36 |
37 | ```
38 | and follow the instruction to create some custom data folders. Then, copy your images to these folders.
39 |
40 | ## III. Training Process
41 | Training script:
42 |
43 |
44 | ```python
45 |
46 | python train.py --train-folder ${train_folder} --valid-folder ${valid_folder} --num-classes ${num_classes} --epochs ${epochs}
47 | ```
48 | Example:
49 |
50 | ```python
51 |
52 | python train.py --model 'resnet34' --epochs 120 --num-classes 2 --train-folder $train_folder --valid-folder $valid_folder
53 | ```
54 | There are some important arguments for the script you should consider when running it:
55 |
56 | - `train-folder`: The folder of training data
57 | - `valid-folder`: The folder of validation data
58 | - ...
59 |
60 | Notebook training: [
](https://colab.research.google.com/drive/1cb7Nkkn5o_U7cVvnxXmfx-jdyoc4slhu?usp=sharing)
61 |
62 | ## IV. Predict Process
63 |
64 | ```bash
65 | python predict.py --test-data ${link_to_test_data}
66 | ```
67 |
68 | ## V. Result and Comparision
69 |
70 | Your implementation
71 | ```
72 | Epoch 00189: val_accuracy did not improve from 0.87700
73 | Epoch 190/200
74 | 32/32 [==============================] - 55s 2s/step - loss: 0.0667 - accuracy: 0.9770 - val_loss: 0.8112 - val_accuracy: 0.8360
75 |
76 | Epoch 00190: val_accuracy did not improve from 0.87700
77 | Epoch 191/200
78 | 32/32 [==============================] - 55s 2s/step - loss: 0.0517 - accuracy: 0.9800 - val_loss: 0.7989 - val_accuracy: 0.8460
79 |
80 | Epoch 00191: val_accuracy did not improve from 0.87700
81 | Epoch 192/200
82 | 32/32 [==============================] - 54s 2s/step - loss: 0.0486 - accuracy: 0.9845 - val_loss: 0.6213 - val_accuracy: 0.8630
83 |
84 | Epoch 00192: val_accuracy did not improve from 0.87700
85 | Epoch 193/200
86 | 32/32 [==============================] - 55s 2s/step - loss: 0.0464 - accuracy: 0.9835 - val_loss: 0.5506 - val_accuracy: 0.8450
87 |
88 | Epoch 00193: val_accuracy did not improve from 0.87700
89 | Epoch 194/200
90 | 32/32 [==============================] - 55s 2s/step - loss: 0.0471 - accuracy: 0.9835 - val_loss: 1.0926 - val_accuracy: 0.8220
91 |
92 | Epoch 00194: val_accuracy did not improve from 0.87700
93 | Epoch 195/200
94 | 32/32 [==============================] - 55s 2s/step - loss: 0.0713 - accuracy: 0.9770 - val_loss: 1.0000 - val_accuracy: 0.8200
95 |
96 | Epoch 00195: val_accuracy did not improve from 0.87700
97 | Epoch 196/200
98 | 32/32 [==============================] - 55s 2s/step - loss: 0.0512 - accuracy: 0.9835 - val_loss: 1.9371 - val_accuracy: 0.6830
99 |
100 | Epoch 00196: val_accuracy did not improve from 0.87700
101 | Epoch 197/200
102 | 32/32 [==============================] - 55s 2s/step - loss: 0.0575 - accuracy: 0.9805 - val_loss: 1.1376 - val_accuracy: 0.7760
103 |
104 | Epoch 00197: val_accuracy did not improve from 0.87700
105 | Epoch 198/200
106 | 32/32 [==============================] - 55s 2s/step - loss: 0.0484 - accuracy: 0.9825 - val_loss: 0.6597 - val_accuracy: 0.8590
107 |
108 | Epoch 00198: val_accuracy did not improve from 0.87700
109 | Epoch 199/200
110 | 32/32 [==============================] - 55s 2s/step - loss: 0.0712 - accuracy: 0.9720 - val_loss: 1.4779 - val_accuracy: 0.8010
111 |
112 | Epoch 00199: val_accuracy did not improve from 0.87700
113 | Epoch 200/200
114 | 32/32 [==============================] - 55s 2s/step - loss: 0.0484 - accuracy: 0.9825 - val_loss: 0.6597 - val_accuracy: 0.8590
115 |
116 | Epoch 00200: val_accuracy did not improve from 0.87700
117 |
118 | ```
119 |
120 | ## VI. Running Test
121 |
122 | The ```best_model.h5``` of resnet50 is too large to commit on github so you can download it [here](https://drive.google.com/file/d/1pDfrAt7wHvDZrX4uolFxwE7KeUVTJvY8/view?usp=sharing). Then copy to the base folder to load model.
123 | In the ```./ResNet``` folder, please run: ```predict.py --test-image "image-path"``` to process.
124 |
125 | Or you can try this: [
](https://colab.research.google.com/drive/1ySFObB6ZPgxJyq8G9_dHpQypExXAgAI2?usp=sharing)
126 |
127 | This is some results from us when we test for some regular dog or cat pictures:
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/constant.py:
--------------------------------------------------------------------------------
1 | # Define constant variables
2 |
3 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
2 | import tensorflow as tf
3 | import os
4 |
5 | URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
6 | zip_dir = tf.keras.utils.get_file('cats_and_dogs_filtered.zip', origin=URL, extract=True, cache_subdir = os.getcwd())
7 | os.rename('cats_and_dogs_filtered', 'Data')
8 | os.remove('cats_and_dogs_filtered.zip')
9 | base_dir = os.path.join(os.path.dirname(zip_dir), 'data')
10 | train_dir = os.path.join(base_dir, 'train')
11 | validation_dir = os.path.join(base_dir, 'validation')
12 | train_cats_dir = os.path.join(train_dir, 'cats')
13 | train_dogs_dir = os.path.join(train_dir, 'dogs')
14 | validation_cats_dir = os.path.join(validation_dir, 'cats')
15 | validation_dogs_dir = os.path.join(validation_dir, 'dogs')
16 |
17 | num_cats_tr = len(os.listdir(train_cats_dir))
18 | num_dogs_tr = len(os.listdir(train_dogs_dir))
19 | num_cats_val = len(os.listdir(validation_cats_dir))
20 | num_dogs_val = len(os.listdir(validation_dogs_dir))
21 | total_train = num_cats_tr + num_dogs_tr
22 | total_val = num_cats_val + num_dogs_val
23 |
24 | print('total training cat images:', num_cats_tr)
25 | print('total training dog images:', num_dogs_tr)
26 | print('total validation cat images:', num_cats_val)
27 | print('total validation dog images:', num_dogs_val)
28 | print("--")
29 | print("Total training images:", total_train)
30 | print("Total validation images:", total_val)
31 |
--------------------------------------------------------------------------------
/folerStructure.py:
--------------------------------------------------------------------------------
1 | import os
2 | """
3 | Input:
4 | n: int - the number of class
5 | Output:
6 | Folder structure:
7 | main_directory/
8 | .../data
9 | .../train
10 | ...class_a/
11 | ......a_image_1.jpg
12 | ......a_image_2.jpg
13 | ...class_b/
14 | ......b_image_1.jpg
15 | ......b_image_2.jpg
16 | .../validation
17 | ...class_a/
18 | ......a_image_1.jpg
19 | ......a_image_2.jpg
20 | ...class_b/
21 | ......b_image_1.jpg
22 | ......b_image_2.jpg
23 | """
24 | try:
25 | number = int(input('Enter the number of classes: '))
26 | if(number>1):
27 | name = []
28 | for type in range(number):
29 | name.append(input("Type the name of class {}: ".format(type+1)))
30 | os.mkdir('data')
31 | os.chdir('data')
32 | path = ['train', 'validation']
33 | for i in path:
34 | os.mkdir(i)
35 | os.chdir('{}/'.format(i))
36 | for j in name:
37 | os.mkdir(j)
38 | os.chdir('../')
39 | print('Data including {} classes folder have been created!'.format(number))
40 | else:
41 | print('Please type the number of classes greater than 1!')
42 | except ValueError:
43 | print('Please type number!')
44 | except FileExistsError:
45 | print('Cannot create a file when that file already exists!')
46 | print("""Please remove "Data" folder!""")
47 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import resnet18, resnet34, resnet50, resnet101, resnet152
2 |
--------------------------------------------------------------------------------
/model/__pycache__/block.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/__pycache__/block.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/block.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/__pycache__/block.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/stage.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/__pycache__/stage.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/stage.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/__pycache__/stage.cpython-38.pyc
--------------------------------------------------------------------------------
/model/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .stage import stage
2 |
3 |
--------------------------------------------------------------------------------
/model/components/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/components/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/components/__pycache__/block.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/components/__pycache__/block.cpython-37.pyc
--------------------------------------------------------------------------------
/model/components/__pycache__/block.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/components/__pycache__/block.cpython-38.pyc
--------------------------------------------------------------------------------
/model/components/__pycache__/stage.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/components/__pycache__/stage.cpython-37.pyc
--------------------------------------------------------------------------------
/model/components/__pycache__/stage.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-03-projects/ResNet/353704900385218def9523ee96ab368d44c2e80a/model/components/__pycache__/stage.cpython-38.pyc
--------------------------------------------------------------------------------
/model/components/block.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, Add
3 | from tensorflow.keras import Model
4 |
5 | def basic_block(input, filter_num, stride=1,stage_idx=-1, block_idx=-1):
6 | '''BasicBlock use stack of two 3x3 convolutions layers
7 |
8 | Args:
9 | filter_num: the number of filters in the convolution
10 | stride: the number of strides in the convolution. stride = 1 if you want
11 | output shape is same as input shape
12 | stage_idx: index of current stage
13 | block_idx: index of current block in stage
14 | '''
15 | # conv3x3
16 | conv1=Conv2D(filters=filter_num,
17 | kernel_size=3,
18 | strides=stride,
19 | padding='same',
20 | kernel_initializer='he_normal',
21 | name='conv{}_block{}_1_conv'.format(stage_idx, block_idx))(input)
22 | bn1=BatchNormalization(name='conv{}_block{}_1_bn'.format(stage_idx, block_idx))(conv1)
23 | relu1=ReLU(name='conv{}_block{}_1_relu'.format(stage_idx, block_idx))(bn1)
24 | # conv3x3
25 | conv2=Conv2D(filters=filter_num,
26 | kernel_size=3,
27 | strides=1,
28 | padding='same',
29 | kernel_initializer='he_normal',
30 | name='conv{}_block{}_2_conv'.format(stage_idx, block_idx))(relu1)
31 | bn2=BatchNormalization(name='conv{}_block{}_2_bn'.format(stage_idx, block_idx))(conv2)
32 |
33 | return bn2
34 |
35 | def bottleneck_block(input, filter_num, stride=1, stage_idx=-1, block_idx=-1):
36 | '''BottleNeckBlock use stack of 3 layers: 1x1, 3x3 and 1x1 convolutions
37 |
38 | Args:
39 | filter_num: the number of filters in the convolution
40 | stride: the number of strides in the convolution. stride = 1 if you want
41 | output shape is same as input shape
42 | stage_idx: index of current stage
43 | block_idx: index of current block in stage
44 | '''
45 | # conv1x1
46 | conv1=Conv2D(filters=filter_num,
47 | kernel_size=1,
48 | strides=stride,
49 | padding='valid',
50 | kernel_initializer='he_normal',
51 | name='conv{}_block{}_1_conv'.format(stage_idx, block_idx))(input)
52 | bn1=BatchNormalization(name='conv{}_block{}_1_bn'.format(stage_idx, block_idx))(conv1)
53 | relu1=ReLU(name='conv{}_block{}_1_relu'.format(stage_idx, block_idx))(bn1)
54 | # conv3x3
55 | conv2=Conv2D(filters=filter_num,
56 | kernel_size=3,
57 | strides=1,
58 | padding='same',
59 | kernel_initializer='he_normal',
60 | name='conv{}_block{}_2_conv'.format(stage_idx, block_idx))(relu1)
61 | bn2=BatchNormalization(name='conv{}_block{}_2_bn'.format(stage_idx, block_idx))(conv2)
62 | relu2=ReLU(name='conv{}_block{}_2_relu'.format(stage_idx, block_idx))(bn2)
63 | # conv1x1
64 | conv3=Conv2D(filters=4*filter_num,
65 | kernel_size=1,
66 | strides=1,
67 | padding='valid',
68 | kernel_initializer='he_normal',
69 | name='conv{}_block{}_3_conv'.format(stage_idx, block_idx))(relu2)
70 | bn3=BatchNormalization(name='conv{}_block{}_3_bn'.format(stage_idx, block_idx))(conv3)
71 |
72 | return bn3
73 |
74 | def resblock(input, filter_num, stride=1, use_bottleneck=False,stage_idx=-1, block_idx=-1):
75 | '''A complete `Residual Unit` of ResNet
76 |
77 | Args:
78 | filter_num: the number of filters in the convolution
79 | stride: the number of strides in the convolution. stride = 1 if you want
80 | output shape is same as input shape
81 | use_bottleneck: type of block: basic or bottleneck
82 | stage_idx: index of current stage
83 | block_idx: index of current block in stage
84 | '''
85 | if use_bottleneck:
86 | residual = bottleneck_block(input, filter_num, stride,stage_idx, block_idx)
87 | expansion=4
88 | else:
89 | residual = basic_block(input, filter_num, stride,stage_idx, block_idx)
90 | expansion=1
91 |
92 | shortcut=input
93 | # use projection short cut when dimensions increase
94 | if stride>1 or input.shape[3]!=residual.shape[3]:
95 | shortcut=Conv2D(expansion*filter_num,
96 | kernel_size=1,
97 | strides=stride,
98 | padding='valid',
99 | kernel_initializer='he_normal',
100 | name='conv{}_block{}_projection-shortcut_conv'.format(stage_idx, block_idx))(input)
101 | shortcut=BatchNormalization(name='conv{}_block{}_projection-shortcut_bn'.format(stage_idx, block_idx))(shortcut)
102 |
103 | output=Add(name='conv{}_block{}_add'.format(stage_idx, block_idx))([residual, shortcut])
104 |
105 | return ReLU(name='conv{}_block{}_relu'.format(stage_idx, block_idx))(output)
106 |
--------------------------------------------------------------------------------
/model/components/stage.py:
--------------------------------------------------------------------------------
1 | from .block import resblock
2 |
3 | def stage(input, filter_num, num_block, use_downsample=True, use_bottleneck=False,stage_idx=-1):
4 | ''' -- Stacking Residual Units on the same stage
5 |
6 | Args:
7 | filter_num: the number of filters in the convolution used during stage
8 | num_block: number of `Residual Unit` in a stage
9 | use_downsample: Down-sampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2
10 | use_bottleneck: type of block: basic or bottleneck
11 | stage_idx: index of current stage
12 | '''
13 | net = resblock(input = input, filter_num = filter_num, stride = 2 if use_downsample else 1, use_bottleneck = use_bottleneck, stage_idx = stage_idx, block_idx = 1)
14 |
15 | for i in range(1, num_block):
16 | net = resblock(input = net, filter_num = filter_num,stride = 1,use_bottleneck = use_bottleneck,stage_idx = stage_idx, block_idx = i+1)
17 |
18 | return net
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | from .components import stage
2 | from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D, GlobalAveragePooling2D,ReLU, Dense
3 | from tensorflow.keras import Model
4 |
5 | def build(input_shape,num_classes,layers,use_bottleneck=False):
6 | '''A complete `stage` of ResNet
7 | '''
8 | input=Input(input_shape,name='input')
9 | # conv1
10 | net=Conv2D(filters=64,
11 | kernel_size=7,
12 | strides=2,
13 | padding='same',
14 | kernel_initializer='he_normal',
15 | name='conv1_conv')(input)
16 | net=BatchNormalization(name='conv1_bn')(net)
17 | net=ReLU(name='conv1_relu')(net)
18 | net=MaxPooling2D(pool_size=3,
19 | strides=2,
20 | padding='same',
21 | name='conv1_max_pool')(net)
22 |
23 | # conv2_x, conv3_x, conv4_x, conv5_x
24 | filters=[64,128,256,512]
25 | for i in range(len(filters)):
26 | net=stage(input = net,
27 | filter_num = filters[i],
28 | num_block = layers[i],
29 | use_downsample = i!=0,
30 | use_bottleneck = use_bottleneck,
31 | stage_idx = i+2)
32 |
33 | net=GlobalAveragePooling2D(name='avg_pool')(net)
34 | output=Dense(num_classes,activation='softmax',name='predictions')(net)
35 | model=Model(input,output)
36 |
37 | return model
38 |
39 | def resnet18(input_shape=(224,224,3),num_classes=1000):
40 | return build(input_shape,num_classes,[2,2,2,2],use_bottleneck=False)
41 |
42 | def resnet34(input_shape=(224,224,3),num_classes=1000):
43 | return build(input_shape,num_classes,[3,4,6,3],use_bottleneck=False)
44 |
45 | def resnet50(input_shape=(224,224,3),num_classes=1000):
46 | return build(input_shape,num_classes,[3,4,6,3],use_bottleneck=True)
47 |
48 | def resnet101(input_shape=(224,224,3),num_classes=1000):
49 | return build(input_shape,num_classes,[3,4,23,3],use_bottleneck=True)
50 |
51 | def resnet152(input_shape=(224,224,3),num_classes=1000):
52 | return build(input_shape,num_classes,[3,8,36,3],use_bottleneck=True)
53 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.models import load_model
2 | from tensorflow.keras import preprocessing
3 | import tensorflow as tf
4 | from argparse import ArgumentParser
5 | import numpy as np
6 | import pickle
7 |
8 | if __name__ == "__main__":
9 | parser = ArgumentParser()
10 | parser.add_argument("--test-file-path", type=str, required=True)
11 | parser.add_argument("--model-path", default="best_model.h5", type=str)
12 | parser.add_argument("--class-names-path", default='class_names.pkl', type=str)
13 |
14 |
15 |
16 | args = parser.parse_args()
17 |
18 |
19 | print('---------------------Welcome to ResNet-------------------')
20 | print("Team leader")
21 | print('Github: dark-kazansky')
22 | print("Team member")
23 | print('1. Github: hoangcaobao')
24 | print('2. Github: sonnymetvn')
25 | print('3. Github: hoangduc199891')
26 | print('4. Github: bdghuy')
27 | print('-------------------------------------------------------- ')
28 | print('Predict using ResNet model for test file path {0}'.format(args.test_file_path)) # FIXME
29 | print('===========================')
30 |
31 | # Loading class names
32 | with open (args.class_names_path, 'rb') as fp:
33 | class_names = pickle.load(fp)
34 |
35 | # Loading model
36 | model=load_model(args.model_path)
37 |
38 | # Load test images
39 | image = preprocessing.image.load_img(args.test_file_path, target_size=(224,224))
40 | input_arr = preprocessing.image.img_to_array(image)/225
41 | x = np.expand_dims(input_arr, axis=0)
42 |
43 | predictions = model.predict(x)
44 | label=np.argmax(predictions)
45 | print('Result: {}'.format(class_names[label]))
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: win-64
4 | absl-py=0.13.0
5 | astunparse=1.6.3
6 | cached-property=1.5.2
7 | cachetools=4.2.2
8 | certifi=2021.5.30
9 | charset-normalizer=2.0.4
10 | flatbuffers=1.12
11 | gast=0.4.0
12 | google-auth=1.34.0
13 | google-auth-oauthlib=0.4.5
14 | google-pasta=0.2.0
15 | grpcio=1.34.1
16 | h5py=3.1.0
17 | idna=3.2
18 | importlib-metadata=4.6.3
19 | keras-nightly=2.5.0.dev2021032900
20 | keras-preprocessing=1.1.2
21 | markdown=3.3.4
22 | numpy=1.19.5
23 | oauthlib=3.1.1
24 | opt-einsum=3.3.0
25 | pip=21.2.2
26 | protobuf=3.17.3
27 | pyasn1=0.4.8
28 | pyasn1-modules=0.2.8
29 | python=3.7.0
30 | requests=2.26.0
31 | requests-oauthlib
32 | rsa=4.7.2
33 | setuptools=52.0.0
34 | six=1.15.0
35 | tensorboard=2.6.0
36 | tensorboard-data-server=0.6.1
37 | tensorboard-plugin-wit=1.8.0
38 | tensorflow=2.5.0
39 | tensorflow-estimator=2.5.0
40 | termcolor=1.1.0
41 | typing-extensions=3.7.4.3
42 | urllib3=1.26.6
43 | vc=14.2
44 | vs2015_runtime=14.27.29016
45 | werkzeug=2.0.1
46 | wheel=0.36.2
47 | wincertstore=0.2
48 | wrapt=1.12.1
49 | zipp=3.5.0
50 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from model import resnet18, resnet34, resnet50, resnet101, resnet152
2 | from tensorflow.keras.models import Sequential
3 | from tensorflow.keras.layers import Input, Conv2D,BatchNormalization, AveragePooling2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
4 | from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
5 | from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Adadelta, Adamax
6 | from argparse import ArgumentParser
7 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
8 | from tensorflow.keras.callbacks import ModelCheckpoint
9 | import pickle
10 | import os
11 |
12 | if __name__ == "__main__":
13 | parser = ArgumentParser()
14 |
15 | # Arguments users used when running command lines
16 | parser.add_argument('--train-folder', default='Data/Train', type=str, help='Where training data is located')
17 | parser.add_argument('--valid-folder', default='Data/Validation', type=str, help='Where validation data is located')
18 | parser.add_argument('--model', default='resnet50', type=str, help='Type of model')
19 | parser.add_argument('--num-classes', default=2, type=int, help='Number of classes')
20 | parser.add_argument("--batch-size", default=32, type=int)
21 | parser.add_argument('--image-size', default=224, type=int, help='Size of input image')
22 | parser.add_argument('--optimizer', default='adam', type=str, help='Types of optimizers')
23 | parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
24 | parser.add_argument('--epochs', default=120, type=int, help = 'Number of epochs')
25 | parser.add_argument('--image-channels', default=3, type=int, help='Number channel of input image')
26 | parser.add_argument('--class-mode', default='sparse', type=str, help='Class mode to compile')
27 | parser.add_argument('--model-path', default='best_model.h5', type=str, help='Path to save trained model')
28 | parser.add_argument('--class-names-path', default='class_names.pkl', type=str, help='Path to save class names')
29 |
30 |
31 | # parser.add_argument('--model-folder', default='.output/', type=str, help='Folder to save trained model')
32 | args = parser.parse_args()
33 |
34 | # Project Description
35 |
36 | print('---------------------Welcome to resnet-------------------')
37 | print('Github: hoangduc199891')
38 | print('Email: hoangduc199892@gmail.com')
39 | print('---------------------------------------------------------------------')
40 | print('Training resnet model with hyper-params:')
41 | print('===========================')
42 |
43 | # Invoke folder path
44 | TRAINING_DIR = args.train_folder
45 | TEST_DIR = args.valid_folder
46 |
47 | loss = SparseCategoricalCrossentropy()
48 | class_mode = args.class_mode
49 | classes = args.num_classes
50 |
51 | training_datagen = ImageDataGenerator(
52 | rescale=1. / 255,
53 | rotation_range=20,
54 | width_shift_range=0.2,
55 | height_shift_range=0.2,
56 | shear_range=0.2,
57 | zoom_range=0.2)
58 | val_datagen = ImageDataGenerator(rescale=1. / 255)
59 |
60 | train_generator = training_datagen.flow_from_directory(TRAINING_DIR, target_size=(224, 244), batch_size= 64, class_mode = class_mode )
61 | val_generator = val_datagen.flow_from_directory(TEST_DIR, target_size=(224, 224), batch_size= 64, class_mode = class_mode)
62 |
63 | class_names=list(train_generator.class_indices.keys())
64 | with open(args.class_names_path,'wb') as fp:
65 | pickle.dump(class_names, fp)
66 |
67 | # Create model
68 | if args.model == 'resnet18':
69 | model = resnet18(num_classes = classes)
70 | elif args.model == 'resnet34':
71 | model = resnet34(num_classes = classes)
72 | elif args.model == 'resnet50':
73 | model = resnet50(num_classes = classes)
74 | elif args.model == 'resnet101':
75 | model = resnet101(num_classes = classes)
76 | elif args.model == 'resnet152':
77 | model = resnet152(num_classes = classes)
78 | else:
79 | print('Wrong resnet name, please choose one of these model: resnet18, resnet34, resnet50, resnet101, resnet152')
80 |
81 | model.build(input_shape=(None, args.image_size, args.image_size, args.image_channels))
82 | model.summary()
83 |
84 |
85 | if (args.optimizer == 'adam'):
86 | optimizer = Adam(learning_rate=args.lr)
87 | elif (args.optimizer == 'sgd'):
88 | optimizer = SGD(learning_rate=args.lr)
89 | elif (args.optimizer == 'rmsprop'):
90 | optimizer = RMSprop(learning_rate=args.lr)
91 | elif (args.optimizer == 'adadelta'):
92 | optimizer = Adadelta(learning_rate=args.lr)
93 | elif (args.optimizer == 'adamax'):
94 | optimizer = Adamax(learning_rate=args.lr)
95 | else:
96 | raise 'Invalid optimizer. Valid option: adam, sgd, rmsprop, adadelta, adamax'
97 |
98 |
99 |
100 | model.compile(optimizer=optimizer,
101 | loss=SparseCategoricalCrossentropy(),
102 | metrics=['accuracy'])
103 |
104 | best_model = ModelCheckpoint(args.model_path,
105 | save_weights_only=False,
106 | monitor='val_accuracy',
107 | verbose=1,
108 | mode='max',
109 | save_best_only=True)
110 | # Traning
111 | model.fit(
112 | train_generator,
113 | epochs=args.epochs,
114 | verbose=1,
115 | validation_data=val_generator,
116 | callbacks=[best_model])
--------------------------------------------------------------------------------