├── README.MD
├── convnext
└── model.py
├── figures
├── ConvNeXt.png
├── ConvNeXtModel.png
└── Logo.png
├── predict.py
├── resnet
├── __init__.py
├── components
│ ├── __init__.py
│ ├── block.py
│ └── stage.py
└── model.py
├── tests
├── data_test.py
└── model_test.py
└── train.py
/README.MD:
--------------------------------------------------------------------------------
1 | # A ConvNet for the 2020s
2 |
3 | Our implementation of paper: [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545), using [tensorflow 2](https://www.tensorflow.org/)
4 |
5 |
6 |
7 |
8 |
9 | This library is part of our project: Building an AI library with ProtonX
10 |
11 | ConvNeXt Architecture :
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | Authors:
22 | - Github: thinguyenkhtn
23 | - Email: thinguyenkhtn@gmail.com
24 |
25 | Advisors:
26 | - Github: https://github.com/bangoc123
27 | - Email: protonxai@gmail.com
28 |
29 | Reviewers:
30 | - @Khoi: https://github.com/NKNK-vn
31 | - @Quynh: https://github.com/quynhtl
32 |
33 | ## I. Set up environment
34 | - Step 1: Make sure you have installed Miniconda. If not yet, see the setup document [here](https://conda.io/en/latest/user-guide/install/index.html#regular-installation).
35 |
36 | - Step 2: Clone this repository: `git clone https://github.com/protonx-tf-04-projects/ConvNext-2020s`
37 |
38 | ## II. Set up your dataset
39 |
40 |
41 | 1. Download the data:
42 | - Download dataset [here](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip)
43 | 2. Extract file and put folder ```train``` and ```validation``` to ```./data```
44 | - train folder was used for the training process
45 | - validation folder was used for validating training result after each epoch
46 |
47 | This library use ImageDataGenerator API from Tensorflow 2.0 to load images. Make sure you have some understanding of how it works via [its document](https://keras.io/api/preprocessing/image/)
48 | Structure of these folders in ```./data```
49 |
50 | ```
51 | train/
52 | ...cats/
53 | ......cat.0.jpg
54 | ......cat.1.jpg
55 | ...dogs/
56 | ......dog.0.jpg
57 | ......dog.1.jpg
58 | ```
59 |
60 | ```
61 | validation/
62 | ...cats/
63 | ......cat.2000.jpg
64 | ......cat.2001.jpg
65 | ...dogs/
66 | ......dog.2000.jpg
67 | ......dog.2001.jpg
68 | ```
69 |
70 |
71 |
72 | ## III. Training Process
73 |
74 | Review training on colab:
75 |
76 |
77 |
78 | Training script:
79 |
80 | ```python
81 |
82 | !python train.py --train-folder ${train_folder} --valid-folder ${valid_folder} --num-classes ${num-classes} --image-size ${image-size} --lr ${lr} --batch-size ${batch-size} --model ${model} --epochs ${epochs}
83 |
84 | ```
85 |
86 | Example:
87 |
88 | ```python
89 |
90 | !python train.py --train-folder $train_folder --valid-folder $valid_folder --num-classes 2 --image-size 224 --lr 0.0001 --batch-size 32 --model tiny --epochs ${epochs}
91 |
92 | ```
93 |
94 | There are some important arguments for the script you should consider when running it:
95 |
96 | - `train-folder`: The folder of training data
97 | - `valid-folder`: The folder of validation data
98 | - `model-folder`: Where the model after training saved
99 | - `num-classes`: The number of your problem classes.
100 | - `batch-size`: The batch size of the dataset
101 | - `image-size`: The image size of the dataset
102 | - `lr`: The learning rate
103 | - `model`: Type of ConvNeXt model, valid option: tiny, small, base, large, xlarge
104 |
105 | ## IV. Predict Process
106 |
107 | ```bash
108 | python predict.py --test-data ${link_to_test_data}
109 | ```
110 |
111 | ## V. Result and Comparision
112 |
113 | Your implementation
114 | ```
115 | Epoch 195: val_accuracy did not improve from 0.81000
116 | 63/63 [==============================] - 74s 1s/step - loss: 0.1756 - accuracy: 0.9300 - val_loss: 0.5760 - val_accuracy: 0.7930
117 | Epoch 196/200
118 | 63/63 [==============================] - ETA: 0s - loss: 0.1788 - accuracy: 0.9270
119 | Epoch 196: val_accuracy did not improve from 0.81000
120 | 63/63 [==============================] - 74s 1s/step - loss: 0.1788 - accuracy: 0.9270 - val_loss: 0.5847 - val_accuracy: 0.7890
121 | Epoch 197/200
122 | 63/63 [==============================] - ETA: 0s - loss: 0.1796 - accuracy: 0.9290
123 | Epoch 197: val_accuracy did not improve from 0.81000
124 | 63/63 [==============================] - 74s 1s/step - loss: 0.1796 - accuracy: 0.9290 - val_loss: 0.5185 - val_accuracy: 0.7840
125 | Epoch 198/200
126 | 63/63 [==============================] - ETA: 0s - loss: 0.1768 - accuracy: 0.9290
127 | Epoch 198: val_accuracy did not improve from 0.81000
128 | 63/63 [==============================] - 74s 1s/step - loss: 0.1768 - accuracy: 0.9290 - val_loss: 0.5624 - val_accuracy: 0.7870
129 | Epoch 199/200
130 | 63/63 [==============================] - ETA: 0s - loss: 0.1744 - accuracy: 0.9340
131 | Epoch 199: val_accuracy did not improve from 0.81000
132 | 63/63 [==============================] - 74s 1s/step - loss: 0.1744 - accuracy: 0.9340 - val_loss: 0.5416 - val_accuracy: 0.7790
133 | Epoch 200/200
134 | 63/63 [==============================] - ETA: 0s - loss: 0.1995 - accuracy: 0.9230
135 | Epoch 200: val_accuracy did not improve from 0.81000
136 | 63/63 [==============================] - 74s 1s/step - loss: 0.1995 - accuracy: 0.9230 - val_loss: 0.4909 - val_accuracy: 0.7930
137 | ```
138 |
139 | ## VI. Feedback
140 | If you meet any issues when using this library, please let us know via the issues submission tab.
141 |
142 |
143 |
--------------------------------------------------------------------------------
/convnext/model.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.models import Model
2 | from tensorflow.keras import Sequential
3 | from tensorflow.keras.layers.experimental.preprocessing import Resizing, RandomFlip, RandomRotation, RandomZoom, Rescaling
4 | from resnet import build_convnext
5 |
6 |
7 | class ConvNeXt(Model):
8 | def __init__(self, num_classes=10, image_size=224, layer=[3, 3, 9, 3], model='tiny'):
9 | """
10 | ConvNeXt Model
11 | Parameters
12 | ----------
13 | num_classes:
14 | number of classes
15 | image_size: int,
16 | size of a image (H or W)
17 | """
18 | super(ConvNeXt, self).__init__()
19 | # Compute ratio
20 | input_shape = (image_size, image_size, 3)
21 |
22 | self.ratio = build_convnext(
23 | input_shape, num_classes, layer, model_name=model)
24 |
25 | def call(self, inputs):
26 | # ratio
27 | # output shape: (..., num_classes)
28 | output = self.ratio(inputs)
29 |
30 | return output
31 |
32 |
33 | class ConvNeXtTiny(ConvNeXt):
34 | def __init__(self, num_classes=10, image_size=224):
35 | super().__init__(num_classes=num_classes,
36 | image_size=image_size, layer=[3, 3, 9, 3], model='tiny')
37 |
38 |
39 | class ConvNeXtSmall(ConvNeXt):
40 | def __init__(self, num_classes=10, image_size=224):
41 | super().__init__(num_classes=num_classes,
42 | image_size=image_size, layer=[3, 3, 27, 3], model='small')
43 |
44 |
45 | class ConvNeXtBase(ConvNeXt):
46 | def __init__(self, num_classes=10, image_size=224):
47 | super().__init__(num_classes=num_classes,
48 | image_size=image_size, layer=[3, 3, 27, 3], model='base')
49 |
50 |
51 | class ConvNeXtLarge(ConvNeXt):
52 | def __init__(self, num_classes=10, image_size=224):
53 | super().__init__(num_classes=num_classes,
54 | image_size=image_size, layer=[3, 3, 27, 3], model='large')
55 |
56 |
57 | class ConvNeXtMXLarge(ConvNeXt):
58 | def __init__(self, num_classes=10, image_size=224):
59 | super().__init__(num_classes=num_classes, image_size=image_size,
60 | layer=[3, 3, 27, 3], model='xlarge')
--------------------------------------------------------------------------------
/figures/ConvNeXt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-04-projects/ConvNext-2020s/560a4767faff9705c70bbe01da47dec64f7e7cae/figures/ConvNeXt.png
--------------------------------------------------------------------------------
/figures/ConvNeXtModel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-04-projects/ConvNext-2020s/560a4767faff9705c70bbe01da47dec64f7e7cae/figures/ConvNeXtModel.png
--------------------------------------------------------------------------------
/figures/Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-04-projects/ConvNext-2020s/560a4767faff9705c70bbe01da47dec64f7e7cae/figures/Logo.png
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.models import load_model
2 | from tensorflow.keras import preprocessing
3 | from argparse import ArgumentParser
4 | import numpy as np
5 |
6 | if __name__ == "__main__":
7 | parser = ArgumentParser()
8 | parser.add_argument(
9 | "--test-image", default='./test.png', type=str, required=True)
10 | parser.add_argument(
11 | "--model-folder", default='output/', type=str)
12 |
13 | args = parser.parse_args()
14 |
15 | # Project Description
16 |
17 | print('---------------------Welcome to ConvNeXt-------------------')
18 | print('Github: thinguyenkhtn')
19 | print('Email: thinguyenkhtn@gmail.com')
20 | print('---------------------------------------------------------------------')
21 | print('Training ConvNeXt-2020s model with hyper-params:')
22 | print('===========================')
23 |
24 | # Loading Model
25 | model = load_model(args.model_folder)
26 |
27 | # Load test image
28 | image = preprocessing.image.load_img(args.test_image, target_size=(224, 224))
29 | input_arr = preprocessing.image.img_to_array(image)
30 | x = np.array([input_arr])
31 |
32 | predictions = model.predict(x)
33 | print('Result: {}'.format(np.argmax(predictions), axis=1))
34 |
35 |
--------------------------------------------------------------------------------
/resnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import build_convnext
2 |
--------------------------------------------------------------------------------
/resnet/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .stage import stage, downsample
2 |
3 |
--------------------------------------------------------------------------------
/resnet/components/block.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import Conv2D, LayerNormalization, ReLU, Add, DepthwiseConv2D
3 |
4 | def downsampleblock(input, filter_num, stage_idx=-1, block_idx=-1):
5 | '''A complete `Downsample` of ResNet
6 |
7 | Args:
8 | filter_num: the number of filters in the convolution
9 | stage_idx: index of current stage
10 | block_idx: index of current block in stage
11 | '''
12 |
13 | # Downsample
14 | down = input
15 | if block_idx > 0:
16 | down = Conv2D(filters=filter_num,
17 | kernel_size=2,
18 | strides=2,
19 | padding='same',
20 | kernel_initializer='he_normal',
21 | name='conv{}_block{}_downsample_conv'.format(stage_idx, block_idx))(input)
22 |
23 | return down
24 |
25 | def microblock(input, filter_num, stage_idx=-1, block_idx=-1):
26 | '''Large Kernel use stack of 2 layers: Depthwise_Layer and Pointwise_Layer
27 |
28 | Args:
29 | filter_num: the number of filters in the convolution
30 | stage_idx: index of current stage
31 | block_idx: index of current block in stage
32 | '''
33 |
34 | # Depthwise_Layer
35 | depthwise = DepthwiseConv2D(
36 | kernel_size=7, strides=1, padding='same')(input)
37 |
38 | nn1 = LayerNormalization(name='conv{}_block{}_1_nn'.format(
39 | stage_idx, block_idx))(depthwise)
40 |
41 | # Pointwise_Layer
42 | conv1 = Conv2D(filters=4*filter_num,
43 | kernel_size=1,
44 | strides=1,
45 | padding='same',
46 | kernel_initializer='he_normal',
47 | name='conv{}_block{}_1_conv'.format(stage_idx, block_idx))(nn1)
48 | gelu = tf.nn.gelu(conv1)
49 |
50 | # Pointwise_Layer
51 | conv2 = Conv2D(filters=filter_num,
52 | kernel_size=1,
53 | strides=1,
54 | padding='same',
55 | kernel_initializer='he_normal',
56 | name='conv{}_block{}_2_conv'.format(stage_idx, block_idx))(gelu)
57 |
58 | return conv2
59 |
60 | def resblock(input, filter_num, stage_idx=-1, block_idx=-1):
61 | '''A complete `Residual Unit` of ResNet
62 |
63 | Args:
64 | filter_num: the number of filters in the convolution
65 | stage_idx: index of current stage
66 | block_idx: index of current block in stage
67 | '''
68 |
69 | residual = microblock(input, filter_num, stage_idx, block_idx)
70 |
71 | output = Add(name='conv{}_block{}_add'.format(
72 | stage_idx, block_idx))([input, residual])
73 |
74 | return output
75 |
--------------------------------------------------------------------------------
/resnet/components/stage.py:
--------------------------------------------------------------------------------
1 | from .block import resblock, downsampleblock
2 |
3 | def downsample(input, filter_num, block_idx, stage_idx=-1):
4 | ''' -- Stacking Residual Units on the same stage
5 |
6 | Args:
7 | filter_num: the number of filters in the convolution used during stage
8 | num_block: number of `Residual Unit` in a stage
9 | stage_idx: index of current stage
10 | '''
11 |
12 | net = input
13 | net = downsampleblock(input=net, filter_num=filter_num,
14 | stage_idx=stage_idx, block_idx=block_idx)
15 |
16 | return net
17 |
18 | def stage(input, filter_num, num_block, stage_idx=-1):
19 | ''' -- Stacking Residual Units on the same stage
20 |
21 | Args:
22 | filter_num: the number of filters in the convolution used during stage
23 | num_block: number of `Residual Unit` in a stage
24 | stage_idx: index of current stage
25 | '''
26 |
27 | net = input
28 | for i in range(num_block):
29 | net = resblock(input=net, filter_num=filter_num,
30 | stage_idx=stage_idx, block_idx=i+1)
31 |
32 | return net
33 |
--------------------------------------------------------------------------------
/resnet/model.py:
--------------------------------------------------------------------------------
1 | from .components import stage, downsample
2 | from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, LayerNormalization
3 | from tensorflow.keras import Model
4 |
5 |
6 | def build_convnext(input_shape, num_classes, layers, model_name='base'):
7 | '''A complete `stage` of ConvNeXt
8 | '''
9 | input = Input(input_shape, name='input')
10 |
11 | filter_cnn = 96
12 | filters = [96, 192, 384, 768]
13 |
14 | if model_name == 'base':
15 | filter_cnn = 128
16 | filters = [128, 256, 512, 1024]
17 | elif model_name == 'large':
18 | filter_cnn = 192
19 | filters = [192, 384, 768, 1536]
20 | elif model_name == 'xlarge':
21 | filter_cnn = 256
22 | filters = [256, 512, 1024, 2048]
23 |
24 | # conv1
25 | # change replace the ResNet-style stem cell
26 | # with a patchify layer implemented using a 4×4, stride 4 convolutional layer
27 | net = Conv2D(filters=filter_cnn,
28 | kernel_size=4,
29 | strides=4,
30 | padding='same',
31 | kernel_initializer='he_normal',
32 | name='conv1_conv')(input)
33 |
34 | # conv2_x, conv3_x, conv4_x, conv5_x
35 | for i in range(len(filters)):
36 | net = downsample(input=net,
37 | filter_num=filters[i],
38 | block_idx=i,
39 | stage_idx=i+2)
40 | net = stage(input=net,
41 | filter_num=filters[i],
42 | num_block=layers[i],
43 | stage_idx=i+2)
44 |
45 | net = GlobalAveragePooling2D(name='avg_pool')(net)
46 | net = LayerNormalization(name='norm')(net)
47 | output = Dense(num_classes, activation='softmax', name='predictions')(net)
48 | model = Model(input, output)
49 |
50 | return model
--------------------------------------------------------------------------------
/tests/data_test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-04-projects/ConvNext-2020s/560a4767faff9705c70bbe01da47dec64f7e7cae/tests/data_test.py
--------------------------------------------------------------------------------
/tests/model_test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protonx-tf-04-projects/ConvNext-2020s/560a4767faff9705c70bbe01da47dec64f7e7cae/tests/model_test.py
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from convnext.model import ConvNeXt, ConvNeXtTiny, ConvNeXtSmall, ConvNeXtBase, ConvNeXtLarge, ConvNeXtMXLarge
2 | from tensorflow.keras.losses import SparseCategoricalCrossentropy
3 | from tensorflow.keras.optimizers import Adam
4 | from argparse import ArgumentParser
5 | from tensorflow.keras.callbacks import ModelCheckpoint
6 | from keras_preprocessing.image import ImageDataGenerator
7 | import tensorflow_addons as tfa
8 |
9 | if __name__ == "__main__":
10 | parser = ArgumentParser()
11 |
12 | # Arguments users used when running command lines
13 | parser.add_argument('--model', default='tiny', type=str,
14 | help='Type of ConvNeXt model, valid option: tiny, small, base, large, xlarge')
15 | parser.add_argument('--lr', default=0.001,
16 | type=float, help='Learning rate')
17 | parser.add_argument('--weight-decay', default=1e-4,
18 | type=float, help='Weight decay')
19 | parser.add_argument("--batch-size", default=32, type=int)
20 | parser.add_argument("--epochs", default=1000, type=int)
21 | parser.add_argument('--num-classes', default=10,
22 | type=int, help='Number of classes')
23 | parser.add_argument('--image-size', default=224,
24 | type=int, help='Size of input image')
25 | parser.add_argument('--image-channels', default=3,
26 | type=int, help='Number channel of input image')
27 | parser.add_argument('--train-folder', default='', type=str,
28 | help='Where training data is located')
29 | parser.add_argument('--valid-folder', default='', type=str,
30 | help='Where validation data is located')
31 | parser.add_argument('--class-mode', default='sparse',
32 | type=str, help='Class mode to compile')
33 | parser.add_argument('--model-folder', default='output/',
34 | type=str, help='Folder to save trained model')
35 |
36 | args = parser.parse_args()
37 |
38 | # Project Description
39 |
40 | print('---------------------Welcome to ConvNext-2020s paper Implementation-------------------')
41 | print('Github: thinguyenkhtn')
42 | print('Email: thinguyenkhtn@gmail.com')
43 | print('---------------------------------------------------------------------')
44 | print('Training ConvNext2020s model with hyper-params:')
45 | print('===========================')
46 | for i, arg in enumerate(vars(args)):
47 | print('{}.{}: {}'.format(i, arg, vars(args)[arg]))
48 | print('===========================')
49 |
50 | # Assign arguments to variables to avoid repetition
51 | train_folder = args.train_folder
52 | valid_folder = args.valid_folder
53 | batch_size = args.batch_size
54 | image_size = args.image_size
55 | image_channels = args.image_channels
56 | num_classes = args.num_classes
57 | epoch = args.epochs
58 | class_mode = args.class_mode
59 | lr = args.lr
60 | weight_decay = args.weight_decay
61 |
62 | # Data loader
63 | training_datagen = ImageDataGenerator(
64 | rescale=1. / 255,
65 | rotation_range=20,
66 | width_shift_range=0.2,
67 | height_shift_range=0.2,
68 | shear_range=0.2,
69 | zoom_range=0.2)
70 | val_datagen = ImageDataGenerator(rescale=1. / 255)
71 |
72 | train_generator = training_datagen.flow_from_directory(train_folder, target_size=(image_size, image_size), batch_size= batch_size, class_mode = class_mode )
73 | val_generator = val_datagen.flow_from_directory(valid_folder, target_size=(image_size, image_size), batch_size= batch_size, class_mode = class_mode)
74 |
75 | # ConvNeXt
76 | if args.model == 'tiny':
77 | model = ConvNeXtTiny()
78 | elif args.model == 'small':
79 | model = ConvNeXtSmall()
80 | elif args.model == 'base':
81 | model = ConvNeXtBase()
82 | elif args.model == 'large':
83 | model = ConvNeXtLarge()
84 | elif args.model == 'xlarge':
85 | model = ConvNeXtMXLarge()
86 | else:
87 | model = ConvNeXt(
88 | num_classes=num_classes,
89 | image_size=image_size
90 | )
91 |
92 | model.build(input_shape=(None, image_size,
93 | image_size, image_channels))
94 |
95 | optimizer = tfa.optimizers.AdamW(
96 | learning_rate=lr, weight_decay=weight_decay)
97 |
98 | model.compile(optimizer=optimizer,
99 | loss=SparseCategoricalCrossentropy(),
100 | metrics=['accuracy'])
101 |
102 | best_model = ModelCheckpoint(args.model_folder,
103 | save_weights_only=False,
104 | monitor='val_accuracy',
105 | verbose=1,
106 | mode='max',
107 | save_best_only=True)
108 | # Traning
109 | model.fit(
110 | train_generator,
111 | epochs=args.epochs,
112 | verbose=1,
113 | validation_data=val_generator,
114 | callbacks=[best_model])
115 |
116 | # Save model
117 | model.save(args.model_folder)
118 |
--------------------------------------------------------------------------------