├── .gitignore ├── LICENSE ├── README.md ├── model ├── build.py └── convnext.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dongkyun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Implementation of "A ConvNet for the 2020s" 2 | 3 | This is the unofficial Tensorflow Implementation of ConvNeXt from "A ConvNet for the 2020s" 4 | 5 | paper: https://arxiv.org/abs/2201.03545 6 | 7 | reference offical pytorch code: https://github.com/facebookresearch/ConvNeXt 8 | 9 | ## Prerequisites 10 | 11 | + tensorflow 12 | + tensorflow_addons 13 | + python3 14 | + CUDA 15 | 16 | ## Citations 17 | 18 | ```bibtex 19 | @misc{liu2022convnet, 20 | title={A ConvNet for the 2020s}, 21 | author={Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, 22 | year={2022}, 23 | eprint={2201.03545}, 24 | archivePrefix={arXiv}, 25 | primaryClass={cs.CV} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /model/build.py: -------------------------------------------------------------------------------- 1 | from model.convnext import ConvNext 2 | 3 | depths_dims = dict( 4 | convnext_tiny = (dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])), 5 | convnext_small = (dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])), 6 | convnext_base = (dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])), 7 | convnext_large = (dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])), 8 | convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])), 9 | ) 10 | 11 | def build_model(name, **kwargs): 12 | model = ConvNext(**depths_dims[name], **kwargs) 13 | return model -------------------------------------------------------------------------------- /model/convnext.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_addons as tfa 3 | import tensorflow.keras.layers as layers 4 | 5 | 6 | kernel_initial = tf.keras.initializers.TruncatedNormal(stddev=0.2) 7 | bias_initial = tf.keras.initializers.Constant(value=0) 8 | 9 | 10 | class Downsampling(tf.keras.Sequential): 11 | def __init__(self, out_dim): 12 | super(Downsampling, self).__init__([ 13 | layers.LayerNormalization(), 14 | layers.Conv2D( 15 | out_dim, kernel_size=2, strides=2, padding='same', 16 | kernel_initializer=kernel_initial, bias_initializer=bias_initial 17 | ) 18 | ]) 19 | 20 | 21 | class ConvNextBlock(layers.Layer): 22 | """ConvNeXt Block using implementation 1 23 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 24 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 25 | 26 | Args: 27 | dim (int): Number of input channels. 28 | drop_path (float): Stochastic depth rate. Default: 0.0 29 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 30 | """ 31 | 32 | def __init__(self, dim, drop_prob=0, layer_scale_init_value=1e-6): 33 | super().__init__() 34 | self.layers = tf.keras.Sequential([ 35 | layers.Conv2D(dim, kernel_size=7, padding="same", groups=dim, 36 | kernel_initializer=kernel_initial, bias_initializer=bias_initial), 37 | layers.LayerNormalization(), 38 | layers.Conv2D(dim*4, kernel_size=1, padding="valid", 39 | kernel_initializer=kernel_initial, bias_initializer=bias_initial), 40 | layers.Activation('gelu'), 41 | layers.Conv2D(dim, kernel_size=1, padding="valid", 42 | kernel_initializer=kernel_initial, bias_initializer=bias_initial), 43 | ]) 44 | if layer_scale_init_value > 0: 45 | self.layer_scale_gamma = tf.Variable( 46 | initial_value=layer_scale_init_value*tf.ones((dim))) 47 | else: 48 | self.layer_scale_gamma = None 49 | self.stochastic_depth = tfa.layers.StochasticDepth(drop_prob) 50 | 51 | def call(self, x): 52 | skip = x 53 | x = self.layers(x) 54 | if self.layer_scale_gamma is not None: 55 | x = x * self.layer_scale_gamma 56 | x = self.stochastic_depth([skip, x]) 57 | return x 58 | 59 | 60 | class ConvNext(tf.keras.Model): 61 | """ A Tensorflow impl of : `A ConvNet for the 2020s` 62 | https://arxiv.org/pdf/2201.03545.pdf 63 | Args: 64 | in_chans (int): Number of input image channels. Default: 3 65 | num_classes (int): Number of classes for classification head. Default: 1000 66 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 67 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 68 | drop_path_rate (float): Stochastic depth rate. Default: 0. 69 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 70 | """ 71 | 72 | def __init__(self, num_classes=1000, 73 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 74 | layer_scale_init_value=1e-6): 75 | super().__init__() 76 | self.downsample_layers = [] 77 | self.downsample_layers.append(tf.keras.Sequential([ 78 | layers.Conv2D(dims[0], kernel_size=4, strides=4, padding="valid"), 79 | layers.LayerNormalization() 80 | ])) 81 | self.downsample_layers += [Downsampling(dim) for dim in dims[1:]] 82 | self.convnext_blocks = [tf.keras.Sequential([ConvNextBlock(dim, drop_path_rate, layer_scale_init_value) for _ in range( 83 | depths[i])]) for i, dim in enumerate(dims)] 84 | self.head = layers.Dense( 85 | num_classes, kernel_initializer=kernel_initial, bias_initializer=bias_initial) 86 | self.gap = layers.GlobalAveragePooling2D() 87 | self.norm = layers.LayerNormalization() 88 | 89 | def call_features(self, x): 90 | for i in range(4): 91 | x = self.downsample_layers[i](x) 92 | x = self.convnext_blocks[i](x) 93 | x = self.gap(x) 94 | return self.norm(x) 95 | 96 | def call(self, x): 97 | x = self.call_features(x) 98 | x = self.head(x) 99 | return x 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model.build import build_model 3 | 4 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data(label_mode='fine') 5 | x_train = x_train.astype('float32') / 255 6 | x_test = x_test.astype('float32') / 255 7 | 8 | model = build_model('convnext_tiny', num_classes=100) 9 | model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 10 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 11 | metrics=['accuracy']) 12 | model.fit(x_train, y_train, batch_size=64, epochs=10) --------------------------------------------------------------------------------