├── .gitignore ├── requirements.txt ├── hub_utilities ├── README.md ├── export_to_hub.py └── generate_doc.py ├── i1k_eval ├── README.md ├── eval.ipynb └── imagenet_class_index.json ├── models ├── model_configs.py ├── convnext_tf.py └── convnext.py ├── convert_all_available_models.py ├── README.md ├── notebooks ├── classification.ipynb └── finetune.ipynb ├── convert.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .DS_Store 3 | .idea 4 | *.pb 5 | *.pyc 6 | saved_models/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.7.0 2 | torch==1.10.1 3 | torchvision==0.11.2 4 | timm==0.4.12 5 | ml_collections==0.1.0 -------------------------------------------------------------------------------- /hub_utilities/README.md: -------------------------------------------------------------------------------- 1 | The scripts contained in this directory are somewhat closely related to TF-Hub. Following utilities are supported: 2 | 3 | * `export_to_hub.py`: Exports a bulk of SavedModels as `tar.gz` archives needed by TF-Hub. 4 | * `generate_doc.py`: Generates documentation for a bulk of models. 5 | -------------------------------------------------------------------------------- /i1k_eval/README.md: -------------------------------------------------------------------------------- 1 | This directory provides a notebook and ImageNet-1k class mapping file to run 2 | evaluation on the ImageNet-1k `val` split using the TF/Keras converted ConvNeXt 3 | models. The notebook assumes the following files are present in your working 4 | directory: 5 | 6 | * The `val` split directory of ImageNet-1k. 7 | * The class mapping files (`.json`). -------------------------------------------------------------------------------- /models/model_configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuratiionns for different ConvNeXt variants. 3 | 4 | Referred from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 5 | """ 6 | 7 | 8 | import ml_collections 9 | 10 | 11 | def convnext_tiny_config() -> ml_collections.ConfigDict: 12 | configs = ml_collections.ConfigDict() 13 | configs.depths = [3, 3, 9, 3] 14 | configs.dims = [96, 192, 384, 768] 15 | return configs 16 | 17 | 18 | def convnext_small_config() -> ml_collections.ConfigDict: 19 | configs = convnext_tiny_config() 20 | configs.depths = [3, 3, 27, 3] 21 | return configs 22 | 23 | 24 | def convnext_base_config() -> ml_collections.ConfigDict: 25 | configs = convnext_small_config() 26 | configs.dims = [128, 256, 512, 1024] 27 | return configs 28 | 29 | 30 | def convnext_large_config() -> ml_collections.ConfigDict: 31 | configs = convnext_base_config() 32 | configs.dims = [192, 384, 768, 1536] 33 | return configs 34 | 35 | 36 | def convnext_xlarge_config() -> ml_collections.ConfigDict: 37 | configs = convnext_large_config() 38 | configs.dims = [256, 512, 1024, 2048] 39 | return configs 40 | 41 | 42 | def get_model_config(model_name: str) -> ml_collections.ConfigDict: 43 | if model_name == "convnext_tiny": 44 | return convnext_tiny_config() 45 | elif model_name == "convnext_small": 46 | return convnext_small_config() 47 | elif model_name == "convnext_base": 48 | return convnext_base_config() 49 | elif model_name == "convnext_large": 50 | return convnext_large_config() 51 | else: 52 | return convnext_xlarge_config() 53 | -------------------------------------------------------------------------------- /hub_utilities/export_to_hub.py: -------------------------------------------------------------------------------- 1 | """Generates .tar.gz archives from SavedModels and serializes them.""" 2 | 3 | 4 | from typing import List 5 | import tensorflow as tf 6 | import os 7 | 8 | 9 | TF_MODEL_ROOT = "gs://convnext/saved_models" 10 | TAR_ARCHIVES = os.path.join(TF_MODEL_ROOT, "tars/") 11 | 12 | 13 | def generate_fe(model: tf.keras.Model) -> tf.keras.Model: 14 | """Generates a feature extractor from a classifier.""" 15 | feature_extractor = tf.keras.Model(model.inputs, model.layers[-2].output) 16 | return feature_extractor 17 | 18 | 19 | def prepare_archive(model_name: str) -> None: 20 | """Prepares a tar archive.""" 21 | archive_name = f"{model_name}.tar.gz" 22 | print(f"Archiving to {archive_name}.") 23 | archive_command = f"cd {model_name} && tar -czvf ../{archive_name} *" 24 | os.system(archive_command) 25 | os.system(f"rm -rf {model_name}") 26 | 27 | 28 | def save_to_gcs(model_paths: List[str]) -> None: 29 | """Prepares tar archives and saves them inside a GCS bucket.""" 30 | for path in model_paths: 31 | print(f"Preparing classification model: {path}.") 32 | model_name = path.strip("/") 33 | abs_model_path = os.path.join(TF_MODEL_ROOT, model_name) 34 | 35 | print(f"Copying from {abs_model_path}.") 36 | os.system(f"gsutil cp -r {abs_model_path} .") 37 | prepare_archive(model_name) 38 | 39 | print("Preparing feature extractor.") 40 | model = tf.keras.models.load_model(abs_model_path) 41 | fe_model = generate_fe(model) 42 | fe_model_name = f"{model_name}_fe" 43 | fe_model.save(fe_model_name) 44 | prepare_archive(fe_model_name) 45 | 46 | os.system(f"gsutil -m cp -r *.tar.gz {TAR_ARCHIVES}") 47 | os.system("rm -rf *.tar.gz") 48 | 49 | 50 | model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT) 51 | print(f"Total models: {len(model_paths)}.") 52 | 53 | print("Preparing archives for the classification and feature extractor models.") 54 | save_to_gcs(model_paths) 55 | -------------------------------------------------------------------------------- /convert_all_available_models.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | 4 | """ 5 | Details about these checkpoints are available here: 6 | https://github.com/facebookresearch/ConvNeXt#results-and-pre-trained-models. 7 | """ 8 | 9 | imagenet_1k_224 = { 10 | "convnext_tiny": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 11 | "convnext_small": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 12 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 13 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 14 | } 15 | 16 | imagenet_1k_384 = { 17 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth", 18 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth", 19 | } 20 | 21 | imagenet_21k_224 = { 22 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 23 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 24 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 25 | } 26 | 27 | imagenet_21k_1k_224 = { 28 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth", 29 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth", 30 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth", 31 | } 32 | 33 | imagenet_21k_1k_384 = { 34 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth", 35 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth", 36 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth", 37 | } 38 | 39 | print("Converting 224x224 resolution ImageNet-1k models.") 40 | for model in tqdm(imagenet_1k_224): 41 | print(f"Converting {model}.") 42 | command = f"python convert.py -m {model} -c {imagenet_1k_224[model]}" 43 | os.system(command) 44 | 45 | 46 | print("Converting 384x384 resolution ImageNet-1k models.") 47 | for model in tqdm(imagenet_1k_384): 48 | print(f"Converting {model}.") 49 | command = f"python convert.py -m {model} -c {imagenet_1k_384[model]} -r 384" 50 | os.system(command) 51 | 52 | 53 | print("Converting 224x224 resolution ImageNet-21k models.") 54 | for model in tqdm(imagenet_21k_224): 55 | print(f"Converting {model}.") 56 | command = f"python convert.py -d imagenet-21k -m {model} -c {imagenet_21k_224[model]} -r 224" 57 | os.system(command) 58 | 59 | 60 | print( 61 | "Converting 224x224 resolution ImageNet-21k trained ImageNet-1k fine-tuned models." 62 | ) 63 | for model in tqdm(imagenet_21k_1k_224): 64 | print(f"Converting {model}.") 65 | command = f"python convert.py -m {model} -c {imagenet_21k_1k_224[model]} -r 224" 66 | os.system(command) 67 | 68 | 69 | print( 70 | "Converting 384x384 resolution ImageNet-21k trained ImageNet-1k fine-tuned models." 71 | ) 72 | for model in tqdm(imagenet_21k_1k_384): 73 | print(f"Converting {model}.") 74 | command = f"python convert.py -m {model} -c {imagenet_21k_1k_384[model]} -r 384" 75 | os.system(command) 76 | -------------------------------------------------------------------------------- /hub_utilities/generate_doc.py: -------------------------------------------------------------------------------- 1 | """Generates model documentation for ConvNeXt-TF models. 2 | 3 | Credits: Willi Gierke 4 | """ 5 | 6 | from string import Template 7 | import attr 8 | import os 9 | 10 | template = Template( 11 | """# Module $HANDLE 12 | 13 | Fine-tunable ConvNeXt model pre-trained on the $DATASET_DESCRIPTION. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | ## Overview 24 | 25 | This model is a ConvNeXt [1] model pre-trained on the $DATASET_DESCRIPTION. You can find the complete 26 | collection of ConvNeXt models on TF-Hub on [this page](https://tfhub.dev/sayakpaul/collections/convnext/1). 27 | 28 | You can use this model for feature extraction and fine-tuning. Please refer to 29 | the Colab Notebook linked on this page for more details. 30 | 31 | ## Notes 32 | 33 | * The original model weights are provided from [2]. There were ported to Keras models 34 | (`tf.keras.Model`) and then serialized as TensorFlow SavedModels. The porting 35 | steps are available in [3]. 36 | * The model can be unrolled into a standard Keras model and you can inspect its topology. 37 | To do so, first download the model from TF-Hub and then load it using `tf.keras.models.load_model` 38 | providing the path to the downloaded model folder. 39 | 40 | ## References 41 | 42 | [1] [A ConvNet for the 2020s by Liu et al.](https://arxiv.org/abs/2201.03545) 43 | [2] [ConvNeXt GitHub](https://github.com/facebookresearch/ConvNeXt) 44 | [3] [ConvNeXt-TF GitHub](https://github.com/sayakpaul/ConvNeXt-TF) 45 | 46 | ## Acknowledgements 47 | 48 | * [Vasudev Gupta](https://github.com/vasudevgupta7) 49 | * [Gus](https://twitter.com/gusthema) 50 | * [Willi](https://ch.linkedin.com/in/willi-gierke) 51 | * [ML-GDE program](https://developers.google.com/programs/experts/) 52 | 53 | """ 54 | ) 55 | 56 | 57 | @attr.s 58 | class Config: 59 | size = attr.ib(type=str) 60 | dataset = attr.ib(type=str) 61 | single_resolution = attr.ib(type=int) 62 | 63 | def two_d_resolution(self): 64 | return f"{self.single_resolution}x{self.single_resolution}" 65 | 66 | def gcs_folder_name(self): 67 | return f"convnext_{self.size}_{self.dataset}_{self.single_resolution}_fe" 68 | 69 | def handle(self): 70 | return f"sayakpaul/{self.gcs_folder_name()}/1" 71 | 72 | def rel_doc_file_path(self): 73 | """Relative to the tfhub.dev directory.""" 74 | return f"assets/docs/{self.handle()}.md" 75 | 76 | 77 | for c in [ 78 | Config("tiny", "1k", 224), 79 | Config("small", "1k", 224), 80 | Config("base", "1k", 224), 81 | Config("base", "1k", 384), 82 | Config("large", "1k", 224), 83 | Config("large", "1k", 384), 84 | Config("base", "21k_1k", 224), 85 | Config("base", "21k_1k", 384), 86 | Config("large", "21k_1k", 224), 87 | Config("large", "21k_1k", 384), 88 | Config("xlarge", "21k_1k", 224), 89 | Config("xlarge", "21k_1k", 384), 90 | Config("base", "21k", 224), 91 | Config("large", "21k", 224), 92 | Config("xlarge", "21k", 224), 93 | ]: 94 | if c.dataset == "1k": 95 | dataset_text = "ImageNet-1k dataset" 96 | elif c.dataset == "21k": 97 | dataset_text = "ImageNet-21k dataset" 98 | else: 99 | dataset_text = ( 100 | "ImageNet-21k" 101 | " dataset and" 102 | " was then " 103 | "fine-tuned " 104 | "on the " 105 | "ImageNet-1k " 106 | "dataset" 107 | ) 108 | 109 | save_path = os.path.join( 110 | "/Users/sayakpaul/Downloads/", "tfhub.dev", c.rel_doc_file_path() 111 | ) 112 | model_folder = save_path.split("/")[-2] 113 | model_abs_path = "/".join(save_path.split("/")[:-1]) 114 | 115 | if not os.path.exists(model_abs_path): 116 | os.makedirs(model_abs_path, exist_ok=True) 117 | 118 | with open(save_path, "w") as f: 119 | f.write( 120 | template.substitute( 121 | HANDLE=c.handle(), 122 | DATASET_DESCRIPTION=dataset_text, 123 | INPUT_RESOLUTION=c.two_d_resolution(), 124 | ARCHIVE_NAME=c.gcs_folder_name(), 125 | ) 126 | ) 127 | -------------------------------------------------------------------------------- /models/convnext_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow import keras 4 | from tensorflow.keras import layers 5 | 6 | 7 | class StochasticDepth(layers.Layer): 8 | """Stochastic Depth module. 9 | 10 | It is also referred to as Drop Path in `timm`. 11 | References: 12 | (1) github.com:rwightman/pytorch-image-models 13 | """ 14 | 15 | def __init__(self, drop_path, **kwargs): 16 | super(StochasticDepth, self).__init__(**kwargs) 17 | self.drop_path = drop_path 18 | 19 | def call(self, x, training=None): 20 | if training: 21 | keep_prob = 1 - self.drop_path 22 | shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) 23 | random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) 24 | random_tensor = tf.floor(random_tensor) 25 | return (x / keep_prob) * random_tensor 26 | return x 27 | 28 | 29 | class Block(tf.keras.Model): 30 | """ConvNeXt block. 31 | 32 | References: 33 | (1) https://arxiv.org/abs/2201.03545 34 | (2) https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 35 | """ 36 | 37 | def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6, **kwargs): 38 | super(Block, self).__init__(**kwargs) 39 | self.dim = dim 40 | if layer_scale_init_value > 0: 41 | self.gamma = tf.Variable(layer_scale_init_value * tf.ones((dim,))) 42 | else: 43 | self.gamma = None 44 | self.dw_conv_1 = layers.Conv2D( 45 | filters=dim, kernel_size=7, padding="same", groups=dim 46 | ) 47 | self.layer_norm = layers.LayerNormalization(epsilon=1e-6) 48 | self.pw_conv_1 = layers.Dense(4 * dim) 49 | self.act_fn = layers.Activation("gelu") 50 | self.pw_conv_2 = layers.Dense(dim) 51 | self.drop_path = ( 52 | StochasticDepth(drop_path) 53 | if drop_path > 0.0 54 | else layers.Activation("linear") 55 | ) 56 | 57 | def call(self, inputs): 58 | x = inputs 59 | 60 | x = self.dw_conv_1(x) 61 | x = self.layer_norm(x) 62 | x = self.pw_conv_1(x) 63 | x = self.act_fn(x) 64 | x = self.pw_conv_2(x) 65 | 66 | if self.gamma is not None: 67 | x = self.gamma * x 68 | 69 | return inputs + self.drop_path(x) 70 | 71 | 72 | def get_convnext_model( 73 | model_name="convnext_tiny_1k", 74 | input_shape=(224, 224, 3), 75 | num_classes=1000, 76 | depths=[3, 3, 9, 3], 77 | dims=[96, 192, 384, 768], 78 | drop_path_rate=0.0, 79 | layer_scale_init_value=1e-6, 80 | ) -> keras.Model: 81 | """Implements ConvNeXt family of models given a configuration. 82 | 83 | References: 84 | (1) https://arxiv.org/abs/2201.03545 85 | (2) https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 86 | 87 | Note: `predict()` fails on CPUs because of group convolutions. The fix is recent at 88 | the time of the development: https://github.com/keras-team/keras/pull/15868. It's 89 | recommended to use a GPU / TPU. 90 | """ 91 | 92 | inputs = layers.Input(input_shape) 93 | stem = keras.Sequential( 94 | [ 95 | layers.Conv2D(dims[0], kernel_size=4, strides=4), 96 | layers.LayerNormalization(epsilon=1e-6), 97 | ], 98 | name="stem", 99 | ) 100 | 101 | downsample_layers = [] 102 | downsample_layers.append(stem) 103 | for i in range(3): 104 | downsample_layer = keras.Sequential( 105 | [ 106 | layers.LayerNormalization(epsilon=1e-6), 107 | layers.Conv2D(dims[i + 1], kernel_size=2, strides=2), 108 | ], 109 | name=f"downsampling_block_{i}", 110 | ) 111 | downsample_layers.append(downsample_layer) 112 | 113 | stages = [] 114 | dp_rates = [x for x in tf.linspace(0.0, drop_path_rate, sum(depths))] 115 | cur = 0 116 | for i in range(4): 117 | stage = keras.Sequential( 118 | [ 119 | *[ 120 | Block( 121 | dim=dims[i], 122 | drop_path=dp_rates[cur + j], 123 | layer_scale_init_value=layer_scale_init_value, 124 | name=f"convnext_block_{i}_{j}", 125 | ) 126 | for j in range(depths[i]) 127 | ] 128 | ], 129 | name=f"convnext_stage_{i}", 130 | ) 131 | stages.append(stage) 132 | cur += depths[i] 133 | 134 | x = inputs 135 | for i in range(len(stages)): 136 | x = downsample_layers[i](x) 137 | x = stages[i](x) 138 | 139 | x = layers.GlobalAvgPool2D()(x) 140 | x = layers.LayerNormalization(epsilon=1e-6)(x) 141 | 142 | outputs = layers.Dense(num_classes, name="classification_head")(x) 143 | 144 | return keras.Model(inputs, outputs, name=model_name) 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvNeXt-TF 2 | 3 | This repository provides TensorFlow / Keras implementations of different ConvNeXt 4 | [1] variants. It also provides the TensorFlow / Keras models that have been 5 | populated with the original ConvNeXt pre-trained weights available from [2]. These 6 | models are not blackbox SavedModels i.e., they can be fully expanded into `tf.keras.Model` 7 | objects and one can call all the utility functions on them (example: `.summary()`). 8 | 9 | As of today, all the TensorFlow / Keras variants of the models listed 10 | [here](https://github.com/facebookresearch/ConvNeXt#results-and-pre-trained-models) 11 | are available in this repository except for the 12 | [isotropic ones](https://github.com/facebookresearch/ConvNeXt#imagenet-1k-trained-models-isotropic). 13 | This list includes the ImageNet-1k as well as ImageNet-21k models. 14 | 15 | Refer to the ["Using the models"](https://github.com/sayakpaul/ConvNeXt-TF#using-the-models) 16 | section to get started. Additionally, here's a [related blog post](https://sayak.dev/convnext-tfhub/) 17 | that jots down my experience. 18 | 19 | ## Conversion 20 | 21 | TensorFlow / Keras implementations are available in `models/convnext_tf.py`. 22 | Conversion utilities are in `convert.py`. 23 | 24 | ## Models 25 | 26 | The converted models are available on [TF-Hub](https://tfhub.dev/sayakpaul/collections/convnext/1). 27 | 28 | There should be a total of 15 different models each having two variants: classifier and 29 | feature extractor. You can load any model and get started like so: 30 | 31 | ```py 32 | import tensorflow as tf 33 | 34 | model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed" 35 | model = tf.keras.models.load_model(model_gcs_path) 36 | print(model.summary(expand_nested=True)) 37 | ``` 38 | 39 | The model names are interpreted as follows: 40 | 41 | * `convnext_large_21k_1k_384`: This means that the model was first pre-trained 42 | on the ImageNet-21k dataset and was then fine-tuned on the ImageNet-1k dataset. 43 | Resolution used during pre-training and fine-tuning: 384x384. `large` denotes 44 | the topology of the underlying model. 45 | * `convnext_large_1k_224`: Means that the model was pre-trained on the ImageNet-1k 46 | dataset with a resolution of 224x224. 47 | 48 | ## Results 49 | 50 | Results are on ImageNet-1k validation set (top-1 accuracy). 51 | 52 | | name | original acc@1 | keras acc@1 | 53 | |:---:|:---:|:---:| 54 | | convnext_tiny_1k_224 | 82.1 | 81.312 | 55 | | convnext_small_1k_224 | 83.1 | 82.392 | 56 | | convnext_base_1k_224 | 83.8 | 83.28 | 57 | | convnext_base_1k_384 | 85.1 | 84.876 | 58 | | convnext_large_1k_224 | 84.3 | 83.844 | 59 | | convnext_large_1k_384 | 85.5 | 85.376 | 60 | | | | | 61 | | convnext_base_21k_1k_224 | 85.8 | 85.364 | 62 | | convnext_base_21k_1k_384 | 86.8 | 86.79 | 63 | | convnext_large_21k_1k_224 | 86.6 | 86.36 | 64 | | convnext_large_21k_1k_384 | 87.5 | 87.504 | 65 | | convnext_xlarge_21k_1k_224 | 87.0 | 86.732 | 66 | | convnext_xlarge_21k_1k_384 | 87.8 | 87.68 | 67 | 68 | Differences in the results are primarily because of the differences in the library 69 | implementations especially how image resizing is implemented in PyTorch and 70 | TensorFlow. Results can be verified with the code in `i1k_eval`. Logs 71 | are available at [this URL](https://tensorboard.dev/experiment/odN7OPCqQvGYCRpJP1GhRQ/). 72 | 73 | ## Using the models 74 | 75 | **Pre-trained models**: 76 | 77 | * Off-the-shelf classification: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/ConvNeXt-TF/blob/main/notebooks/classification.ipynb) 78 | * Fine-tuning: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/ConvNeXt-TF/blob/main/notebooks/finetune.ipynb) 79 | 80 | **Randomly initialized models**: 81 | 82 | ```py 83 | from models.convnext_tf import get_convnext_model 84 | 85 | convnext_tiny = get_convnext_model() 86 | print(convnext_tiny.summary(expand_nested=True)) 87 | ``` 88 | 89 | To view different model configurations, refer [here](https://github.com/sayakpaul/ConvNeXt-TF/blob/main/models/model_configs.py). 90 | 91 | ## Upcoming (contributions welcome) 92 | 93 | - [ ] Align layer initializers (useful if someone wanted to train the models 94 | from scratch) 95 | - [ ] Allow the models to accept arbitrary shapes (useful for downstream tasks) 96 | - [ ] Convert the [isotropic models](https://github.com/facebookresearch/ConvNeXt#imagenet-1k-trained-models-isotropic) as well 97 | - [x] Fine-tuning notebook (thanks to [awsaf49](https://github.com/awsaf49)) 98 | - [x] Off-the-shelf-classification notebook 99 | - [x] Publish models on TF-Hub 100 | 101 | ## References 102 | 103 | [1] ConvNeXt paper: https://arxiv.org/abs/2201.03545 104 | 105 | [2] Official ConvNeXt code: https://github.com/facebookresearch/ConvNeXt 106 | 107 | ## Acknowledgements 108 | 109 | * [Vasudev Gupta](https://github.com/vasudevgupta7) 110 | * [Gus](https://twitter.com/gusthema) 111 | * [Willi](https://ch.linkedin.com/in/willi-gierke) 112 | * [ML-GDE program](https://developers.google.com/programs/experts/) 113 | -------------------------------------------------------------------------------- /notebooks/classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "zpiEiO2BUeO5", 6 | "metadata": { 7 | "id": "zpiEiO2BUeO5" 8 | }, 9 | "source": [ 10 | "# Off-the-shelf image classification with ConvNeXt models on TF-Hub\n", 11 | "\n", 12 | "\n", 13 | " \n", 16 | " \n", 19 | " \n", 22 | "
\n", 14 | " Run in Google Colab\n", 15 | " \n", 17 | " View on GitHub\n", 18 | " \n", 20 | " See TF Hub models\n", 21 | "
" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "661e6538", 28 | "metadata": { 29 | "id": "661e6538" 30 | }, 31 | "source": [ 32 | "## Setup" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5", 39 | "metadata": { 40 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9", 51 | "metadata": { 52 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "import tensorflow as tf\n", 57 | "import tensorflow_hub as hub\n", 58 | "from tensorflow import keras\n", 59 | "\n", 60 | "\n", 61 | "from PIL import Image\n", 62 | "from io import BytesIO\n", 63 | "\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "import numpy as np\n", 66 | "import requests" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "z5l1cRpiSavW", 72 | "metadata": { 73 | "id": "z5l1cRpiSavW" 74 | }, 75 | "source": [ 76 | "## Select a ConvNeXt ImageNet-1k model" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "a0wM8idaSaOq", 83 | "metadata": { 84 | "cellView": "form", 85 | "id": "a0wM8idaSaOq" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "model_name = \"convnext_tiny_1k_224\" #@param [\"convnext_tiny_1k_224\", \"convnext_small_1k_224\", \"convnext_base_1k_224\", \"convnext_base_1k_384\", \"convnext_large_1k_224\", \"convnext_large_1k_384\", \"convnext_base_21k_1k_224\", \"convnext_base_21k_1k_384\", \"convnext_large_21k_1k_224\", \"convnext_large_21k_1k_384\", \"convnext_xlarge_21k_1k_224\", \"convnext_xlarge_21k_1k_384\"]\n", 90 | "\n", 91 | "model_handle_map ={\n", 92 | " \"convnext_tiny_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_tiny_1k_224/1\",\n", 93 | " \"convnext_small_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_small_1k_224/1\",\n", 94 | " \"convnext_base_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_base_1k_224/1\",\n", 95 | " \"convnext_base_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_base_1k_384/1\",\n", 96 | " \"convnext_large_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_large_1k_224/1\",\n", 97 | " \"convnext_large_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_large_1k_384/1\",\n", 98 | " \"convnext_base_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_base_21k_1k_224/1\",\n", 99 | " \"convnext_base_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_base_21k_1k_384/1\",\n", 100 | " \"convnext_large_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_large_21k_1k_224/1\",\n", 101 | " \"convnext_large_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_large_21k_1k_384/1\",\n", 102 | " \"convnext_xlarge_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_xlarge_21k_1k_224/1\",\n", 103 | " \"convnext_xlarge_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_xlarge_21k_1k_384/1\",\n", 104 | "\n", 105 | "}\n", 106 | "\n", 107 | "input_resolution = int(model_name.split(\"_\")[-1])\n", 108 | "model_handle = model_handle_map[model_name]\n", 109 | "print(f\"Input resolution: {input_resolution} x {input_resolution} x 3.\")\n", 110 | "print(f\"TF-Hub handle: {model_handle}.\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "441b5361", 116 | "metadata": { 117 | "id": "441b5361" 118 | }, 119 | "source": [ 120 | "## Image preprocessing utilities " 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b", 127 | "metadata": { 128 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b" 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "crop_layer = keras.layers.CenterCrop(224, 224)\n", 133 | "norm_layer = keras.layers.Normalization(\n", 134 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n", 135 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n", 136 | ")\n", 137 | "\n", 138 | "\n", 139 | "def preprocess_image(image, size=input_resolution):\n", 140 | " image = np.array(image)\n", 141 | " image_resized = tf.expand_dims(image, 0)\n", 142 | " \n", 143 | " if size == 224:\n", 144 | " image_resized = tf.image.resize(image_resized, (256, 256), method=\"bicubic\")\n", 145 | " image_resized = crop_layer(image_resized)\n", 146 | " elif size == 384:\n", 147 | " image_resized = tf.image.resize(image, (size, size), method=\"bicubic\")\n", 148 | " \n", 149 | " return norm_layer(image_resized).numpy()\n", 150 | " \n", 151 | "\n", 152 | "def load_image_from_url(url):\n", 153 | " # Credit: Willi Gierke\n", 154 | " response = requests.get(url)\n", 155 | " image = Image.open(BytesIO(response.content))\n", 156 | " preprocessed_image = preprocess_image(image)\n", 157 | " return image, preprocessed_image" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "8b961e14", 163 | "metadata": { 164 | "id": "8b961e14" 165 | }, 166 | "source": [ 167 | "## Load ImageNet-1k labels and a demo image" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b", 174 | "metadata": { 175 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n", 180 | " lines = f.readlines()\n", 181 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n", 182 | "\n", 183 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n", 184 | "image, preprocessed_image = load_image_from_url(img_url)\n", 185 | "\n", 186 | "plt.imshow(image)\n", 187 | "plt.show()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "9006a643", 193 | "metadata": { 194 | "id": "9006a643" 195 | }, 196 | "source": [ 197 | "## Run inference" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908", 204 | "metadata": { 205 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "classification_model = tf.keras.Sequential(\n", 210 | " [hub.KerasLayer(model_handle)]\n", 211 | ") \n", 212 | "predictions = classification_model.predict(preprocessed_image)\n", 213 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n", 214 | "print(predicted_label)" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "accelerator": "GPU", 220 | "colab": { 221 | "machine_shape": "hm", 222 | "name": "classification.ipynb", 223 | "provenance": [] 224 | }, 225 | "environment": { 226 | "kernel": "python3", 227 | "name": "tf2-gpu.2-7.m84", 228 | "type": "gcloud", 229 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84" 230 | }, 231 | "kernelspec": { 232 | "display_name": "Python 3 (ipykernel)", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.2" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 5 251 | } 252 | -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Originally from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 10 | 11 | This script has been slightly modified to support the conversion. 12 | Contact: spsayakpaul@gmail.com 13 | 14 | """ 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from timm.models.layers import trunc_normal_, DropPath 21 | from timm.models.registry import register_model 22 | 23 | 24 | class Block(nn.Module): 25 | r"""ConvNeXt Block. There are two equivalent implementations: 26 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 27 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 28 | We use (2) as we find it slightly faster in PyTorch 29 | 30 | Args: 31 | dim (int): Number of input channels. 32 | drop_path (float): Stochastic depth rate. Default: 0.0 33 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 34 | """ 35 | 36 | def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): 37 | super().__init__() 38 | self.dwconv = nn.Conv2d( 39 | dim, dim, kernel_size=7, padding=3, groups=dim 40 | ) # depthwise conv 41 | self.norm = LayerNorm(dim, eps=1e-6) 42 | self.pwconv1 = nn.Linear( 43 | dim, 4 * dim 44 | ) # pointwise/1x1 convs, implemented with linear layers 45 | self.act = nn.GELU() 46 | self.pwconv2 = nn.Linear(4 * dim, dim) 47 | self.gamma = ( 48 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 49 | if layer_scale_init_value > 0 50 | else None 51 | ) 52 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 53 | 54 | def forward(self, x): 55 | input = x 56 | x = self.dwconv(x) 57 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 58 | x = self.norm(x) 59 | x = self.pwconv1(x) 60 | x = self.act(x) 61 | x = self.pwconv2(x) 62 | if self.gamma is not None: 63 | x = self.gamma * x 64 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 65 | 66 | x = input + self.drop_path(x) 67 | return x 68 | 69 | 70 | class ConvNeXt(nn.Module): 71 | r"""ConvNeXt 72 | A PyTorch impl of : `A ConvNet for the 2020s` - 73 | https://arxiv.org/pdf/2201.03545.pdf 74 | 75 | Args: 76 | in_chans (int): Number of input image channels. Default: 3 77 | num_classes (int): Number of classes for classification head. Default: 1000 78 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 79 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 80 | drop_path_rate (float): Stochastic depth rate. Default: 0. 81 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 82 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | in_chans=3, 88 | num_classes=1000, 89 | depths=[3, 3, 9, 3], 90 | dims=[96, 192, 384, 768], 91 | drop_path_rate=0.0, 92 | layer_scale_init_value=1e-6, 93 | head_init_scale=1.0, 94 | ): 95 | super().__init__() 96 | 97 | self.downsample_layers = ( 98 | nn.ModuleList() 99 | ) # stem and 3 intermediate downsampling conv layers 100 | stem = nn.Sequential( 101 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 102 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), 103 | ) 104 | self.downsample_layers.append(stem) 105 | for i in range(3): 106 | downsample_layer = nn.Sequential( 107 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 108 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 109 | ) 110 | self.downsample_layers.append(downsample_layer) 111 | 112 | self.stages = ( 113 | nn.ModuleList() 114 | ) # 4 feature resolution stages, each consisting of multiple residual blocks 115 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 116 | cur = 0 117 | for i in range(4): 118 | stage = nn.Sequential( 119 | *[ 120 | Block( 121 | dim=dims[i], 122 | drop_path=dp_rates[cur + j], 123 | layer_scale_init_value=layer_scale_init_value, 124 | ) 125 | for j in range(depths[i]) 126 | ] 127 | ) 128 | self.stages.append(stage) 129 | cur += depths[i] 130 | 131 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 132 | self.head = nn.Linear(dims[-1], num_classes) 133 | 134 | self.apply(self._init_weights) 135 | self.head.weight.data.mul_(head_init_scale) 136 | self.head.bias.data.mul_(head_init_scale) 137 | 138 | def _init_weights(self, m): 139 | if isinstance(m, (nn.Conv2d, nn.Linear)): 140 | trunc_normal_(m.weight, std=0.02) 141 | nn.init.constant_(m.bias, 0) 142 | 143 | def forward_features(self, x): 144 | for i in range(4): 145 | x = self.downsample_layers[i](x) 146 | x = self.stages[i](x) 147 | return self.norm( 148 | x.mean([-2, -1]) 149 | ) # global average pooling, (N, C, H, W) -> (N, C) 150 | 151 | def forward(self, x): 152 | x = self.forward_features(x) 153 | x = self.head(x) 154 | return x 155 | 156 | 157 | class LayerNorm(nn.Module): 158 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 159 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 160 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 161 | with shape (batch_size, channels, height, width). 162 | """ 163 | 164 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 165 | super().__init__() 166 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 167 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 168 | self.eps = eps 169 | self.data_format = data_format 170 | if self.data_format not in ["channels_last", "channels_first"]: 171 | raise NotImplementedError 172 | self.normalized_shape = (normalized_shape,) 173 | 174 | def forward(self, x): 175 | if self.data_format == "channels_last": 176 | return F.layer_norm( 177 | x, self.normalized_shape, self.weight, self.bias, self.eps 178 | ) 179 | elif self.data_format == "channels_first": 180 | u = x.mean(1, keepdim=True) 181 | s = (x - u).pow(2).mean(1, keepdim=True) 182 | x = (x - u) / torch.sqrt(s + self.eps) 183 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 184 | return x 185 | 186 | 187 | @register_model 188 | def convnext_tiny(url, **kwargs): 189 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 190 | checkpoint = torch.hub.load_state_dict_from_url( 191 | url=url, map_location="cpu", check_hash=True 192 | ) 193 | model.load_state_dict(checkpoint["model"]) 194 | return model 195 | 196 | 197 | @register_model 198 | def convnext_small(url, **kwargs): 199 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 200 | checkpoint = torch.hub.load_state_dict_from_url( 201 | url=url, map_location="cpu", check_hash=True 202 | ) 203 | model.load_state_dict(checkpoint["model"]) 204 | return model 205 | 206 | 207 | @register_model 208 | def convnext_base(url, **kwargs): 209 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 210 | checkpoint = torch.hub.load_state_dict_from_url( 211 | url=url, map_location="cpu", check_hash=True 212 | ) 213 | model.load_state_dict(checkpoint["model"]) 214 | return model 215 | 216 | 217 | @register_model 218 | def convnext_large(url, **kwargs): 219 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 220 | checkpoint = torch.hub.load_state_dict_from_url( 221 | url=url, map_location="cpu", check_hash=True 222 | ) 223 | model.load_state_dict(checkpoint["model"]) 224 | return model 225 | 226 | 227 | @register_model 228 | def convnext_xlarge(url, **kwargs): 229 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 230 | checkpoint = torch.hub.load_state_dict_from_url( 231 | url=url, map_location="cpu", check_hash=True 232 | ) 233 | model.load_state_dict(checkpoint["model"]) 234 | return model 235 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | from models.convnext_tf import get_convnext_model 2 | from models.model_configs import get_model_config 3 | from models import convnext 4 | 5 | from tensorflow.keras import layers 6 | import tensorflow as tf 7 | import torch 8 | 9 | import os 10 | import argparse 11 | import numpy as np 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | DATASET_TO_CLASSES = { 16 | "imagenet-1k": 1000, 17 | "imagenet-21k": 21841, 18 | } 19 | MODEL_TO_METHOD = { 20 | "convnext_tiny": convnext.convnext_tiny, 21 | "convnext_small": convnext.convnext_small, 22 | "convnext_base": convnext.convnext_base, 23 | "convnext_large": convnext.convnext_large, 24 | "convnext_xlarge": convnext.convnext_xlarge, 25 | } 26 | TF_MODEL_ROOT = "saved_models" 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser( 31 | description="Conversion of the PyTorch pre-trained ConvNeXt weights to TensorFlow." 32 | ) 33 | parser.add_argument( 34 | "-d", 35 | "--dataset", 36 | default="imagenet-1k", 37 | type=str, 38 | required=False, 39 | choices=["imagenet-1k", "imagenet-21k"], 40 | help="Name of the pretraining dataset.", 41 | ) 42 | parser.add_argument( 43 | "-m", 44 | "--model-name", 45 | default="convnext_tiny", 46 | type=str, 47 | required=False, 48 | choices=[ 49 | "convnext_tiny", 50 | "convnext_small", 51 | "convnext_base", 52 | "convnext_large", 53 | "convnext_xlarge", 54 | ], 55 | help="Name of the ConvNeXt model variant.", 56 | ) 57 | parser.add_argument( 58 | "-r", 59 | "--resolution", 60 | default=224, 61 | type=int, 62 | required=False, 63 | choices=[224, 384], 64 | help="Image resolution.", 65 | ) 66 | parser.add_argument( 67 | "-c", 68 | "--checkpoint-path", 69 | default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 70 | type=str, 71 | required=False, 72 | help="URL of the checkpoint to be loaded.", 73 | ) 74 | return vars(parser.parse_args()) 75 | 76 | 77 | def main(args): 78 | print(f'Model: {args["model_name"]}') 79 | print(f'Image resolution: {args["resolution"]}') 80 | print(f'Dataset: {args["dataset"]}') 81 | print(f'Checkpoint URL: {args["checkpoint_path"]}') 82 | 83 | print("Instantiating PyTorch model and populating weights...") 84 | model_method = MODEL_TO_METHOD[args["model_name"]] 85 | convnext_model_pt = model_method( 86 | args["checkpoint_path"], num_classes=DATASET_TO_CLASSES[args["dataset"]] 87 | ) 88 | convnext_model_pt.eval() 89 | 90 | print("Instantiating TensorFlow model...") 91 | model_config = get_model_config(args["model_name"]) 92 | 93 | if "22k_1k" not in args["checkpoint_path"]: 94 | model_name = ( 95 | f'{args["model_name"]}_1k' 96 | if args["dataset"] == "imagenet-1k" 97 | else f'{args["model_name"]}_21k' 98 | ) 99 | else: 100 | model_name = f'{args["model_name"]}_21k_1k' 101 | 102 | convnext_model_tf = get_convnext_model( 103 | model_name=model_name, 104 | input_shape=(args["resolution"], args["resolution"], 3), 105 | num_classes=DATASET_TO_CLASSES[args["dataset"]], 106 | depths=model_config.depths, 107 | dims=model_config.dims, 108 | ) 109 | assert convnext_model_tf.count_params() == sum( 110 | p.numel() for p in convnext_model_pt.parameters() 111 | ) 112 | print("TensorFlow model instantiated, populating pretrained weights...") 113 | 114 | # Fetch the pretrained parameters. 115 | param_list = list(convnext_model_pt.parameters()) 116 | model_states = convnext_model_pt.state_dict() 117 | state_list = list(model_states.keys()) 118 | 119 | # Stem block. 120 | stem_block = convnext_model_tf.get_layer("stem") 121 | 122 | for layer in stem_block.layers: 123 | if isinstance(layer, layers.Conv2D): 124 | layer.kernel.assign( 125 | tf.Variable(param_list[0].numpy().transpose(2, 3, 1, 0)) 126 | ) 127 | layer.bias.assign(tf.Variable(param_list[1].numpy())) 128 | elif isinstance(layer, layers.LayerNormalization): 129 | layer.gamma.assign(tf.Variable(param_list[2].numpy())) 130 | layer.beta.assign(tf.Variable(param_list[3].numpy())) 131 | 132 | # Downsampling layers. 133 | for i in range(3): 134 | downsampling_block = convnext_model_tf.get_layer(f"downsampling_block_{i}") 135 | pytorch_layer_prefix = f"downsample_layers.{i + 1}" 136 | 137 | for l in downsampling_block.layers: 138 | if isinstance(l, layers.LayerNormalization): 139 | l.gamma.assign( 140 | tf.Variable( 141 | model_states[f"{pytorch_layer_prefix}.0.weight"].numpy() 142 | ) 143 | ) 144 | l.beta.assign( 145 | tf.Variable(model_states[f"{pytorch_layer_prefix}.0.bias"].numpy()) 146 | ) 147 | elif isinstance(l, layers.Conv2D): 148 | l.kernel.assign( 149 | tf.Variable( 150 | model_states[f"{pytorch_layer_prefix}.1.weight"] 151 | .numpy() 152 | .transpose(2, 3, 1, 0) 153 | ) 154 | ) 155 | l.bias.assign( 156 | tf.Variable(model_states[f"{pytorch_layer_prefix}.1.bias"].numpy()) 157 | ) 158 | 159 | # ConvNeXt stages. 160 | num_stages = 4 161 | 162 | for m in range(num_stages): 163 | stage_name = f"convnext_stage_{m}" 164 | num_blocks = len(convnext_model_tf.get_layer(stage_name).layers) 165 | 166 | for i in range(num_blocks): 167 | stage_block = convnext_model_tf.get_layer(stage_name).get_layer( 168 | f"convnext_block_{m}_{i}" 169 | ) 170 | stage_prefix = f"stages.{m}.{i}" 171 | 172 | for j, layer in enumerate(stage_block.layers): 173 | if isinstance(layer, layers.Conv2D): 174 | layer.kernel.assign( 175 | tf.Variable( 176 | model_states[f"{stage_prefix}.dwconv.weight"] 177 | .numpy() 178 | .transpose(2, 3, 1, 0) 179 | ) 180 | ) 181 | layer.bias.assign( 182 | tf.Variable(model_states[f"{stage_prefix}.dwconv.bias"].numpy()) 183 | ) 184 | elif isinstance(layer, layers.Dense): 185 | if j == 2: 186 | layer.kernel.assign( 187 | tf.Variable( 188 | model_states[f"{stage_prefix}.pwconv1.weight"] 189 | .numpy() 190 | .transpose() 191 | ) 192 | ) 193 | layer.bias.assign( 194 | tf.Variable( 195 | model_states[f"{stage_prefix}.pwconv1.bias"].numpy() 196 | ) 197 | ) 198 | elif j == 4: 199 | layer.kernel.assign( 200 | tf.Variable( 201 | model_states[f"{stage_prefix}.pwconv2.weight"] 202 | .numpy() 203 | .transpose() 204 | ) 205 | ) 206 | layer.bias.assign( 207 | tf.Variable( 208 | model_states[f"{stage_prefix}.pwconv2.bias"].numpy() 209 | ) 210 | ) 211 | elif isinstance(layer, layers.LayerNormalization): 212 | layer.gamma.assign( 213 | tf.Variable(model_states[f"{stage_prefix}.norm.weight"].numpy()) 214 | ) 215 | layer.beta.assign( 216 | tf.Variable(model_states[f"{stage_prefix}.norm.bias"].numpy()) 217 | ) 218 | 219 | stage_block.gamma.assign( 220 | tf.Variable(model_states[f"{stage_prefix}.gamma"].numpy()) 221 | ) 222 | 223 | # Final LayerNormalization layer and classifier head. 224 | convnext_model_tf.layers[-2].gamma.assign( 225 | tf.Variable(model_states[state_list[-4]].numpy()) 226 | ) 227 | convnext_model_tf.layers[-2].beta.assign( 228 | tf.Variable(model_states[state_list[-3]].numpy()) 229 | ) 230 | 231 | convnext_model_tf.layers[-1].kernel.assign( 232 | tf.Variable(model_states[state_list[-2]].numpy().transpose()) 233 | ) 234 | convnext_model_tf.layers[-1].bias.assign( 235 | tf.Variable(model_states[state_list[-1]].numpy()) 236 | ) 237 | print("Weight population successful, serializing TensorFlow model...") 238 | 239 | model_name = f'{model_name}_{args["resolution"]}' 240 | save_path = os.path.join(TF_MODEL_ROOT, model_name) 241 | convnext_model_tf.save(save_path) 242 | print(f"TensorFlow model serialized to: {save_path}...") 243 | 244 | 245 | if __name__ == "__main__": 246 | args = parse_args() 247 | main(args) 248 | -------------------------------------------------------------------------------- /i1k_eval/eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "926bae46", 6 | "metadata": {}, 7 | "source": [ 8 | "## Imports" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bae636ae-24d1-4523-9997-696731318a81", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from tensorflow.keras import layers\n", 19 | "from tensorflow import keras\n", 20 | "import tensorflow_hub as hub\n", 21 | "import tensorflow as tf\n", 22 | "\n", 23 | "from imutils import paths\n", 24 | "import json\n", 25 | "import re" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "96c4f0a2", 31 | "metadata": {}, 32 | "source": [ 33 | "## Constants" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "f8238055-08bf-44e1-8f3b-98e7768f1603", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "AUTO = tf.data.AUTOTUNE\n", 44 | "BATCH_SIZE = 256\n", 45 | "IMAGE_SIZE = 224\n", 46 | "TF_MODEL_ROOT = \"gs://convnext/saved_models\"" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "74edcf20", 52 | "metadata": {}, 53 | "source": [ 54 | "## Set up ImageNet-1k labels" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "334993ee-0d91-4572-9721-03e67af28cb3", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "with open(\"imagenet_class_index.json\", \"r\") as read_file:\n", 65 | " imagenet_labels = json.load(read_file)\n", 66 | "\n", 67 | "MAPPING_DICT = {}\n", 68 | "LABEL_NAMES = {}\n", 69 | "for label_id in list(imagenet_labels.keys()):\n", 70 | " MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)\n", 71 | " LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "id": "01ad5447-3e28-4c86-941f-f64b45be603a", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "(['val/n01751748/ILSVRC2012_val_00031060.JPEG',\n", 84 | " 'val/n01751748/ILSVRC2012_val_00013492.JPEG',\n", 85 | " 'val/n01751748/ILSVRC2012_val_00033108.JPEG',\n", 86 | " 'val/n01751748/ILSVRC2012_val_00021437.JPEG',\n", 87 | " 'val/n01751748/ILSVRC2012_val_00025096.JPEG'],\n", 88 | " [65, 65, 65, 65, 65])" 89 | ] 90 | }, 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "all_val_paths = list(paths.list_images(\"val\"))\n", 98 | "all_val_labels = [MAPPING_DICT[x.split(\"/\")[1]] for x in all_val_paths]\n", 99 | "\n", 100 | "all_val_paths[:5], all_val_labels[:5]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "1124817d", 106 | "metadata": {}, 107 | "source": [ 108 | "## Preprocessing utilities" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "id": "5a4f03d8-25d1-4660-9858-1b197425d5d9", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def load_and_prepare(path, label):\n", 119 | " image = tf.io.read_file(path)\n", 120 | " image = tf.image.decode_png(image, channels=3)\n", 121 | " image = tf.image.resize(image, (256, 256), method=\"bicubic\")\n", 122 | " return image, label" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 6, 128 | "id": "d0a2f457-ad4d-4cbd-8dfa-db544c7f6531", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# Reference: https://github.com/facebookresearch/ConvNeXt/blob/main/datasets.py\n", 133 | "def get_preprocessing_model(input_size=224):\n", 134 | " preprocessing_model = keras.Sequential()\n", 135 | "\n", 136 | " preprocessing_model.add(layers.CenterCrop(input_size, input_size))\n", 137 | " preprocessing_model.add(layers.Normalization(\n", 138 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n", 139 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n", 140 | " ))\n", 141 | "\n", 142 | " return preprocessing_model" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "56d33240", 148 | "metadata": {}, 149 | "source": [ 150 | "## Prepare `tf.data.Dataset`" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 7, 156 | "id": "f3518397-2ab0-4d79-adea-ae5f1cb66add", 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stderr", 161 | "output_type": "stream", 162 | "text": [ 163 | "2022-01-31 03:20:05.306146: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n", 164 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 165 | "2022-01-31 03:20:05.828828: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory: -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0\n" 166 | ] 167 | }, 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),\n", 172 | " TensorSpec(shape=(None,), dtype=tf.int32, name=None))" 173 | ] 174 | }, 175 | "execution_count": 7, 176 | "metadata": {}, 177 | "output_type": "execute_result" 178 | } 179 | ], 180 | "source": [ 181 | "preprocessor = get_preprocessing_model()\n", 182 | "\n", 183 | "dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))\n", 184 | "dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)\n", 185 | "dataset = dataset.map(lambda x, y: (preprocessor(x), y), num_parallel_calls=AUTO)\n", 186 | "dataset = dataset.prefetch(AUTO)\n", 187 | "dataset.element_spec" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "ea42076a", 193 | "metadata": {}, 194 | "source": [ 195 | "## Fetch model paths and filter the 224x224 models" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "id": "13a7e46e-31b2-48b9-9a57-2873fe27397a", 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "['convnext_base_1k_224/', 'convnext_base_21k_1k_224/', 'convnext_large_1k_224/', 'convnext_large_21k_1k_224/', 'convnext_small_1k_224/', 'convnext_tiny_1k_224/', 'convnext_xlarge_21k_1k_224/']\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)\n", 214 | "models_res_224 = [model_path for model_path in model_paths if str(IMAGE_SIZE) in model_path]\n", 215 | "p = re.compile('.*_21k_224')\n", 216 | "i1k_paths = [path for path in models_res_224 if not p.match(path)]\n", 217 | "\n", 218 | "print(i1k_paths)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "id": "2b7e8e68", 224 | "metadata": {}, 225 | "source": [ 226 | "## Run evaluation" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 13, 232 | "id": "63a3da22-a60f-48b8-a0b0-02e54b2d012f", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def get_model(model_url):\n", 237 | " classification_model = tf.keras.Sequential(\n", 238 | " [\n", 239 | " layers.InputLayer((224, 224, 3)),\n", 240 | " hub.KerasLayer(model_url),\n", 241 | " ]\n", 242 | " )\n", 243 | " return classification_model\n", 244 | "\n", 245 | "\n", 246 | "def evaluate_model(model_name):\n", 247 | " tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f\"logs_{model_name}\")\n", 248 | " model_url = TF_MODEL_ROOT + \"/\" + model_name\n", 249 | " \n", 250 | " model = get_model(model_url)\n", 251 | " model.compile(metrics=[\"accuracy\"])\n", 252 | " _, accuracy = model.evaluate(dataset, callbacks=[tb_callback])\n", 253 | " accuracy = round(accuracy * 100, 4)\n", 254 | " print(f\"{model_name}: {accuracy}%.\", file=open(f\"{model_name.strip('/')}.txt\", \"w\"))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 14, 260 | "id": "4db846db-86bc-4ae8-a699-acb69331d93c", 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "Evaluating convnext_base_1k_224/.\n", 268 | "196/196 [==============================] - 118s 586ms/step - loss: 0.0000e+00 - accuracy: 0.8328\n", 269 | "Evaluating convnext_base_21k_1k_224/.\n", 270 | "196/196 [==============================] - 118s 585ms/step - loss: 0.0000e+00 - accuracy: 0.8536\n", 271 | "Evaluating convnext_large_1k_224/.\n", 272 | "196/196 [==============================] - 177s 879ms/step - loss: 0.0000e+00 - accuracy: 0.8384\n", 273 | "Evaluating convnext_large_21k_1k_224/.\n", 274 | "196/196 [==============================] - 175s 876ms/step - loss: 0.0000e+00 - accuracy: 0.8636\n", 275 | "Evaluating convnext_small_1k_224/.\n", 276 | "196/196 [==============================] - 92s 451ms/step - loss: 0.0000e+00 - accuracy: 0.8239\n", 277 | "Evaluating convnext_tiny_1k_224/.\n", 278 | "196/196 [==============================] - 60s 293ms/step - loss: 0.0000e+00 - accuracy: 0.8131\n", 279 | "Evaluating convnext_xlarge_21k_1k_224/.\n", 280 | "196/196 [==============================] - 241s 1s/step - loss: 0.0000e+00 - accuracy: 0.8673\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "for i1k_path in i1k_paths:\n", 286 | " print(f\"Evaluating {i1k_path}.\")\n", 287 | " evaluate_model(i1k_path)" 288 | ] 289 | } 290 | ], 291 | "metadata": { 292 | "environment": { 293 | "kernel": "python3", 294 | "name": "tf2-gpu.2-7.m84", 295 | "type": "gcloud", 296 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84" 297 | }, 298 | "kernelspec": { 299 | "display_name": "Python 3", 300 | "language": "python", 301 | "name": "python3" 302 | }, 303 | "language_info": { 304 | "codemirror_mode": { 305 | "name": "ipython", 306 | "version": 3 307 | }, 308 | "file_extension": ".py", 309 | "mimetype": "text/x-python", 310 | "name": "python", 311 | "nbconvert_exporter": "python", 312 | "pygments_lexer": "ipython3", 313 | "version": "3.7.12" 314 | } 315 | }, 316 | "nbformat": 4, 317 | "nbformat_minor": 5 318 | } 319 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Meta Platforms Inc. and Sayak Paul 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /notebooks/finetune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "# Fine-tuning for image classification with ConvNeXt models on TF-Hub\n", 11 | "\n", 12 | "\n", 13 | " \n", 16 | " \n", 19 | " \n", 22 | "
\n", 14 | " Run in Google Colab\n", 15 | " \n", 17 | " View on GitHub\n", 18 | " \n", 20 | " See TF Hub models\n", 21 | "
" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "89B27-TGiDNB" 29 | }, 30 | "source": [ 31 | "## Imports" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "id": "9u3d4Z7uQsmp" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "from tensorflow import keras\n", 43 | "import tensorflow as tf\n", 44 | "import tensorflow_hub as hub\n", 45 | "import tensorflow_datasets as tfds\n", 46 | "\n", 47 | "tfds.disable_progress_bar()\n", 48 | "\n", 49 | "import os\n", 50 | "import sys\n", 51 | "import math\n", 52 | "import numpy as np\n", 53 | "import pandas as pd\n", 54 | "import matplotlib.pyplot as plt" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "mPo10cahZXXQ" 61 | }, 62 | "source": [ 63 | "## TPU/GPU detection" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "id": "FpvUOuC3j27n" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "try: # detect TPUs\n", 75 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection\n", 76 | " strategy = tf.distribute.TPUStrategy(tpu)\n", 77 | "except ValueError: # detect GPUs\n", 78 | " tpu = False\n", 79 | " strategy = (\n", 80 | " tf.distribute.get_strategy()\n", 81 | " ) # default strategy that works on CPU and single GPU\n", 82 | "print(\"Number of Accelerators: \", strategy.num_replicas_in_sync)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "id": "w9S3uKC_iXY5" 89 | }, 90 | "source": [ 91 | "## Configuration" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "id": "kCc6tdUGnD4C" 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# Model\n", 103 | "IMAGE_SIZE = [224, 224]\n", 104 | "MODEL_PATH = \"https://tfhub.dev/sayakpaul/convnext_tiny_1k_224_fe/1\"\n", 105 | "\n", 106 | "# TPU\n", 107 | "if tpu:\n", 108 | " BATCH_SIZE = (\n", 109 | " 16 * strategy.num_replicas_in_sync\n", 110 | " ) # a TPU has 8 cores so this will be 128\n", 111 | "else:\n", 112 | " BATCH_SIZE = 64 # on Colab/GPU, a higher batch size may throw(OOM)\n", 113 | "\n", 114 | "# Dataset\n", 115 | "CLASSES = [\n", 116 | " \"dandelion\",\n", 117 | " \"daisy\",\n", 118 | " \"tulips\",\n", 119 | " \"sunflowers\",\n", 120 | " \"roses\",\n", 121 | "] # don't change the order\n", 122 | "\n", 123 | "# Other constants\n", 124 | "MEAN = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255]) # imagenet mean\n", 125 | "STD = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255]) # imagenet std\n", 126 | "AUTO = tf.data.AUTOTUNE" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "9iTImGI5qMQT" 133 | }, 134 | "source": [ 135 | "# Data Pipeline" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "id": "h29TLx7gqN_7" 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):\n", 147 | " def preprocess(image, label):\n", 148 | " # for training, do augmentation\n", 149 | " if train:\n", 150 | " if tf.random.uniform(shape=[]) > 0.5:\n", 151 | " image = tf.image.flip_left_right(image)\n", 152 | " image = tf.image.resize(image, size=image_size, method=\"bicubic\")\n", 153 | " image = (image - MEAN) / STD # normalization\n", 154 | " return image, label\n", 155 | "\n", 156 | " if train:\n", 157 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n", 158 | "\n", 159 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": { 165 | "id": "AMQ3Qs9_pddU" 166 | }, 167 | "source": [ 168 | "# Flower Dataset" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "id": "M3G-2aUBQJ-H" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "train_dataset, val_dataset = tfds.load(\n", 180 | " \"tf_flowers\",\n", 181 | " split=[\"train[:90%]\", \"train[90%:]\"],\n", 182 | " as_supervised=True,\n", 183 | " try_gcs=False, # gcs_path is necessary for tpu,\n", 184 | ")\n", 185 | "\n", 186 | "num_train = tf.data.experimental.cardinality(train_dataset)\n", 187 | "num_val = tf.data.experimental.cardinality(val_dataset)\n", 188 | "print(f\"Number of training examples: {num_train}\")\n", 189 | "print(f\"Number of validation examples: {num_val}\")\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": { 195 | "id": "l2X7sE3oRLXN" 196 | }, 197 | "source": [ 198 | "## Prepare dataset" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": { 205 | "id": "oftrfYw1qXei" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "train_dataset = make_dataset(train_dataset, True)\n", 210 | "val_dataset = make_dataset(val_dataset, False)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": { 216 | "id": "kNyCCM6PRM8I" 217 | }, 218 | "source": [ 219 | "## Visualize" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": { 226 | "id": "IaGzFUUVqjaC" 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "sample_images, sample_labels = next(iter(train_dataset))\n", 231 | "\n", 232 | "plt.figure(figsize=(5 * 3, 3 * 3))\n", 233 | "for n in range(15):\n", 234 | " ax = plt.subplot(3, 5, n + 1)\n", 235 | " image = (sample_images[n] * STD + MEAN).numpy()\n", 236 | " image = (image - image.min()) / (\n", 237 | " image.max() - image.min()\n", 238 | " ) # convert to [0, 1] for avoiding matplotlib warning\n", 239 | " plt.imshow(image)\n", 240 | " plt.title(CLASSES[sample_labels[n]])\n", 241 | " plt.axis(\"off\")\n", 242 | "plt.tight_layout()\n", 243 | "plt.show()\n" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": { 249 | "id": "Qf6u_7tt8BYy" 250 | }, 251 | "source": [ 252 | "# LR Scheduler Utility" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "id": "oVTbnkJL79T_" 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "# Reference:\n", 264 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n", 265 | "\n", 266 | "\n", 267 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n", 268 | " def __init__(\n", 269 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n", 270 | " ):\n", 271 | " super(WarmUpCosine, self).__init__()\n", 272 | "\n", 273 | " self.learning_rate_base = learning_rate_base\n", 274 | " self.total_steps = total_steps\n", 275 | " self.warmup_learning_rate = warmup_learning_rate\n", 276 | " self.warmup_steps = warmup_steps\n", 277 | " self.pi = tf.constant(np.pi)\n", 278 | "\n", 279 | " def __call__(self, step):\n", 280 | " if self.total_steps < self.warmup_steps:\n", 281 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n", 282 | " learning_rate = (\n", 283 | " 0.5\n", 284 | " * self.learning_rate_base\n", 285 | " * (\n", 286 | " 1\n", 287 | " + tf.cos(\n", 288 | " self.pi\n", 289 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n", 290 | " / float(self.total_steps - self.warmup_steps)\n", 291 | " )\n", 292 | " )\n", 293 | " )\n", 294 | "\n", 295 | " if self.warmup_steps > 0:\n", 296 | " if self.learning_rate_base < self.warmup_learning_rate:\n", 297 | " raise ValueError(\n", 298 | " \"Learning_rate_base must be larger or equal to \"\n", 299 | " \"warmup_learning_rate.\"\n", 300 | " )\n", 301 | " slope = (\n", 302 | " self.learning_rate_base - self.warmup_learning_rate\n", 303 | " ) / self.warmup_steps\n", 304 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n", 305 | " learning_rate = tf.where(\n", 306 | " step < self.warmup_steps, warmup_rate, learning_rate\n", 307 | " )\n", 308 | " return tf.where(\n", 309 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", 310 | " )" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "id": "ALtRUlxhw8Vt" 317 | }, 318 | "source": [ 319 | "# Model Utility" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "id": "JD9SI_Q9JdAB" 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "def get_model(model_path=MODEL_PATH, res=224, num_classes=5):\n", 331 | " hub_layer = hub.KerasLayer(model_path, trainable=True)\n", 332 | "\n", 333 | " model = keras.Sequential(\n", 334 | " [\n", 335 | " keras.layers.InputLayer((res, res, 3)),\n", 336 | " hub_layer,\n", 337 | " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", 338 | " ]\n", 339 | " )\n", 340 | " return model" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": { 347 | "id": "wpZApp9u9_Y-" 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "get_model().summary()" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "id": "dMfenMQcxAAb" 358 | }, 359 | "source": [ 360 | "# Training Hyperparameters" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "id": "1D7Iu7oD8WzX" 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "EPOCHS = 10\n", 372 | "WARMUP_STEPS = 10\n", 373 | "INIT_LR = 0.03\n", 374 | "WAMRUP_LR = 0.006\n", 375 | "\n", 376 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": { 383 | "id": "zmolHMov8als" 384 | }, 385 | "outputs": [], 386 | "source": [ 387 | "scheduled_lrs = WarmUpCosine(\n", 388 | " learning_rate_base=INIT_LR,\n", 389 | " total_steps=TOTAL_STEPS,\n", 390 | " warmup_learning_rate=WAMRUP_LR,\n", 391 | " warmup_steps=WARMUP_STEPS,\n", 392 | ")\n", 393 | "\n", 394 | "lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]\n", 395 | "plt.figure(figsize=(10, 6))\n", 396 | "plt.plot(lrs)\n", 397 | "plt.xlabel(\"Step\", fontsize=14)\n", 398 | "plt.ylabel(\"LR\", fontsize=14)\n", 399 | "plt.show()\n" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "id": "M-ID7vP5mIKs" 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "optimizer = keras.optimizers.SGD(scheduled_lrs)\n", 411 | "loss = keras.losses.SparseCategoricalCrossentropy()" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "E9p4ymNh9y7d" 418 | }, 419 | "source": [ 420 | "# Training & Validation" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": { 427 | "id": "VnZTSd8K90Mq" 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "with strategy.scope(): # this line is all that is needed to run on TPU (or multi-GPU, ...)\n", 432 | " model = get_model(MODEL_PATH)\n", 433 | " model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n", 434 | "\n", 435 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": { 442 | "id": "jc7LMVz5Cbx6" 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "result = pd.DataFrame(history.history)\n", 447 | "fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n", 448 | "result[[\"accuracy\", \"val_accuracy\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[0])\n", 449 | "result[[\"loss\", \"val_loss\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[1])\n" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": { 455 | "id": "MKFMWzh0Yxsq" 456 | }, 457 | "source": [ 458 | "# Predictions" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "id": "yMEsR851VDZb" 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "sample_images, sample_labels = next(iter(val_dataset))\n", 470 | "\n", 471 | "predictions = model.predict(sample_images, batch_size=16).argmax(axis=-1)\n", 472 | "evaluations = model.evaluate(sample_images, sample_labels, batch_size=16)\n", 473 | "\n", 474 | "print(\"[val_loss, val_acc]\", evaluations)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": { 481 | "id": "qzCCDL1CZFx6" 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "plt.figure(figsize=(5 * 3, 3 * 3))\n", 486 | "for n in range(15):\n", 487 | " ax = plt.subplot(3, 5, n + 1)\n", 488 | " image = (sample_images[n] * STD + MEAN).numpy()\n", 489 | " image = (image - image.min()) / (\n", 490 | " image.max() - image.min()\n", 491 | " ) # convert to [0, 1] for avoiding matplotlib warning\n", 492 | " plt.imshow(image)\n", 493 | " target = CLASSES[sample_labels[n]]\n", 494 | " pred = CLASSES[predictions[n]]\n", 495 | " plt.title(\"{} ({})\".format(target, pred))\n", 496 | " plt.axis(\"off\")\n", 497 | "plt.tight_layout()\n", 498 | "plt.show()\n" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": { 504 | "id": "2e5oy9zmNNID" 505 | }, 506 | "source": [ 507 | "# Reference\n", 508 | "* [ConvNeXt-TF](https://github.com/sayakpaul/ConvNeXt-TF)\n", 509 | "* [Keras Flowers on TPU (solution)](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_solution.ipynb)" 510 | ] 511 | } 512 | ], 513 | "metadata": { 514 | "accelerator": "GPU", 515 | "colab": { 516 | "collapsed_sections": [], 517 | "include_colab_link": true, 518 | "machine_shape": "hm", 519 | "name": "ConvNext-TF: Flower Classification (TPU/GPU).ipynb", 520 | "provenance": [], 521 | "toc_visible": true 522 | }, 523 | "environment": { 524 | "name": "tf22-cpu.2-2.m47", 525 | "type": "gcloud", 526 | "uri": "gcr.io/deeplearning-platform-release/tf22-cpu.2-2:m47" 527 | }, 528 | "kernelspec": { 529 | "display_name": "Python 3 (ipykernel)", 530 | "language": "python", 531 | "name": "python3" 532 | }, 533 | "language_info": { 534 | "codemirror_mode": { 535 | "name": "ipython", 536 | "version": 3 537 | }, 538 | "file_extension": ".py", 539 | "mimetype": "text/x-python", 540 | "name": "python", 541 | "nbconvert_exporter": "python", 542 | "pygments_lexer": "ipython3", 543 | "version": "3.8.2" 544 | } 545 | }, 546 | "nbformat": 4, 547 | "nbformat_minor": 1 548 | } 549 | -------------------------------------------------------------------------------- /i1k_eval/imagenet_class_index.json: -------------------------------------------------------------------------------- 1 | {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]} --------------------------------------------------------------------------------