├── .gitignore
├── requirements.txt
├── hub_utilities
├── README.md
├── export_to_hub.py
└── generate_doc.py
├── i1k_eval
├── README.md
├── eval.ipynb
└── imagenet_class_index.json
├── models
├── model_configs.py
├── convnext_tf.py
└── convnext.py
├── convert_all_available_models.py
├── README.md
├── notebooks
├── classification.ipynb
└── finetune.ipynb
├── convert.py
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | .DS_Store
3 | .idea
4 | *.pb
5 | *.pyc
6 | saved_models/
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.7.0
2 | torch==1.10.1
3 | torchvision==0.11.2
4 | timm==0.4.12
5 | ml_collections==0.1.0
--------------------------------------------------------------------------------
/hub_utilities/README.md:
--------------------------------------------------------------------------------
1 | The scripts contained in this directory are somewhat closely related to TF-Hub. Following utilities are supported:
2 |
3 | * `export_to_hub.py`: Exports a bulk of SavedModels as `tar.gz` archives needed by TF-Hub.
4 | * `generate_doc.py`: Generates documentation for a bulk of models.
5 |
--------------------------------------------------------------------------------
/i1k_eval/README.md:
--------------------------------------------------------------------------------
1 | This directory provides a notebook and ImageNet-1k class mapping file to run
2 | evaluation on the ImageNet-1k `val` split using the TF/Keras converted ConvNeXt
3 | models. The notebook assumes the following files are present in your working
4 | directory:
5 |
6 | * The `val` split directory of ImageNet-1k.
7 | * The class mapping files (`.json`).
--------------------------------------------------------------------------------
/models/model_configs.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuratiionns for different ConvNeXt variants.
3 |
4 | Referred from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
5 | """
6 |
7 |
8 | import ml_collections
9 |
10 |
11 | def convnext_tiny_config() -> ml_collections.ConfigDict:
12 | configs = ml_collections.ConfigDict()
13 | configs.depths = [3, 3, 9, 3]
14 | configs.dims = [96, 192, 384, 768]
15 | return configs
16 |
17 |
18 | def convnext_small_config() -> ml_collections.ConfigDict:
19 | configs = convnext_tiny_config()
20 | configs.depths = [3, 3, 27, 3]
21 | return configs
22 |
23 |
24 | def convnext_base_config() -> ml_collections.ConfigDict:
25 | configs = convnext_small_config()
26 | configs.dims = [128, 256, 512, 1024]
27 | return configs
28 |
29 |
30 | def convnext_large_config() -> ml_collections.ConfigDict:
31 | configs = convnext_base_config()
32 | configs.dims = [192, 384, 768, 1536]
33 | return configs
34 |
35 |
36 | def convnext_xlarge_config() -> ml_collections.ConfigDict:
37 | configs = convnext_large_config()
38 | configs.dims = [256, 512, 1024, 2048]
39 | return configs
40 |
41 |
42 | def get_model_config(model_name: str) -> ml_collections.ConfigDict:
43 | if model_name == "convnext_tiny":
44 | return convnext_tiny_config()
45 | elif model_name == "convnext_small":
46 | return convnext_small_config()
47 | elif model_name == "convnext_base":
48 | return convnext_base_config()
49 | elif model_name == "convnext_large":
50 | return convnext_large_config()
51 | else:
52 | return convnext_xlarge_config()
53 |
--------------------------------------------------------------------------------
/hub_utilities/export_to_hub.py:
--------------------------------------------------------------------------------
1 | """Generates .tar.gz archives from SavedModels and serializes them."""
2 |
3 |
4 | from typing import List
5 | import tensorflow as tf
6 | import os
7 |
8 |
9 | TF_MODEL_ROOT = "gs://convnext/saved_models"
10 | TAR_ARCHIVES = os.path.join(TF_MODEL_ROOT, "tars/")
11 |
12 |
13 | def generate_fe(model: tf.keras.Model) -> tf.keras.Model:
14 | """Generates a feature extractor from a classifier."""
15 | feature_extractor = tf.keras.Model(model.inputs, model.layers[-2].output)
16 | return feature_extractor
17 |
18 |
19 | def prepare_archive(model_name: str) -> None:
20 | """Prepares a tar archive."""
21 | archive_name = f"{model_name}.tar.gz"
22 | print(f"Archiving to {archive_name}.")
23 | archive_command = f"cd {model_name} && tar -czvf ../{archive_name} *"
24 | os.system(archive_command)
25 | os.system(f"rm -rf {model_name}")
26 |
27 |
28 | def save_to_gcs(model_paths: List[str]) -> None:
29 | """Prepares tar archives and saves them inside a GCS bucket."""
30 | for path in model_paths:
31 | print(f"Preparing classification model: {path}.")
32 | model_name = path.strip("/")
33 | abs_model_path = os.path.join(TF_MODEL_ROOT, model_name)
34 |
35 | print(f"Copying from {abs_model_path}.")
36 | os.system(f"gsutil cp -r {abs_model_path} .")
37 | prepare_archive(model_name)
38 |
39 | print("Preparing feature extractor.")
40 | model = tf.keras.models.load_model(abs_model_path)
41 | fe_model = generate_fe(model)
42 | fe_model_name = f"{model_name}_fe"
43 | fe_model.save(fe_model_name)
44 | prepare_archive(fe_model_name)
45 |
46 | os.system(f"gsutil -m cp -r *.tar.gz {TAR_ARCHIVES}")
47 | os.system("rm -rf *.tar.gz")
48 |
49 |
50 | model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)
51 | print(f"Total models: {len(model_paths)}.")
52 |
53 | print("Preparing archives for the classification and feature extractor models.")
54 | save_to_gcs(model_paths)
55 |
--------------------------------------------------------------------------------
/convert_all_available_models.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import os
3 |
4 | """
5 | Details about these checkpoints are available here:
6 | https://github.com/facebookresearch/ConvNeXt#results-and-pre-trained-models.
7 | """
8 |
9 | imagenet_1k_224 = {
10 | "convnext_tiny": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
11 | "convnext_small": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
12 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
13 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
14 | }
15 |
16 | imagenet_1k_384 = {
17 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth",
18 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth",
19 | }
20 |
21 | imagenet_21k_224 = {
22 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
23 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
24 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
25 | }
26 |
27 | imagenet_21k_1k_224 = {
28 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth",
29 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth",
30 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth",
31 | }
32 |
33 | imagenet_21k_1k_384 = {
34 | "convnext_base": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth",
35 | "convnext_large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth",
36 | "convnext_xlarge": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth",
37 | }
38 |
39 | print("Converting 224x224 resolution ImageNet-1k models.")
40 | for model in tqdm(imagenet_1k_224):
41 | print(f"Converting {model}.")
42 | command = f"python convert.py -m {model} -c {imagenet_1k_224[model]}"
43 | os.system(command)
44 |
45 |
46 | print("Converting 384x384 resolution ImageNet-1k models.")
47 | for model in tqdm(imagenet_1k_384):
48 | print(f"Converting {model}.")
49 | command = f"python convert.py -m {model} -c {imagenet_1k_384[model]} -r 384"
50 | os.system(command)
51 |
52 |
53 | print("Converting 224x224 resolution ImageNet-21k models.")
54 | for model in tqdm(imagenet_21k_224):
55 | print(f"Converting {model}.")
56 | command = f"python convert.py -d imagenet-21k -m {model} -c {imagenet_21k_224[model]} -r 224"
57 | os.system(command)
58 |
59 |
60 | print(
61 | "Converting 224x224 resolution ImageNet-21k trained ImageNet-1k fine-tuned models."
62 | )
63 | for model in tqdm(imagenet_21k_1k_224):
64 | print(f"Converting {model}.")
65 | command = f"python convert.py -m {model} -c {imagenet_21k_1k_224[model]} -r 224"
66 | os.system(command)
67 |
68 |
69 | print(
70 | "Converting 384x384 resolution ImageNet-21k trained ImageNet-1k fine-tuned models."
71 | )
72 | for model in tqdm(imagenet_21k_1k_384):
73 | print(f"Converting {model}.")
74 | command = f"python convert.py -m {model} -c {imagenet_21k_1k_384[model]} -r 384"
75 | os.system(command)
76 |
--------------------------------------------------------------------------------
/hub_utilities/generate_doc.py:
--------------------------------------------------------------------------------
1 | """Generates model documentation for ConvNeXt-TF models.
2 |
3 | Credits: Willi Gierke
4 | """
5 |
6 | from string import Template
7 | import attr
8 | import os
9 |
10 | template = Template(
11 | """# Module $HANDLE
12 |
13 | Fine-tunable ConvNeXt model pre-trained on the $DATASET_DESCRIPTION.
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | ## Overview
24 |
25 | This model is a ConvNeXt [1] model pre-trained on the $DATASET_DESCRIPTION. You can find the complete
26 | collection of ConvNeXt models on TF-Hub on [this page](https://tfhub.dev/sayakpaul/collections/convnext/1).
27 |
28 | You can use this model for feature extraction and fine-tuning. Please refer to
29 | the Colab Notebook linked on this page for more details.
30 |
31 | ## Notes
32 |
33 | * The original model weights are provided from [2]. There were ported to Keras models
34 | (`tf.keras.Model`) and then serialized as TensorFlow SavedModels. The porting
35 | steps are available in [3].
36 | * The model can be unrolled into a standard Keras model and you can inspect its topology.
37 | To do so, first download the model from TF-Hub and then load it using `tf.keras.models.load_model`
38 | providing the path to the downloaded model folder.
39 |
40 | ## References
41 |
42 | [1] [A ConvNet for the 2020s by Liu et al.](https://arxiv.org/abs/2201.03545)
43 | [2] [ConvNeXt GitHub](https://github.com/facebookresearch/ConvNeXt)
44 | [3] [ConvNeXt-TF GitHub](https://github.com/sayakpaul/ConvNeXt-TF)
45 |
46 | ## Acknowledgements
47 |
48 | * [Vasudev Gupta](https://github.com/vasudevgupta7)
49 | * [Gus](https://twitter.com/gusthema)
50 | * [Willi](https://ch.linkedin.com/in/willi-gierke)
51 | * [ML-GDE program](https://developers.google.com/programs/experts/)
52 |
53 | """
54 | )
55 |
56 |
57 | @attr.s
58 | class Config:
59 | size = attr.ib(type=str)
60 | dataset = attr.ib(type=str)
61 | single_resolution = attr.ib(type=int)
62 |
63 | def two_d_resolution(self):
64 | return f"{self.single_resolution}x{self.single_resolution}"
65 |
66 | def gcs_folder_name(self):
67 | return f"convnext_{self.size}_{self.dataset}_{self.single_resolution}_fe"
68 |
69 | def handle(self):
70 | return f"sayakpaul/{self.gcs_folder_name()}/1"
71 |
72 | def rel_doc_file_path(self):
73 | """Relative to the tfhub.dev directory."""
74 | return f"assets/docs/{self.handle()}.md"
75 |
76 |
77 | for c in [
78 | Config("tiny", "1k", 224),
79 | Config("small", "1k", 224),
80 | Config("base", "1k", 224),
81 | Config("base", "1k", 384),
82 | Config("large", "1k", 224),
83 | Config("large", "1k", 384),
84 | Config("base", "21k_1k", 224),
85 | Config("base", "21k_1k", 384),
86 | Config("large", "21k_1k", 224),
87 | Config("large", "21k_1k", 384),
88 | Config("xlarge", "21k_1k", 224),
89 | Config("xlarge", "21k_1k", 384),
90 | Config("base", "21k", 224),
91 | Config("large", "21k", 224),
92 | Config("xlarge", "21k", 224),
93 | ]:
94 | if c.dataset == "1k":
95 | dataset_text = "ImageNet-1k dataset"
96 | elif c.dataset == "21k":
97 | dataset_text = "ImageNet-21k dataset"
98 | else:
99 | dataset_text = (
100 | "ImageNet-21k"
101 | " dataset and"
102 | " was then "
103 | "fine-tuned "
104 | "on the "
105 | "ImageNet-1k "
106 | "dataset"
107 | )
108 |
109 | save_path = os.path.join(
110 | "/Users/sayakpaul/Downloads/", "tfhub.dev", c.rel_doc_file_path()
111 | )
112 | model_folder = save_path.split("/")[-2]
113 | model_abs_path = "/".join(save_path.split("/")[:-1])
114 |
115 | if not os.path.exists(model_abs_path):
116 | os.makedirs(model_abs_path, exist_ok=True)
117 |
118 | with open(save_path, "w") as f:
119 | f.write(
120 | template.substitute(
121 | HANDLE=c.handle(),
122 | DATASET_DESCRIPTION=dataset_text,
123 | INPUT_RESOLUTION=c.two_d_resolution(),
124 | ARCHIVE_NAME=c.gcs_folder_name(),
125 | )
126 | )
127 |
--------------------------------------------------------------------------------
/models/convnext_tf.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tensorflow import keras
4 | from tensorflow.keras import layers
5 |
6 |
7 | class StochasticDepth(layers.Layer):
8 | """Stochastic Depth module.
9 |
10 | It is also referred to as Drop Path in `timm`.
11 | References:
12 | (1) github.com:rwightman/pytorch-image-models
13 | """
14 |
15 | def __init__(self, drop_path, **kwargs):
16 | super(StochasticDepth, self).__init__(**kwargs)
17 | self.drop_path = drop_path
18 |
19 | def call(self, x, training=None):
20 | if training:
21 | keep_prob = 1 - self.drop_path
22 | shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
23 | random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
24 | random_tensor = tf.floor(random_tensor)
25 | return (x / keep_prob) * random_tensor
26 | return x
27 |
28 |
29 | class Block(tf.keras.Model):
30 | """ConvNeXt block.
31 |
32 | References:
33 | (1) https://arxiv.org/abs/2201.03545
34 | (2) https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
35 | """
36 |
37 | def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6, **kwargs):
38 | super(Block, self).__init__(**kwargs)
39 | self.dim = dim
40 | if layer_scale_init_value > 0:
41 | self.gamma = tf.Variable(layer_scale_init_value * tf.ones((dim,)))
42 | else:
43 | self.gamma = None
44 | self.dw_conv_1 = layers.Conv2D(
45 | filters=dim, kernel_size=7, padding="same", groups=dim
46 | )
47 | self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
48 | self.pw_conv_1 = layers.Dense(4 * dim)
49 | self.act_fn = layers.Activation("gelu")
50 | self.pw_conv_2 = layers.Dense(dim)
51 | self.drop_path = (
52 | StochasticDepth(drop_path)
53 | if drop_path > 0.0
54 | else layers.Activation("linear")
55 | )
56 |
57 | def call(self, inputs):
58 | x = inputs
59 |
60 | x = self.dw_conv_1(x)
61 | x = self.layer_norm(x)
62 | x = self.pw_conv_1(x)
63 | x = self.act_fn(x)
64 | x = self.pw_conv_2(x)
65 |
66 | if self.gamma is not None:
67 | x = self.gamma * x
68 |
69 | return inputs + self.drop_path(x)
70 |
71 |
72 | def get_convnext_model(
73 | model_name="convnext_tiny_1k",
74 | input_shape=(224, 224, 3),
75 | num_classes=1000,
76 | depths=[3, 3, 9, 3],
77 | dims=[96, 192, 384, 768],
78 | drop_path_rate=0.0,
79 | layer_scale_init_value=1e-6,
80 | ) -> keras.Model:
81 | """Implements ConvNeXt family of models given a configuration.
82 |
83 | References:
84 | (1) https://arxiv.org/abs/2201.03545
85 | (2) https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
86 |
87 | Note: `predict()` fails on CPUs because of group convolutions. The fix is recent at
88 | the time of the development: https://github.com/keras-team/keras/pull/15868. It's
89 | recommended to use a GPU / TPU.
90 | """
91 |
92 | inputs = layers.Input(input_shape)
93 | stem = keras.Sequential(
94 | [
95 | layers.Conv2D(dims[0], kernel_size=4, strides=4),
96 | layers.LayerNormalization(epsilon=1e-6),
97 | ],
98 | name="stem",
99 | )
100 |
101 | downsample_layers = []
102 | downsample_layers.append(stem)
103 | for i in range(3):
104 | downsample_layer = keras.Sequential(
105 | [
106 | layers.LayerNormalization(epsilon=1e-6),
107 | layers.Conv2D(dims[i + 1], kernel_size=2, strides=2),
108 | ],
109 | name=f"downsampling_block_{i}",
110 | )
111 | downsample_layers.append(downsample_layer)
112 |
113 | stages = []
114 | dp_rates = [x for x in tf.linspace(0.0, drop_path_rate, sum(depths))]
115 | cur = 0
116 | for i in range(4):
117 | stage = keras.Sequential(
118 | [
119 | *[
120 | Block(
121 | dim=dims[i],
122 | drop_path=dp_rates[cur + j],
123 | layer_scale_init_value=layer_scale_init_value,
124 | name=f"convnext_block_{i}_{j}",
125 | )
126 | for j in range(depths[i])
127 | ]
128 | ],
129 | name=f"convnext_stage_{i}",
130 | )
131 | stages.append(stage)
132 | cur += depths[i]
133 |
134 | x = inputs
135 | for i in range(len(stages)):
136 | x = downsample_layers[i](x)
137 | x = stages[i](x)
138 |
139 | x = layers.GlobalAvgPool2D()(x)
140 | x = layers.LayerNormalization(epsilon=1e-6)(x)
141 |
142 | outputs = layers.Dense(num_classes, name="classification_head")(x)
143 |
144 | return keras.Model(inputs, outputs, name=model_name)
145 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConvNeXt-TF
2 |
3 | This repository provides TensorFlow / Keras implementations of different ConvNeXt
4 | [1] variants. It also provides the TensorFlow / Keras models that have been
5 | populated with the original ConvNeXt pre-trained weights available from [2]. These
6 | models are not blackbox SavedModels i.e., they can be fully expanded into `tf.keras.Model`
7 | objects and one can call all the utility functions on them (example: `.summary()`).
8 |
9 | As of today, all the TensorFlow / Keras variants of the models listed
10 | [here](https://github.com/facebookresearch/ConvNeXt#results-and-pre-trained-models)
11 | are available in this repository except for the
12 | [isotropic ones](https://github.com/facebookresearch/ConvNeXt#imagenet-1k-trained-models-isotropic).
13 | This list includes the ImageNet-1k as well as ImageNet-21k models.
14 |
15 | Refer to the ["Using the models"](https://github.com/sayakpaul/ConvNeXt-TF#using-the-models)
16 | section to get started. Additionally, here's a [related blog post](https://sayak.dev/convnext-tfhub/)
17 | that jots down my experience.
18 |
19 | ## Conversion
20 |
21 | TensorFlow / Keras implementations are available in `models/convnext_tf.py`.
22 | Conversion utilities are in `convert.py`.
23 |
24 | ## Models
25 |
26 | The converted models are available on [TF-Hub](https://tfhub.dev/sayakpaul/collections/convnext/1).
27 |
28 | There should be a total of 15 different models each having two variants: classifier and
29 | feature extractor. You can load any model and get started like so:
30 |
31 | ```py
32 | import tensorflow as tf
33 |
34 | model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed"
35 | model = tf.keras.models.load_model(model_gcs_path)
36 | print(model.summary(expand_nested=True))
37 | ```
38 |
39 | The model names are interpreted as follows:
40 |
41 | * `convnext_large_21k_1k_384`: This means that the model was first pre-trained
42 | on the ImageNet-21k dataset and was then fine-tuned on the ImageNet-1k dataset.
43 | Resolution used during pre-training and fine-tuning: 384x384. `large` denotes
44 | the topology of the underlying model.
45 | * `convnext_large_1k_224`: Means that the model was pre-trained on the ImageNet-1k
46 | dataset with a resolution of 224x224.
47 |
48 | ## Results
49 |
50 | Results are on ImageNet-1k validation set (top-1 accuracy).
51 |
52 | | name | original acc@1 | keras acc@1 |
53 | |:---:|:---:|:---:|
54 | | convnext_tiny_1k_224 | 82.1 | 81.312 |
55 | | convnext_small_1k_224 | 83.1 | 82.392 |
56 | | convnext_base_1k_224 | 83.8 | 83.28 |
57 | | convnext_base_1k_384 | 85.1 | 84.876 |
58 | | convnext_large_1k_224 | 84.3 | 83.844 |
59 | | convnext_large_1k_384 | 85.5 | 85.376 |
60 | | | | |
61 | | convnext_base_21k_1k_224 | 85.8 | 85.364 |
62 | | convnext_base_21k_1k_384 | 86.8 | 86.79 |
63 | | convnext_large_21k_1k_224 | 86.6 | 86.36 |
64 | | convnext_large_21k_1k_384 | 87.5 | 87.504 |
65 | | convnext_xlarge_21k_1k_224 | 87.0 | 86.732 |
66 | | convnext_xlarge_21k_1k_384 | 87.8 | 87.68 |
67 |
68 | Differences in the results are primarily because of the differences in the library
69 | implementations especially how image resizing is implemented in PyTorch and
70 | TensorFlow. Results can be verified with the code in `i1k_eval`. Logs
71 | are available at [this URL](https://tensorboard.dev/experiment/odN7OPCqQvGYCRpJP1GhRQ/).
72 |
73 | ## Using the models
74 |
75 | **Pre-trained models**:
76 |
77 | * Off-the-shelf classification: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/ConvNeXt-TF/blob/main/notebooks/classification.ipynb)
78 | * Fine-tuning: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/ConvNeXt-TF/blob/main/notebooks/finetune.ipynb)
79 |
80 | **Randomly initialized models**:
81 |
82 | ```py
83 | from models.convnext_tf import get_convnext_model
84 |
85 | convnext_tiny = get_convnext_model()
86 | print(convnext_tiny.summary(expand_nested=True))
87 | ```
88 |
89 | To view different model configurations, refer [here](https://github.com/sayakpaul/ConvNeXt-TF/blob/main/models/model_configs.py).
90 |
91 | ## Upcoming (contributions welcome)
92 |
93 | - [ ] Align layer initializers (useful if someone wanted to train the models
94 | from scratch)
95 | - [ ] Allow the models to accept arbitrary shapes (useful for downstream tasks)
96 | - [ ] Convert the [isotropic models](https://github.com/facebookresearch/ConvNeXt#imagenet-1k-trained-models-isotropic) as well
97 | - [x] Fine-tuning notebook (thanks to [awsaf49](https://github.com/awsaf49))
98 | - [x] Off-the-shelf-classification notebook
99 | - [x] Publish models on TF-Hub
100 |
101 | ## References
102 |
103 | [1] ConvNeXt paper: https://arxiv.org/abs/2201.03545
104 |
105 | [2] Official ConvNeXt code: https://github.com/facebookresearch/ConvNeXt
106 |
107 | ## Acknowledgements
108 |
109 | * [Vasudev Gupta](https://github.com/vasudevgupta7)
110 | * [Gus](https://twitter.com/gusthema)
111 | * [Willi](https://ch.linkedin.com/in/willi-gierke)
112 | * [ML-GDE program](https://developers.google.com/programs/experts/)
113 |
--------------------------------------------------------------------------------
/notebooks/classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "zpiEiO2BUeO5",
6 | "metadata": {
7 | "id": "zpiEiO2BUeO5"
8 | },
9 | "source": [
10 | "# Off-the-shelf image classification with ConvNeXt models on TF-Hub\n",
11 | "\n",
12 | "
"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "661e6538",
28 | "metadata": {
29 | "id": "661e6538"
30 | },
31 | "source": [
32 | "## Setup"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5",
39 | "metadata": {
40 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5"
41 | },
42 | "outputs": [],
43 | "source": [
44 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9",
51 | "metadata": {
52 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9"
53 | },
54 | "outputs": [],
55 | "source": [
56 | "import tensorflow as tf\n",
57 | "import tensorflow_hub as hub\n",
58 | "from tensorflow import keras\n",
59 | "\n",
60 | "\n",
61 | "from PIL import Image\n",
62 | "from io import BytesIO\n",
63 | "\n",
64 | "import matplotlib.pyplot as plt\n",
65 | "import numpy as np\n",
66 | "import requests"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "id": "z5l1cRpiSavW",
72 | "metadata": {
73 | "id": "z5l1cRpiSavW"
74 | },
75 | "source": [
76 | "## Select a ConvNeXt ImageNet-1k model"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "a0wM8idaSaOq",
83 | "metadata": {
84 | "cellView": "form",
85 | "id": "a0wM8idaSaOq"
86 | },
87 | "outputs": [],
88 | "source": [
89 | "model_name = \"convnext_tiny_1k_224\" #@param [\"convnext_tiny_1k_224\", \"convnext_small_1k_224\", \"convnext_base_1k_224\", \"convnext_base_1k_384\", \"convnext_large_1k_224\", \"convnext_large_1k_384\", \"convnext_base_21k_1k_224\", \"convnext_base_21k_1k_384\", \"convnext_large_21k_1k_224\", \"convnext_large_21k_1k_384\", \"convnext_xlarge_21k_1k_224\", \"convnext_xlarge_21k_1k_384\"]\n",
90 | "\n",
91 | "model_handle_map ={\n",
92 | " \"convnext_tiny_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_tiny_1k_224/1\",\n",
93 | " \"convnext_small_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_small_1k_224/1\",\n",
94 | " \"convnext_base_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_base_1k_224/1\",\n",
95 | " \"convnext_base_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_base_1k_384/1\",\n",
96 | " \"convnext_large_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_large_1k_224/1\",\n",
97 | " \"convnext_large_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_large_1k_384/1\",\n",
98 | " \"convnext_base_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_base_21k_1k_224/1\",\n",
99 | " \"convnext_base_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_base_21k_1k_384/1\",\n",
100 | " \"convnext_large_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_large_21k_1k_224/1\",\n",
101 | " \"convnext_large_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_large_21k_1k_384/1\",\n",
102 | " \"convnext_xlarge_21k_1k_224\": \"https://tfhub.dev/sayakpaul/convnext_xlarge_21k_1k_224/1\",\n",
103 | " \"convnext_xlarge_21k_1k_384\": \"https://tfhub.dev/sayakpaul/convnext_xlarge_21k_1k_384/1\",\n",
104 | "\n",
105 | "}\n",
106 | "\n",
107 | "input_resolution = int(model_name.split(\"_\")[-1])\n",
108 | "model_handle = model_handle_map[model_name]\n",
109 | "print(f\"Input resolution: {input_resolution} x {input_resolution} x 3.\")\n",
110 | "print(f\"TF-Hub handle: {model_handle}.\")"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "id": "441b5361",
116 | "metadata": {
117 | "id": "441b5361"
118 | },
119 | "source": [
120 | "## Image preprocessing utilities "
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b",
127 | "metadata": {
128 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b"
129 | },
130 | "outputs": [],
131 | "source": [
132 | "crop_layer = keras.layers.CenterCrop(224, 224)\n",
133 | "norm_layer = keras.layers.Normalization(\n",
134 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n",
135 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n",
136 | ")\n",
137 | "\n",
138 | "\n",
139 | "def preprocess_image(image, size=input_resolution):\n",
140 | " image = np.array(image)\n",
141 | " image_resized = tf.expand_dims(image, 0)\n",
142 | " \n",
143 | " if size == 224:\n",
144 | " image_resized = tf.image.resize(image_resized, (256, 256), method=\"bicubic\")\n",
145 | " image_resized = crop_layer(image_resized)\n",
146 | " elif size == 384:\n",
147 | " image_resized = tf.image.resize(image, (size, size), method=\"bicubic\")\n",
148 | " \n",
149 | " return norm_layer(image_resized).numpy()\n",
150 | " \n",
151 | "\n",
152 | "def load_image_from_url(url):\n",
153 | " # Credit: Willi Gierke\n",
154 | " response = requests.get(url)\n",
155 | " image = Image.open(BytesIO(response.content))\n",
156 | " preprocessed_image = preprocess_image(image)\n",
157 | " return image, preprocessed_image"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "id": "8b961e14",
163 | "metadata": {
164 | "id": "8b961e14"
165 | },
166 | "source": [
167 | "## Load ImageNet-1k labels and a demo image"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b",
174 | "metadata": {
175 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
180 | " lines = f.readlines()\n",
181 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
182 | "\n",
183 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
184 | "image, preprocessed_image = load_image_from_url(img_url)\n",
185 | "\n",
186 | "plt.imshow(image)\n",
187 | "plt.show()"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "9006a643",
193 | "metadata": {
194 | "id": "9006a643"
195 | },
196 | "source": [
197 | "## Run inference"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908",
204 | "metadata": {
205 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908"
206 | },
207 | "outputs": [],
208 | "source": [
209 | "classification_model = tf.keras.Sequential(\n",
210 | " [hub.KerasLayer(model_handle)]\n",
211 | ") \n",
212 | "predictions = classification_model.predict(preprocessed_image)\n",
213 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n",
214 | "print(predicted_label)"
215 | ]
216 | }
217 | ],
218 | "metadata": {
219 | "accelerator": "GPU",
220 | "colab": {
221 | "machine_shape": "hm",
222 | "name": "classification.ipynb",
223 | "provenance": []
224 | },
225 | "environment": {
226 | "kernel": "python3",
227 | "name": "tf2-gpu.2-7.m84",
228 | "type": "gcloud",
229 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84"
230 | },
231 | "kernelspec": {
232 | "display_name": "Python 3 (ipykernel)",
233 | "language": "python",
234 | "name": "python3"
235 | },
236 | "language_info": {
237 | "codemirror_mode": {
238 | "name": "ipython",
239 | "version": 3
240 | },
241 | "file_extension": ".py",
242 | "mimetype": "text/x-python",
243 | "name": "python",
244 | "nbconvert_exporter": "python",
245 | "pygments_lexer": "ipython3",
246 | "version": "3.8.2"
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 5
251 | }
252 |
--------------------------------------------------------------------------------
/models/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """
9 | Originally from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
10 |
11 | This script has been slightly modified to support the conversion.
12 | Contact: spsayakpaul@gmail.com
13 |
14 | """
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | from timm.models.layers import trunc_normal_, DropPath
21 | from timm.models.registry import register_model
22 |
23 |
24 | class Block(nn.Module):
25 | r"""ConvNeXt Block. There are two equivalent implementations:
26 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
27 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
28 | We use (2) as we find it slightly faster in PyTorch
29 |
30 | Args:
31 | dim (int): Number of input channels.
32 | drop_path (float): Stochastic depth rate. Default: 0.0
33 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
34 | """
35 |
36 | def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
37 | super().__init__()
38 | self.dwconv = nn.Conv2d(
39 | dim, dim, kernel_size=7, padding=3, groups=dim
40 | ) # depthwise conv
41 | self.norm = LayerNorm(dim, eps=1e-6)
42 | self.pwconv1 = nn.Linear(
43 | dim, 4 * dim
44 | ) # pointwise/1x1 convs, implemented with linear layers
45 | self.act = nn.GELU()
46 | self.pwconv2 = nn.Linear(4 * dim, dim)
47 | self.gamma = (
48 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
49 | if layer_scale_init_value > 0
50 | else None
51 | )
52 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
53 |
54 | def forward(self, x):
55 | input = x
56 | x = self.dwconv(x)
57 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
58 | x = self.norm(x)
59 | x = self.pwconv1(x)
60 | x = self.act(x)
61 | x = self.pwconv2(x)
62 | if self.gamma is not None:
63 | x = self.gamma * x
64 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
65 |
66 | x = input + self.drop_path(x)
67 | return x
68 |
69 |
70 | class ConvNeXt(nn.Module):
71 | r"""ConvNeXt
72 | A PyTorch impl of : `A ConvNet for the 2020s` -
73 | https://arxiv.org/pdf/2201.03545.pdf
74 |
75 | Args:
76 | in_chans (int): Number of input image channels. Default: 3
77 | num_classes (int): Number of classes for classification head. Default: 1000
78 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
79 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
80 | drop_path_rate (float): Stochastic depth rate. Default: 0.
81 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
82 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
83 | """
84 |
85 | def __init__(
86 | self,
87 | in_chans=3,
88 | num_classes=1000,
89 | depths=[3, 3, 9, 3],
90 | dims=[96, 192, 384, 768],
91 | drop_path_rate=0.0,
92 | layer_scale_init_value=1e-6,
93 | head_init_scale=1.0,
94 | ):
95 | super().__init__()
96 |
97 | self.downsample_layers = (
98 | nn.ModuleList()
99 | ) # stem and 3 intermediate downsampling conv layers
100 | stem = nn.Sequential(
101 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
102 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
103 | )
104 | self.downsample_layers.append(stem)
105 | for i in range(3):
106 | downsample_layer = nn.Sequential(
107 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
108 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
109 | )
110 | self.downsample_layers.append(downsample_layer)
111 |
112 | self.stages = (
113 | nn.ModuleList()
114 | ) # 4 feature resolution stages, each consisting of multiple residual blocks
115 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
116 | cur = 0
117 | for i in range(4):
118 | stage = nn.Sequential(
119 | *[
120 | Block(
121 | dim=dims[i],
122 | drop_path=dp_rates[cur + j],
123 | layer_scale_init_value=layer_scale_init_value,
124 | )
125 | for j in range(depths[i])
126 | ]
127 | )
128 | self.stages.append(stage)
129 | cur += depths[i]
130 |
131 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
132 | self.head = nn.Linear(dims[-1], num_classes)
133 |
134 | self.apply(self._init_weights)
135 | self.head.weight.data.mul_(head_init_scale)
136 | self.head.bias.data.mul_(head_init_scale)
137 |
138 | def _init_weights(self, m):
139 | if isinstance(m, (nn.Conv2d, nn.Linear)):
140 | trunc_normal_(m.weight, std=0.02)
141 | nn.init.constant_(m.bias, 0)
142 |
143 | def forward_features(self, x):
144 | for i in range(4):
145 | x = self.downsample_layers[i](x)
146 | x = self.stages[i](x)
147 | return self.norm(
148 | x.mean([-2, -1])
149 | ) # global average pooling, (N, C, H, W) -> (N, C)
150 |
151 | def forward(self, x):
152 | x = self.forward_features(x)
153 | x = self.head(x)
154 | return x
155 |
156 |
157 | class LayerNorm(nn.Module):
158 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
159 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
160 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
161 | with shape (batch_size, channels, height, width).
162 | """
163 |
164 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
165 | super().__init__()
166 | self.weight = nn.Parameter(torch.ones(normalized_shape))
167 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
168 | self.eps = eps
169 | self.data_format = data_format
170 | if self.data_format not in ["channels_last", "channels_first"]:
171 | raise NotImplementedError
172 | self.normalized_shape = (normalized_shape,)
173 |
174 | def forward(self, x):
175 | if self.data_format == "channels_last":
176 | return F.layer_norm(
177 | x, self.normalized_shape, self.weight, self.bias, self.eps
178 | )
179 | elif self.data_format == "channels_first":
180 | u = x.mean(1, keepdim=True)
181 | s = (x - u).pow(2).mean(1, keepdim=True)
182 | x = (x - u) / torch.sqrt(s + self.eps)
183 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
184 | return x
185 |
186 |
187 | @register_model
188 | def convnext_tiny(url, **kwargs):
189 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
190 | checkpoint = torch.hub.load_state_dict_from_url(
191 | url=url, map_location="cpu", check_hash=True
192 | )
193 | model.load_state_dict(checkpoint["model"])
194 | return model
195 |
196 |
197 | @register_model
198 | def convnext_small(url, **kwargs):
199 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
200 | checkpoint = torch.hub.load_state_dict_from_url(
201 | url=url, map_location="cpu", check_hash=True
202 | )
203 | model.load_state_dict(checkpoint["model"])
204 | return model
205 |
206 |
207 | @register_model
208 | def convnext_base(url, **kwargs):
209 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
210 | checkpoint = torch.hub.load_state_dict_from_url(
211 | url=url, map_location="cpu", check_hash=True
212 | )
213 | model.load_state_dict(checkpoint["model"])
214 | return model
215 |
216 |
217 | @register_model
218 | def convnext_large(url, **kwargs):
219 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
220 | checkpoint = torch.hub.load_state_dict_from_url(
221 | url=url, map_location="cpu", check_hash=True
222 | )
223 | model.load_state_dict(checkpoint["model"])
224 | return model
225 |
226 |
227 | @register_model
228 | def convnext_xlarge(url, **kwargs):
229 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
230 | checkpoint = torch.hub.load_state_dict_from_url(
231 | url=url, map_location="cpu", check_hash=True
232 | )
233 | model.load_state_dict(checkpoint["model"])
234 | return model
235 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | from models.convnext_tf import get_convnext_model
2 | from models.model_configs import get_model_config
3 | from models import convnext
4 |
5 | from tensorflow.keras import layers
6 | import tensorflow as tf
7 | import torch
8 |
9 | import os
10 | import argparse
11 | import numpy as np
12 |
13 | torch.set_grad_enabled(False)
14 |
15 | DATASET_TO_CLASSES = {
16 | "imagenet-1k": 1000,
17 | "imagenet-21k": 21841,
18 | }
19 | MODEL_TO_METHOD = {
20 | "convnext_tiny": convnext.convnext_tiny,
21 | "convnext_small": convnext.convnext_small,
22 | "convnext_base": convnext.convnext_base,
23 | "convnext_large": convnext.convnext_large,
24 | "convnext_xlarge": convnext.convnext_xlarge,
25 | }
26 | TF_MODEL_ROOT = "saved_models"
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(
31 | description="Conversion of the PyTorch pre-trained ConvNeXt weights to TensorFlow."
32 | )
33 | parser.add_argument(
34 | "-d",
35 | "--dataset",
36 | default="imagenet-1k",
37 | type=str,
38 | required=False,
39 | choices=["imagenet-1k", "imagenet-21k"],
40 | help="Name of the pretraining dataset.",
41 | )
42 | parser.add_argument(
43 | "-m",
44 | "--model-name",
45 | default="convnext_tiny",
46 | type=str,
47 | required=False,
48 | choices=[
49 | "convnext_tiny",
50 | "convnext_small",
51 | "convnext_base",
52 | "convnext_large",
53 | "convnext_xlarge",
54 | ],
55 | help="Name of the ConvNeXt model variant.",
56 | )
57 | parser.add_argument(
58 | "-r",
59 | "--resolution",
60 | default=224,
61 | type=int,
62 | required=False,
63 | choices=[224, 384],
64 | help="Image resolution.",
65 | )
66 | parser.add_argument(
67 | "-c",
68 | "--checkpoint-path",
69 | default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
70 | type=str,
71 | required=False,
72 | help="URL of the checkpoint to be loaded.",
73 | )
74 | return vars(parser.parse_args())
75 |
76 |
77 | def main(args):
78 | print(f'Model: {args["model_name"]}')
79 | print(f'Image resolution: {args["resolution"]}')
80 | print(f'Dataset: {args["dataset"]}')
81 | print(f'Checkpoint URL: {args["checkpoint_path"]}')
82 |
83 | print("Instantiating PyTorch model and populating weights...")
84 | model_method = MODEL_TO_METHOD[args["model_name"]]
85 | convnext_model_pt = model_method(
86 | args["checkpoint_path"], num_classes=DATASET_TO_CLASSES[args["dataset"]]
87 | )
88 | convnext_model_pt.eval()
89 |
90 | print("Instantiating TensorFlow model...")
91 | model_config = get_model_config(args["model_name"])
92 |
93 | if "22k_1k" not in args["checkpoint_path"]:
94 | model_name = (
95 | f'{args["model_name"]}_1k'
96 | if args["dataset"] == "imagenet-1k"
97 | else f'{args["model_name"]}_21k'
98 | )
99 | else:
100 | model_name = f'{args["model_name"]}_21k_1k'
101 |
102 | convnext_model_tf = get_convnext_model(
103 | model_name=model_name,
104 | input_shape=(args["resolution"], args["resolution"], 3),
105 | num_classes=DATASET_TO_CLASSES[args["dataset"]],
106 | depths=model_config.depths,
107 | dims=model_config.dims,
108 | )
109 | assert convnext_model_tf.count_params() == sum(
110 | p.numel() for p in convnext_model_pt.parameters()
111 | )
112 | print("TensorFlow model instantiated, populating pretrained weights...")
113 |
114 | # Fetch the pretrained parameters.
115 | param_list = list(convnext_model_pt.parameters())
116 | model_states = convnext_model_pt.state_dict()
117 | state_list = list(model_states.keys())
118 |
119 | # Stem block.
120 | stem_block = convnext_model_tf.get_layer("stem")
121 |
122 | for layer in stem_block.layers:
123 | if isinstance(layer, layers.Conv2D):
124 | layer.kernel.assign(
125 | tf.Variable(param_list[0].numpy().transpose(2, 3, 1, 0))
126 | )
127 | layer.bias.assign(tf.Variable(param_list[1].numpy()))
128 | elif isinstance(layer, layers.LayerNormalization):
129 | layer.gamma.assign(tf.Variable(param_list[2].numpy()))
130 | layer.beta.assign(tf.Variable(param_list[3].numpy()))
131 |
132 | # Downsampling layers.
133 | for i in range(3):
134 | downsampling_block = convnext_model_tf.get_layer(f"downsampling_block_{i}")
135 | pytorch_layer_prefix = f"downsample_layers.{i + 1}"
136 |
137 | for l in downsampling_block.layers:
138 | if isinstance(l, layers.LayerNormalization):
139 | l.gamma.assign(
140 | tf.Variable(
141 | model_states[f"{pytorch_layer_prefix}.0.weight"].numpy()
142 | )
143 | )
144 | l.beta.assign(
145 | tf.Variable(model_states[f"{pytorch_layer_prefix}.0.bias"].numpy())
146 | )
147 | elif isinstance(l, layers.Conv2D):
148 | l.kernel.assign(
149 | tf.Variable(
150 | model_states[f"{pytorch_layer_prefix}.1.weight"]
151 | .numpy()
152 | .transpose(2, 3, 1, 0)
153 | )
154 | )
155 | l.bias.assign(
156 | tf.Variable(model_states[f"{pytorch_layer_prefix}.1.bias"].numpy())
157 | )
158 |
159 | # ConvNeXt stages.
160 | num_stages = 4
161 |
162 | for m in range(num_stages):
163 | stage_name = f"convnext_stage_{m}"
164 | num_blocks = len(convnext_model_tf.get_layer(stage_name).layers)
165 |
166 | for i in range(num_blocks):
167 | stage_block = convnext_model_tf.get_layer(stage_name).get_layer(
168 | f"convnext_block_{m}_{i}"
169 | )
170 | stage_prefix = f"stages.{m}.{i}"
171 |
172 | for j, layer in enumerate(stage_block.layers):
173 | if isinstance(layer, layers.Conv2D):
174 | layer.kernel.assign(
175 | tf.Variable(
176 | model_states[f"{stage_prefix}.dwconv.weight"]
177 | .numpy()
178 | .transpose(2, 3, 1, 0)
179 | )
180 | )
181 | layer.bias.assign(
182 | tf.Variable(model_states[f"{stage_prefix}.dwconv.bias"].numpy())
183 | )
184 | elif isinstance(layer, layers.Dense):
185 | if j == 2:
186 | layer.kernel.assign(
187 | tf.Variable(
188 | model_states[f"{stage_prefix}.pwconv1.weight"]
189 | .numpy()
190 | .transpose()
191 | )
192 | )
193 | layer.bias.assign(
194 | tf.Variable(
195 | model_states[f"{stage_prefix}.pwconv1.bias"].numpy()
196 | )
197 | )
198 | elif j == 4:
199 | layer.kernel.assign(
200 | tf.Variable(
201 | model_states[f"{stage_prefix}.pwconv2.weight"]
202 | .numpy()
203 | .transpose()
204 | )
205 | )
206 | layer.bias.assign(
207 | tf.Variable(
208 | model_states[f"{stage_prefix}.pwconv2.bias"].numpy()
209 | )
210 | )
211 | elif isinstance(layer, layers.LayerNormalization):
212 | layer.gamma.assign(
213 | tf.Variable(model_states[f"{stage_prefix}.norm.weight"].numpy())
214 | )
215 | layer.beta.assign(
216 | tf.Variable(model_states[f"{stage_prefix}.norm.bias"].numpy())
217 | )
218 |
219 | stage_block.gamma.assign(
220 | tf.Variable(model_states[f"{stage_prefix}.gamma"].numpy())
221 | )
222 |
223 | # Final LayerNormalization layer and classifier head.
224 | convnext_model_tf.layers[-2].gamma.assign(
225 | tf.Variable(model_states[state_list[-4]].numpy())
226 | )
227 | convnext_model_tf.layers[-2].beta.assign(
228 | tf.Variable(model_states[state_list[-3]].numpy())
229 | )
230 |
231 | convnext_model_tf.layers[-1].kernel.assign(
232 | tf.Variable(model_states[state_list[-2]].numpy().transpose())
233 | )
234 | convnext_model_tf.layers[-1].bias.assign(
235 | tf.Variable(model_states[state_list[-1]].numpy())
236 | )
237 | print("Weight population successful, serializing TensorFlow model...")
238 |
239 | model_name = f'{model_name}_{args["resolution"]}'
240 | save_path = os.path.join(TF_MODEL_ROOT, model_name)
241 | convnext_model_tf.save(save_path)
242 | print(f"TensorFlow model serialized to: {save_path}...")
243 |
244 |
245 | if __name__ == "__main__":
246 | args = parse_args()
247 | main(args)
248 |
--------------------------------------------------------------------------------
/i1k_eval/eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "926bae46",
6 | "metadata": {},
7 | "source": [
8 | "## Imports"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "bae636ae-24d1-4523-9997-696731318a81",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from tensorflow.keras import layers\n",
19 | "from tensorflow import keras\n",
20 | "import tensorflow_hub as hub\n",
21 | "import tensorflow as tf\n",
22 | "\n",
23 | "from imutils import paths\n",
24 | "import json\n",
25 | "import re"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "96c4f0a2",
31 | "metadata": {},
32 | "source": [
33 | "## Constants"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "id": "f8238055-08bf-44e1-8f3b-98e7768f1603",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "AUTO = tf.data.AUTOTUNE\n",
44 | "BATCH_SIZE = 256\n",
45 | "IMAGE_SIZE = 224\n",
46 | "TF_MODEL_ROOT = \"gs://convnext/saved_models\""
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "74edcf20",
52 | "metadata": {},
53 | "source": [
54 | "## Set up ImageNet-1k labels"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "id": "334993ee-0d91-4572-9721-03e67af28cb3",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "with open(\"imagenet_class_index.json\", \"r\") as read_file:\n",
65 | " imagenet_labels = json.load(read_file)\n",
66 | "\n",
67 | "MAPPING_DICT = {}\n",
68 | "LABEL_NAMES = {}\n",
69 | "for label_id in list(imagenet_labels.keys()):\n",
70 | " MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)\n",
71 | " LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "id": "01ad5447-3e28-4c86-941f-f64b45be603a",
78 | "metadata": {},
79 | "outputs": [
80 | {
81 | "data": {
82 | "text/plain": [
83 | "(['val/n01751748/ILSVRC2012_val_00031060.JPEG',\n",
84 | " 'val/n01751748/ILSVRC2012_val_00013492.JPEG',\n",
85 | " 'val/n01751748/ILSVRC2012_val_00033108.JPEG',\n",
86 | " 'val/n01751748/ILSVRC2012_val_00021437.JPEG',\n",
87 | " 'val/n01751748/ILSVRC2012_val_00025096.JPEG'],\n",
88 | " [65, 65, 65, 65, 65])"
89 | ]
90 | },
91 | "execution_count": 4,
92 | "metadata": {},
93 | "output_type": "execute_result"
94 | }
95 | ],
96 | "source": [
97 | "all_val_paths = list(paths.list_images(\"val\"))\n",
98 | "all_val_labels = [MAPPING_DICT[x.split(\"/\")[1]] for x in all_val_paths]\n",
99 | "\n",
100 | "all_val_paths[:5], all_val_labels[:5]"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "id": "1124817d",
106 | "metadata": {},
107 | "source": [
108 | "## Preprocessing utilities"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 5,
114 | "id": "5a4f03d8-25d1-4660-9858-1b197425d5d9",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "def load_and_prepare(path, label):\n",
119 | " image = tf.io.read_file(path)\n",
120 | " image = tf.image.decode_png(image, channels=3)\n",
121 | " image = tf.image.resize(image, (256, 256), method=\"bicubic\")\n",
122 | " return image, label"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 6,
128 | "id": "d0a2f457-ad4d-4cbd-8dfa-db544c7f6531",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "# Reference: https://github.com/facebookresearch/ConvNeXt/blob/main/datasets.py\n",
133 | "def get_preprocessing_model(input_size=224):\n",
134 | " preprocessing_model = keras.Sequential()\n",
135 | "\n",
136 | " preprocessing_model.add(layers.CenterCrop(input_size, input_size))\n",
137 | " preprocessing_model.add(layers.Normalization(\n",
138 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n",
139 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n",
140 | " ))\n",
141 | "\n",
142 | " return preprocessing_model"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "id": "56d33240",
148 | "metadata": {},
149 | "source": [
150 | "## Prepare `tf.data.Dataset`"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 7,
156 | "id": "f3518397-2ab0-4d79-adea-ae5f1cb66add",
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "name": "stderr",
161 | "output_type": "stream",
162 | "text": [
163 | "2022-01-31 03:20:05.306146: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
164 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
165 | "2022-01-31 03:20:05.828828: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory: -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0\n"
166 | ]
167 | },
168 | {
169 | "data": {
170 | "text/plain": [
171 | "(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),\n",
172 | " TensorSpec(shape=(None,), dtype=tf.int32, name=None))"
173 | ]
174 | },
175 | "execution_count": 7,
176 | "metadata": {},
177 | "output_type": "execute_result"
178 | }
179 | ],
180 | "source": [
181 | "preprocessor = get_preprocessing_model()\n",
182 | "\n",
183 | "dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))\n",
184 | "dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)\n",
185 | "dataset = dataset.map(lambda x, y: (preprocessor(x), y), num_parallel_calls=AUTO)\n",
186 | "dataset = dataset.prefetch(AUTO)\n",
187 | "dataset.element_spec"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "ea42076a",
193 | "metadata": {},
194 | "source": [
195 | "## Fetch model paths and filter the 224x224 models"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 8,
201 | "id": "13a7e46e-31b2-48b9-9a57-2873fe27397a",
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "['convnext_base_1k_224/', 'convnext_base_21k_1k_224/', 'convnext_large_1k_224/', 'convnext_large_21k_1k_224/', 'convnext_small_1k_224/', 'convnext_tiny_1k_224/', 'convnext_xlarge_21k_1k_224/']\n"
209 | ]
210 | }
211 | ],
212 | "source": [
213 | "model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)\n",
214 | "models_res_224 = [model_path for model_path in model_paths if str(IMAGE_SIZE) in model_path]\n",
215 | "p = re.compile('.*_21k_224')\n",
216 | "i1k_paths = [path for path in models_res_224 if not p.match(path)]\n",
217 | "\n",
218 | "print(i1k_paths)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "id": "2b7e8e68",
224 | "metadata": {},
225 | "source": [
226 | "## Run evaluation"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 13,
232 | "id": "63a3da22-a60f-48b8-a0b0-02e54b2d012f",
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "def get_model(model_url):\n",
237 | " classification_model = tf.keras.Sequential(\n",
238 | " [\n",
239 | " layers.InputLayer((224, 224, 3)),\n",
240 | " hub.KerasLayer(model_url),\n",
241 | " ]\n",
242 | " )\n",
243 | " return classification_model\n",
244 | "\n",
245 | "\n",
246 | "def evaluate_model(model_name):\n",
247 | " tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f\"logs_{model_name}\")\n",
248 | " model_url = TF_MODEL_ROOT + \"/\" + model_name\n",
249 | " \n",
250 | " model = get_model(model_url)\n",
251 | " model.compile(metrics=[\"accuracy\"])\n",
252 | " _, accuracy = model.evaluate(dataset, callbacks=[tb_callback])\n",
253 | " accuracy = round(accuracy * 100, 4)\n",
254 | " print(f\"{model_name}: {accuracy}%.\", file=open(f\"{model_name.strip('/')}.txt\", \"w\"))"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 14,
260 | "id": "4db846db-86bc-4ae8-a699-acb69331d93c",
261 | "metadata": {},
262 | "outputs": [
263 | {
264 | "name": "stdout",
265 | "output_type": "stream",
266 | "text": [
267 | "Evaluating convnext_base_1k_224/.\n",
268 | "196/196 [==============================] - 118s 586ms/step - loss: 0.0000e+00 - accuracy: 0.8328\n",
269 | "Evaluating convnext_base_21k_1k_224/.\n",
270 | "196/196 [==============================] - 118s 585ms/step - loss: 0.0000e+00 - accuracy: 0.8536\n",
271 | "Evaluating convnext_large_1k_224/.\n",
272 | "196/196 [==============================] - 177s 879ms/step - loss: 0.0000e+00 - accuracy: 0.8384\n",
273 | "Evaluating convnext_large_21k_1k_224/.\n",
274 | "196/196 [==============================] - 175s 876ms/step - loss: 0.0000e+00 - accuracy: 0.8636\n",
275 | "Evaluating convnext_small_1k_224/.\n",
276 | "196/196 [==============================] - 92s 451ms/step - loss: 0.0000e+00 - accuracy: 0.8239\n",
277 | "Evaluating convnext_tiny_1k_224/.\n",
278 | "196/196 [==============================] - 60s 293ms/step - loss: 0.0000e+00 - accuracy: 0.8131\n",
279 | "Evaluating convnext_xlarge_21k_1k_224/.\n",
280 | "196/196 [==============================] - 241s 1s/step - loss: 0.0000e+00 - accuracy: 0.8673\n"
281 | ]
282 | }
283 | ],
284 | "source": [
285 | "for i1k_path in i1k_paths:\n",
286 | " print(f\"Evaluating {i1k_path}.\")\n",
287 | " evaluate_model(i1k_path)"
288 | ]
289 | }
290 | ],
291 | "metadata": {
292 | "environment": {
293 | "kernel": "python3",
294 | "name": "tf2-gpu.2-7.m84",
295 | "type": "gcloud",
296 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84"
297 | },
298 | "kernelspec": {
299 | "display_name": "Python 3",
300 | "language": "python",
301 | "name": "python3"
302 | },
303 | "language_info": {
304 | "codemirror_mode": {
305 | "name": "ipython",
306 | "version": 3
307 | },
308 | "file_extension": ".py",
309 | "mimetype": "text/x-python",
310 | "name": "python",
311 | "nbconvert_exporter": "python",
312 | "pygments_lexer": "ipython3",
313 | "version": "3.7.12"
314 | }
315 | },
316 | "nbformat": 4,
317 | "nbformat_minor": 5
318 | }
319 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 Meta Platforms Inc. and Sayak Paul
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/notebooks/finetune.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "# Fine-tuning for image classification with ConvNeXt models on TF-Hub\n",
11 | "\n",
12 | ""
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {
28 | "id": "89B27-TGiDNB"
29 | },
30 | "source": [
31 | "## Imports"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "id": "9u3d4Z7uQsmp"
39 | },
40 | "outputs": [],
41 | "source": [
42 | "from tensorflow import keras\n",
43 | "import tensorflow as tf\n",
44 | "import tensorflow_hub as hub\n",
45 | "import tensorflow_datasets as tfds\n",
46 | "\n",
47 | "tfds.disable_progress_bar()\n",
48 | "\n",
49 | "import os\n",
50 | "import sys\n",
51 | "import math\n",
52 | "import numpy as np\n",
53 | "import pandas as pd\n",
54 | "import matplotlib.pyplot as plt"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "id": "mPo10cahZXXQ"
61 | },
62 | "source": [
63 | "## TPU/GPU detection"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "id": "FpvUOuC3j27n"
71 | },
72 | "outputs": [],
73 | "source": [
74 | "try: # detect TPUs\n",
75 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection\n",
76 | " strategy = tf.distribute.TPUStrategy(tpu)\n",
77 | "except ValueError: # detect GPUs\n",
78 | " tpu = False\n",
79 | " strategy = (\n",
80 | " tf.distribute.get_strategy()\n",
81 | " ) # default strategy that works on CPU and single GPU\n",
82 | "print(\"Number of Accelerators: \", strategy.num_replicas_in_sync)"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "id": "w9S3uKC_iXY5"
89 | },
90 | "source": [
91 | "## Configuration"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {
98 | "id": "kCc6tdUGnD4C"
99 | },
100 | "outputs": [],
101 | "source": [
102 | "# Model\n",
103 | "IMAGE_SIZE = [224, 224]\n",
104 | "MODEL_PATH = \"https://tfhub.dev/sayakpaul/convnext_tiny_1k_224_fe/1\"\n",
105 | "\n",
106 | "# TPU\n",
107 | "if tpu:\n",
108 | " BATCH_SIZE = (\n",
109 | " 16 * strategy.num_replicas_in_sync\n",
110 | " ) # a TPU has 8 cores so this will be 128\n",
111 | "else:\n",
112 | " BATCH_SIZE = 64 # on Colab/GPU, a higher batch size may throw(OOM)\n",
113 | "\n",
114 | "# Dataset\n",
115 | "CLASSES = [\n",
116 | " \"dandelion\",\n",
117 | " \"daisy\",\n",
118 | " \"tulips\",\n",
119 | " \"sunflowers\",\n",
120 | " \"roses\",\n",
121 | "] # don't change the order\n",
122 | "\n",
123 | "# Other constants\n",
124 | "MEAN = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255]) # imagenet mean\n",
125 | "STD = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255]) # imagenet std\n",
126 | "AUTO = tf.data.AUTOTUNE"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {
132 | "id": "9iTImGI5qMQT"
133 | },
134 | "source": [
135 | "# Data Pipeline"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {
142 | "id": "h29TLx7gqN_7"
143 | },
144 | "outputs": [],
145 | "source": [
146 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):\n",
147 | " def preprocess(image, label):\n",
148 | " # for training, do augmentation\n",
149 | " if train:\n",
150 | " if tf.random.uniform(shape=[]) > 0.5:\n",
151 | " image = tf.image.flip_left_right(image)\n",
152 | " image = tf.image.resize(image, size=image_size, method=\"bicubic\")\n",
153 | " image = (image - MEAN) / STD # normalization\n",
154 | " return image, label\n",
155 | "\n",
156 | " if train:\n",
157 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n",
158 | "\n",
159 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {
165 | "id": "AMQ3Qs9_pddU"
166 | },
167 | "source": [
168 | "# Flower Dataset"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "id": "M3G-2aUBQJ-H"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "train_dataset, val_dataset = tfds.load(\n",
180 | " \"tf_flowers\",\n",
181 | " split=[\"train[:90%]\", \"train[90%:]\"],\n",
182 | " as_supervised=True,\n",
183 | " try_gcs=False, # gcs_path is necessary for tpu,\n",
184 | ")\n",
185 | "\n",
186 | "num_train = tf.data.experimental.cardinality(train_dataset)\n",
187 | "num_val = tf.data.experimental.cardinality(val_dataset)\n",
188 | "print(f\"Number of training examples: {num_train}\")\n",
189 | "print(f\"Number of validation examples: {num_val}\")\n"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {
195 | "id": "l2X7sE3oRLXN"
196 | },
197 | "source": [
198 | "## Prepare dataset"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {
205 | "id": "oftrfYw1qXei"
206 | },
207 | "outputs": [],
208 | "source": [
209 | "train_dataset = make_dataset(train_dataset, True)\n",
210 | "val_dataset = make_dataset(val_dataset, False)"
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {
216 | "id": "kNyCCM6PRM8I"
217 | },
218 | "source": [
219 | "## Visualize"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {
226 | "id": "IaGzFUUVqjaC"
227 | },
228 | "outputs": [],
229 | "source": [
230 | "sample_images, sample_labels = next(iter(train_dataset))\n",
231 | "\n",
232 | "plt.figure(figsize=(5 * 3, 3 * 3))\n",
233 | "for n in range(15):\n",
234 | " ax = plt.subplot(3, 5, n + 1)\n",
235 | " image = (sample_images[n] * STD + MEAN).numpy()\n",
236 | " image = (image - image.min()) / (\n",
237 | " image.max() - image.min()\n",
238 | " ) # convert to [0, 1] for avoiding matplotlib warning\n",
239 | " plt.imshow(image)\n",
240 | " plt.title(CLASSES[sample_labels[n]])\n",
241 | " plt.axis(\"off\")\n",
242 | "plt.tight_layout()\n",
243 | "plt.show()\n"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {
249 | "id": "Qf6u_7tt8BYy"
250 | },
251 | "source": [
252 | "# LR Scheduler Utility"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {
259 | "id": "oVTbnkJL79T_"
260 | },
261 | "outputs": [],
262 | "source": [
263 | "# Reference:\n",
264 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n",
265 | "\n",
266 | "\n",
267 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n",
268 | " def __init__(\n",
269 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n",
270 | " ):\n",
271 | " super(WarmUpCosine, self).__init__()\n",
272 | "\n",
273 | " self.learning_rate_base = learning_rate_base\n",
274 | " self.total_steps = total_steps\n",
275 | " self.warmup_learning_rate = warmup_learning_rate\n",
276 | " self.warmup_steps = warmup_steps\n",
277 | " self.pi = tf.constant(np.pi)\n",
278 | "\n",
279 | " def __call__(self, step):\n",
280 | " if self.total_steps < self.warmup_steps:\n",
281 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
282 | " learning_rate = (\n",
283 | " 0.5\n",
284 | " * self.learning_rate_base\n",
285 | " * (\n",
286 | " 1\n",
287 | " + tf.cos(\n",
288 | " self.pi\n",
289 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
290 | " / float(self.total_steps - self.warmup_steps)\n",
291 | " )\n",
292 | " )\n",
293 | " )\n",
294 | "\n",
295 | " if self.warmup_steps > 0:\n",
296 | " if self.learning_rate_base < self.warmup_learning_rate:\n",
297 | " raise ValueError(\n",
298 | " \"Learning_rate_base must be larger or equal to \"\n",
299 | " \"warmup_learning_rate.\"\n",
300 | " )\n",
301 | " slope = (\n",
302 | " self.learning_rate_base - self.warmup_learning_rate\n",
303 | " ) / self.warmup_steps\n",
304 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
305 | " learning_rate = tf.where(\n",
306 | " step < self.warmup_steps, warmup_rate, learning_rate\n",
307 | " )\n",
308 | " return tf.where(\n",
309 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
310 | " )"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {
316 | "id": "ALtRUlxhw8Vt"
317 | },
318 | "source": [
319 | "# Model Utility"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {
326 | "id": "JD9SI_Q9JdAB"
327 | },
328 | "outputs": [],
329 | "source": [
330 | "def get_model(model_path=MODEL_PATH, res=224, num_classes=5):\n",
331 | " hub_layer = hub.KerasLayer(model_path, trainable=True)\n",
332 | "\n",
333 | " model = keras.Sequential(\n",
334 | " [\n",
335 | " keras.layers.InputLayer((res, res, 3)),\n",
336 | " hub_layer,\n",
337 | " keras.layers.Dense(num_classes, activation=\"softmax\"),\n",
338 | " ]\n",
339 | " )\n",
340 | " return model"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {
347 | "id": "wpZApp9u9_Y-"
348 | },
349 | "outputs": [],
350 | "source": [
351 | "get_model().summary()"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {
357 | "id": "dMfenMQcxAAb"
358 | },
359 | "source": [
360 | "# Training Hyperparameters"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {
367 | "id": "1D7Iu7oD8WzX"
368 | },
369 | "outputs": [],
370 | "source": [
371 | "EPOCHS = 10\n",
372 | "WARMUP_STEPS = 10\n",
373 | "INIT_LR = 0.03\n",
374 | "WAMRUP_LR = 0.006\n",
375 | "\n",
376 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "metadata": {
383 | "id": "zmolHMov8als"
384 | },
385 | "outputs": [],
386 | "source": [
387 | "scheduled_lrs = WarmUpCosine(\n",
388 | " learning_rate_base=INIT_LR,\n",
389 | " total_steps=TOTAL_STEPS,\n",
390 | " warmup_learning_rate=WAMRUP_LR,\n",
391 | " warmup_steps=WARMUP_STEPS,\n",
392 | ")\n",
393 | "\n",
394 | "lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]\n",
395 | "plt.figure(figsize=(10, 6))\n",
396 | "plt.plot(lrs)\n",
397 | "plt.xlabel(\"Step\", fontsize=14)\n",
398 | "plt.ylabel(\"LR\", fontsize=14)\n",
399 | "plt.show()\n"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": null,
405 | "metadata": {
406 | "id": "M-ID7vP5mIKs"
407 | },
408 | "outputs": [],
409 | "source": [
410 | "optimizer = keras.optimizers.SGD(scheduled_lrs)\n",
411 | "loss = keras.losses.SparseCategoricalCrossentropy()"
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "metadata": {
417 | "id": "E9p4ymNh9y7d"
418 | },
419 | "source": [
420 | "# Training & Validation"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": null,
426 | "metadata": {
427 | "id": "VnZTSd8K90Mq"
428 | },
429 | "outputs": [],
430 | "source": [
431 | "with strategy.scope(): # this line is all that is needed to run on TPU (or multi-GPU, ...)\n",
432 | " model = get_model(MODEL_PATH)\n",
433 | " model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n",
434 | "\n",
435 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "metadata": {
442 | "id": "jc7LMVz5Cbx6"
443 | },
444 | "outputs": [],
445 | "source": [
446 | "result = pd.DataFrame(history.history)\n",
447 | "fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n",
448 | "result[[\"accuracy\", \"val_accuracy\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[0])\n",
449 | "result[[\"loss\", \"val_loss\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[1])\n"
450 | ]
451 | },
452 | {
453 | "cell_type": "markdown",
454 | "metadata": {
455 | "id": "MKFMWzh0Yxsq"
456 | },
457 | "source": [
458 | "# Predictions"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": null,
464 | "metadata": {
465 | "id": "yMEsR851VDZb"
466 | },
467 | "outputs": [],
468 | "source": [
469 | "sample_images, sample_labels = next(iter(val_dataset))\n",
470 | "\n",
471 | "predictions = model.predict(sample_images, batch_size=16).argmax(axis=-1)\n",
472 | "evaluations = model.evaluate(sample_images, sample_labels, batch_size=16)\n",
473 | "\n",
474 | "print(\"[val_loss, val_acc]\", evaluations)"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": null,
480 | "metadata": {
481 | "id": "qzCCDL1CZFx6"
482 | },
483 | "outputs": [],
484 | "source": [
485 | "plt.figure(figsize=(5 * 3, 3 * 3))\n",
486 | "for n in range(15):\n",
487 | " ax = plt.subplot(3, 5, n + 1)\n",
488 | " image = (sample_images[n] * STD + MEAN).numpy()\n",
489 | " image = (image - image.min()) / (\n",
490 | " image.max() - image.min()\n",
491 | " ) # convert to [0, 1] for avoiding matplotlib warning\n",
492 | " plt.imshow(image)\n",
493 | " target = CLASSES[sample_labels[n]]\n",
494 | " pred = CLASSES[predictions[n]]\n",
495 | " plt.title(\"{} ({})\".format(target, pred))\n",
496 | " plt.axis(\"off\")\n",
497 | "plt.tight_layout()\n",
498 | "plt.show()\n"
499 | ]
500 | },
501 | {
502 | "cell_type": "markdown",
503 | "metadata": {
504 | "id": "2e5oy9zmNNID"
505 | },
506 | "source": [
507 | "# Reference\n",
508 | "* [ConvNeXt-TF](https://github.com/sayakpaul/ConvNeXt-TF)\n",
509 | "* [Keras Flowers on TPU (solution)](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_solution.ipynb)"
510 | ]
511 | }
512 | ],
513 | "metadata": {
514 | "accelerator": "GPU",
515 | "colab": {
516 | "collapsed_sections": [],
517 | "include_colab_link": true,
518 | "machine_shape": "hm",
519 | "name": "ConvNext-TF: Flower Classification (TPU/GPU).ipynb",
520 | "provenance": [],
521 | "toc_visible": true
522 | },
523 | "environment": {
524 | "name": "tf22-cpu.2-2.m47",
525 | "type": "gcloud",
526 | "uri": "gcr.io/deeplearning-platform-release/tf22-cpu.2-2:m47"
527 | },
528 | "kernelspec": {
529 | "display_name": "Python 3 (ipykernel)",
530 | "language": "python",
531 | "name": "python3"
532 | },
533 | "language_info": {
534 | "codemirror_mode": {
535 | "name": "ipython",
536 | "version": 3
537 | },
538 | "file_extension": ".py",
539 | "mimetype": "text/x-python",
540 | "name": "python",
541 | "nbconvert_exporter": "python",
542 | "pygments_lexer": "ipython3",
543 | "version": "3.8.2"
544 | }
545 | },
546 | "nbformat": 4,
547 | "nbformat_minor": 1
548 | }
549 |
--------------------------------------------------------------------------------
/i1k_eval/imagenet_class_index.json:
--------------------------------------------------------------------------------
1 | {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]}
--------------------------------------------------------------------------------