├── jax_to_pytorch ├── weights │ └── .gitkeep ├── README.md ├── jax_weights │ └── download.sh ├── convert.py ├── explore-conversion-21k.ipynb └── explore-conversion.ipynb ├── examples ├── simple │ ├── img.jpg │ ├── img2.jpg │ ├── imagenet-21k-labels.py │ └── labels_map.txt └── imagenet │ ├── data │ └── README.md │ ├── README.md │ └── main.py ├── pytorch_pretrained_vit ├── __init__.py ├── configs.py ├── transformer.py ├── utils.py └── model.py ├── .gitignore ├── setup.py └── README.md /jax_to_pytorch/weights/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/simple/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/PyTorch-Pretrained-ViT/HEAD/examples/simple/img.jpg -------------------------------------------------------------------------------- /examples/simple/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/PyTorch-Pretrained-ViT/HEAD/examples/simple/img2.jpg -------------------------------------------------------------------------------- /pytorch_pretrained_vit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.7" 2 | 3 | from .model import ViT 4 | from .configs import * 5 | from .utils import load_pretrained_weights 6 | -------------------------------------------------------------------------------- /jax_to_pytorch/README.md: -------------------------------------------------------------------------------- 1 | ### Jax to PyTorch Conversion 2 | 3 | This directory is used to convert Jax weights to PyTorch. 4 | 5 | First `cd jax_weights` and download the models. Then return here and run `convert.py`. The weights will be saved to `weights/`. 6 | 7 | -------------------------------------------------------------------------------- /examples/imagenet/data/README.md: -------------------------------------------------------------------------------- 1 | ### ImageNet 2 | 3 | Download ImageNet and place it into `train` and `val` folders here. 4 | 5 | More details may be found with the official PyTorch ImageNet example [here](https://github.com/pytorch/examples/blob/master/imagenet). 6 | -------------------------------------------------------------------------------- /examples/imagenet/README.md: -------------------------------------------------------------------------------- 1 | ### Imagenet Evaluation 2 | 3 | Place your `train` and `val` directories in `data`. 4 | 5 | Example commands: 6 | ```bash 7 | # Evaluate ViT on CPU 8 | python main.py data -e -a 'B_16_imagenet1k' --vit --pretrained -b 16 --image_size 384 9 | 10 | # Evaluate ViT on GPU 11 | python main.py data -e -a 'B_16_imagenet1k' --vit --pretrained -b 16 --image_size 384 --gpu 0 12 | ``` -------------------------------------------------------------------------------- /examples/simple/imagenet-21k-labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | wget http://www.image-net.org/api/text/imagenet.synset.obtain_synset_list 3 | """ 4 | 5 | from pathlib import Path 6 | from nltk.corpus import wordnet 7 | 8 | classes_file = 'imagenet.synset.obtain_synset_list' 9 | output_file = 'labels_map_21k.txt' 10 | with open(Path(classes_file)) as f: 11 | classes = f.read().splitlines() 12 | classes = [c for c in classes if c != ''] 13 | assert len(classes) in [1000, 21_841] 14 | classes = [wordnet.synset_from_pos_and_offset('n', int(c[1:])).lemmas()[0].name() for c in classes] 15 | with open(output_file, 'w') as f: 16 | print('\n'.join(classes), file=f) 17 | -------------------------------------------------------------------------------- /jax_to_pytorch/jax_weights/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # B_16 4 | wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz 5 | 6 | # B_32 7 | wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_32.npz 8 | 9 | # L_32 10 | wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_32.npz 11 | 12 | # B_16_imagenet1k 13 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16.npz 14 | 15 | # B_32_imagenet1k 16 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_32.npz 17 | 18 | # L_16_imagenet1k 19 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz 20 | 21 | # L_32_imagenet1k 22 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_32.npz 23 | 24 | # B_16_imagenet1k_224 25 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16-224.npz 26 | 27 | # L_16_imagenet1k_224 28 | wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16-224.npz 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | tmp 3 | *.pkl 4 | .vscode 5 | *.npy 6 | *.npz 7 | *.pth 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .DS_STORE 114 | 115 | # PyCharm 116 | .idea* 117 | *.xml 118 | 119 | # Custom 120 | tensorflow/ 121 | example/test* 122 | *.pth* 123 | examples/imagenet/data/ 124 | !examples/imagenet/data/README.md 125 | tmp 126 | tf_to_pytorch/pretrained_tensorflow 127 | !tf_to_pytorch/pretrained_tensorflow/download.sh 128 | examples/imagenet/run.sh 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /pytorch_pretrained_vit/configs.py: -------------------------------------------------------------------------------- 1 | """configs.py - ViT model configurations, based on: 2 | https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py 3 | """ 4 | 5 | def get_base_config(): 6 | """Base ViT config ViT""" 7 | return dict( 8 | dim=768, 9 | ff_dim=3072, 10 | num_heads=12, 11 | num_layers=12, 12 | attention_dropout_rate=0.0, 13 | dropout_rate=0.1, 14 | representation_size=768, 15 | classifier='token' 16 | ) 17 | 18 | def get_b16_config(): 19 | """Returns the ViT-B/16 configuration.""" 20 | config = get_base_config() 21 | config.update(dict(patches=(16, 16))) 22 | return config 23 | 24 | def get_b32_config(): 25 | """Returns the ViT-B/32 configuration.""" 26 | config = get_b16_config() 27 | config.update(dict(patches=(32, 32))) 28 | return config 29 | 30 | def get_l16_config(): 31 | """Returns the ViT-L/16 configuration.""" 32 | config = get_base_config() 33 | config.update(dict( 34 | patches=(16, 16), 35 | dim=1024, 36 | ff_dim=4096, 37 | num_heads=16, 38 | num_layers=24, 39 | attention_dropout_rate=0.0, 40 | dropout_rate=0.1, 41 | representation_size=1024 42 | )) 43 | return config 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.update(dict(patches=(32, 32))) 49 | return config 50 | 51 | def drop_head_variant(config): 52 | config.update(dict(representation_size=None)) 53 | return config 54 | 55 | 56 | PRETRAINED_MODELS = { 57 | 'B_16': { 58 | 'config': get_b16_config(), 59 | 'num_classes': 21843, 60 | 'image_size': (224, 224), 61 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth" 62 | }, 63 | 'B_32': { 64 | 'config': get_b32_config(), 65 | 'num_classes': 21843, 66 | 'image_size': (224, 224), 67 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" 68 | }, 69 | 'L_16': { 70 | 'config': get_l16_config(), 71 | 'num_classes': 21843, 72 | 'image_size': (224, 224), 73 | 'url': None 74 | }, 75 | 'L_32': { 76 | 'config': get_l32_config(), 77 | 'num_classes': 21843, 78 | 'image_size': (224, 224), 79 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth" 80 | }, 81 | 'B_16_imagenet1k': { 82 | 'config': drop_head_variant(get_b16_config()), 83 | 'num_classes': 1000, 84 | 'image_size': (384, 384), 85 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth" 86 | }, 87 | 'B_32_imagenet1k': { 88 | 'config': drop_head_variant(get_b32_config()), 89 | 'num_classes': 1000, 90 | 'image_size': (384, 384), 91 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth" 92 | }, 93 | 'L_16_imagenet1k': { 94 | 'config': drop_head_variant(get_l16_config()), 95 | 'num_classes': 1000, 96 | 'image_size': (384, 384), 97 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth" 98 | }, 99 | 'L_32_imagenet1k': { 100 | 'config': drop_head_variant(get_l32_config()), 101 | 'num_classes': 1000, 102 | 'image_size': (384, 384), 103 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth" 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /pytorch_pretrained_vit/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.scores = None # for visualization 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | scores = self.drop(F.softmax(scores, dim=-1)) 53 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 54 | h = (scores @ v).transpose(1, 2).contiguous() 55 | # -merge-> (B, S, D) 56 | h = merge_last(h, 2) 57 | self.scores = scores 58 | return h 59 | 60 | 61 | class PositionWiseFeedForward(nn.Module): 62 | """FeedForward Neural Networks for each position""" 63 | def __init__(self, dim, ff_dim): 64 | super().__init__() 65 | self.fc1 = nn.Linear(dim, ff_dim) 66 | self.fc2 = nn.Linear(ff_dim, dim) 67 | 68 | def forward(self, x): 69 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 70 | return self.fc2(F.gelu(self.fc1(x))) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, ff_dim, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 78 | self.proj = nn.Linear(dim, dim) 79 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 80 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.drop = nn.Dropout(dropout) 83 | 84 | def forward(self, x, mask): 85 | h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 86 | x = x + h 87 | h = self.drop(self.pwff(self.norm2(x))) 88 | x = x + h 89 | return x 90 | 91 | 92 | class Transformer(nn.Module): 93 | """Transformer with Self-Attentive Blocks""" 94 | def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 95 | super().__init__() 96 | self.blocks = nn.ModuleList([ 97 | Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 98 | 99 | def forward(self, x, mask=None): 100 | for block in self.blocks: 101 | x = block(x, mask) 102 | return x 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'pytorch-pretrained-vit' 16 | DESCRIPTION = 'Visual Transformers (ViT) in PyTorch.' 17 | URL = 'https://github.com/lukemelas/ViT-PyTorch' 18 | EMAIL = 'luke.melas@gmail.com' 19 | AUTHOR = 'Luke' 20 | REQUIRES_PYTHON = '>=3.5.0' 21 | VERSION = '0.0.7' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | 'torch' 26 | ] 27 | 28 | # What packages are optional? 29 | EXTRAS = { 30 | # 'fancy feature': ['django'], 31 | } 32 | 33 | # The rest you shouldn't have to touch too much :) 34 | # ------------------------------------------------ 35 | # Except, perhaps the License and Trove Classifiers! 36 | # If you do change the License, remember to change the Trove Classifier for that! 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | # Import the README and use it as the long-description. 41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 42 | try: 43 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 44 | long_description = '\n' + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Load the package's __version__.py module as a dictionary. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 53 | exec(f.read(), about) 54 | else: 55 | about['__version__'] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | """Support setup.py upload.""" 60 | 61 | description = 'Build and publish the package.' 62 | user_options = [] 63 | 64 | @staticmethod 65 | def status(s): 66 | """Prints things in bold.""" 67 | print('\033[1m{0}\033[0m'.format(s)) 68 | 69 | def initialize_options(self): 70 | pass 71 | 72 | def finalize_options(self): 73 | pass 74 | 75 | def run(self): 76 | try: 77 | self.status('Removing previous builds…') 78 | rmtree(os.path.join(here, 'dist')) 79 | except OSError: 80 | pass 81 | 82 | self.status('Building Source and Wheel (universal) distribution…') 83 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 84 | 85 | self.status('Uploading the package to PyPI via Twine…') 86 | os.system('twine upload dist/*') 87 | 88 | self.status('Pushing git tags…') 89 | os.system('git tag v{0}'.format(about['__version__'])) 90 | os.system('git push --tags') 91 | 92 | sys.exit() 93 | 94 | 95 | # Where the magic happens: 96 | setup( 97 | name=NAME, 98 | version=about['__version__'], 99 | description=DESCRIPTION, 100 | long_description=long_description, 101 | long_description_content_type='text/markdown', 102 | author=AUTHOR, 103 | author_email=EMAIL, 104 | python_requires=REQUIRES_PYTHON, 105 | url=URL, 106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 107 | # py_modules=['model'], # If your package is a single module, use this instead of 'packages' 108 | install_requires=REQUIRED, 109 | extras_require=EXTRAS, 110 | include_package_data=True, 111 | license='Apache', 112 | classifiers=[ 113 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 114 | 'License :: OSI Approved :: Apache Software License', 115 | 'Programming Language :: Python', 116 | 'Programming Language :: Python :: 3', 117 | 'Programming Language :: Python :: 3.6', 118 | ], 119 | # $ setup.py publish support. 120 | cmdclass={ 121 | 'upload': UploadCommand, 122 | }, 123 | ) 124 | -------------------------------------------------------------------------------- /pytorch_pretrained_vit/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | from .configs import PRETRAINED_MODELS 9 | 10 | 11 | def load_pretrained_weights( 12 | model, 13 | model_name=None, 14 | weights_path=None, 15 | load_first_conv=True, 16 | load_fc=True, 17 | load_repr_layer=False, 18 | resize_positional_embedding=False, 19 | verbose=True, 20 | strict=True, 21 | ): 22 | """Loads pretrained weights from weights path or download using url. 23 | Args: 24 | model (Module): Full model (a nn.Module) 25 | model_name (str): Model name (e.g. B_16) 26 | weights_path (None or str): 27 | str: path to pretrained weights file on the local disk. 28 | None: use pretrained weights downloaded from the Internet. 29 | load_first_conv (bool): Whether to load patch embedding. 30 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 31 | resize_positional_embedding=False, 32 | verbose (bool): Whether to print on completion 33 | """ 34 | assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path' 35 | 36 | # Load or download weights 37 | if weights_path is None: 38 | url = PRETRAINED_MODELS[model_name]['url'] 39 | if url: 40 | state_dict = model_zoo.load_url(url) 41 | else: 42 | raise ValueError(f'Pretrained model for {model_name} has not yet been released') 43 | else: 44 | state_dict = torch.load(weights_path) 45 | 46 | # Modifications to load partial state dict 47 | expected_missing_keys = [] 48 | if not load_first_conv and 'patch_embedding.weight' in state_dict: 49 | expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias'] 50 | if not load_fc and 'fc.weight' in state_dict: 51 | expected_missing_keys += ['fc.weight', 'fc.bias'] 52 | if not load_repr_layer and 'pre_logits.weight' in state_dict: 53 | expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias'] 54 | for key in expected_missing_keys: 55 | state_dict.pop(key) 56 | 57 | # Change size of positional embeddings 58 | if resize_positional_embedding: 59 | posemb = state_dict['positional_embedding.pos_embedding'] 60 | posemb_new = model.state_dict()['positional_embedding.pos_embedding'] 61 | state_dict['positional_embedding.pos_embedding'] = \ 62 | resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 63 | has_class_token=hasattr(model, 'class_token')) 64 | maybe_print('Resized positional embeddings from {} to {}'.format( 65 | posemb.shape, posemb_new.shape), verbose) 66 | 67 | # Load state dict 68 | ret = model.load_state_dict(state_dict, strict=False) 69 | if strict: 70 | assert set(ret.missing_keys) == set(expected_missing_keys), \ 71 | 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 72 | assert not ret.unexpected_keys, \ 73 | 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 74 | maybe_print('Loaded pretrained weights.', verbose) 75 | else: 76 | maybe_print('Missing keys when loading pretrained weights: {}'.format(ret.missing_keys), verbose) 77 | maybe_print('Unexpected keys when loading pretrained weights: {}'.format(ret.unexpected_keys), verbose) 78 | return ret 79 | 80 | 81 | def maybe_print(s: str, flag: bool): 82 | if flag: 83 | print(s) 84 | 85 | 86 | def as_tuple(x): 87 | return x if isinstance(x, tuple) else (x, x) 88 | 89 | 90 | def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 91 | """Rescale the grid of position embeddings in a sensible manner""" 92 | from scipy.ndimage import zoom 93 | 94 | # Deal with class token 95 | ntok_new = posemb_new.shape[1] 96 | if has_class_token: # this means classifier == 'token' 97 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 98 | ntok_new -= 1 99 | else: 100 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 101 | 102 | # Get old and new grid sizes 103 | gs_old = int(np.sqrt(len(posemb_grid))) 104 | gs_new = int(np.sqrt(ntok_new)) 105 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 106 | 107 | # Rescale grid 108 | zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) 109 | posemb_grid = zoom(posemb_grid, zoom_factor, order=1) 110 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 111 | posemb_grid = torch.from_numpy(posemb_grid) 112 | 113 | # Deal with class token and return 114 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 115 | return posemb 116 | 117 | -------------------------------------------------------------------------------- /jax_to_pytorch/convert.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | 7 | import pytorch_pretrained_vit 8 | 9 | 10 | npz_files = { 11 | 'B_16': 'jax_weights/ViT-B_16.npz', 12 | 'B_32': 'jax_weights/ViT-B_32.npz', 13 | # 'L_16': 'jax_weights/ViT-L_16.npz', # <-- not available 14 | 'L_32': 'jax_weights/ViT-L_32.npz', 15 | 'B_16_imagenet1k': 'jax_weights/ViT-B_16_imagenet1k.npz', 16 | 'B_32_imagenet1k': 'jax_weights/ViT-B_32_imagenet1k.npz', 17 | 'L_16_imagenet1k': 'jax_weights/ViT-L_16_imagenet1k.npz', 18 | 'L_32_imagenet1k': 'jax_weights/ViT-L_32_imagenet1k.npz', 19 | } 20 | 21 | 22 | def jax_to_pytorch(k): 23 | k = k.replace('Transformer/encoder_norm', 'norm') 24 | k = k.replace('LayerNorm_0', 'norm1') 25 | k = k.replace('LayerNorm_2', 'norm2') 26 | k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1') 27 | k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2') 28 | k = k.replace('MultiHeadDotProductAttention_1/out', 'proj') 29 | k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q') 30 | k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k') 31 | k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v') 32 | k = k.replace('Transformer/posembed_input', 'positional_embedding') 33 | k = k.replace('encoderblock_', 'blocks.') 34 | k = 'patch_embedding.bias' if k == 'embedding/bias' else k 35 | k = 'patch_embedding.weight' if k == 'embedding/kernel' else k 36 | k = 'class_token' if k == 'cls' else k 37 | k = k.replace('head', 'fc') 38 | k = k.replace('kernel', 'weight') 39 | k = k.replace('scale', 'weight') 40 | k = k.replace('/', '.') 41 | k = k.lower() 42 | return k 43 | 44 | 45 | def convert(npz, state_dict): 46 | new_state_dict = {} 47 | pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()} 48 | for pytorch_k, pytorch_v in state_dict.items(): 49 | 50 | # Naming 51 | if 'self_attn.out_proj.weight' in pytorch_k: 52 | v = pytorch_k2v[pytorch_k] 53 | v = v.reshape(v.shape[0] * v.shape[1], v.shape[2]) 54 | elif 'self_attn.in_proj_' in pytorch_k: 55 | v = np.stack((pytorch_k2v[pytorch_k + '*q'], 56 | pytorch_k2v[pytorch_k + '*k'], 57 | pytorch_k2v[pytorch_k + '*v']), axis=0) 58 | else: 59 | if pytorch_k not in pytorch_k2v: 60 | print(pytorch_k, list(pytorch_k2v.keys())) 61 | assert False 62 | v = pytorch_k2v[pytorch_k] 63 | v = torch.from_numpy(v) 64 | 65 | # Sizing 66 | if '.weight' in pytorch_k: 67 | if len(pytorch_v.shape) == 2: 68 | v = v.transpose(0, 1) 69 | if len(pytorch_v.shape) == 4: 70 | v = v.permute(3, 2, 0, 1) 71 | if ('proj.weight' in pytorch_k): 72 | v = v.transpose(0, 1) 73 | v = v.reshape(-1, v.shape[-1]).T 74 | if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k): 75 | v = v.permute(0, 2, 1) 76 | v = v.reshape(-1, v.shape[-1]) 77 | if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k: 78 | v = v.reshape(-1) 79 | new_state_dict[pytorch_k] = v 80 | return new_state_dict 81 | 82 | 83 | def check_model(model, name): 84 | model.eval() 85 | img = Image.open('../examples/simple/img.jpg') 86 | img = transforms.Compose([transforms.Resize(model.image_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0) 87 | if 'imagenet1k' in name: 88 | labels_file = '../examples/simple/labels_map.txt' 89 | labels_map = json.load(open(labels_file)) 90 | labels_map = [labels_map[str(i)] for i in range(1000)] 91 | print('-----\nShould be index 388 (panda) w/ high probability:') 92 | else: 93 | print('~ not checked ~') 94 | return # labels_map = open('../examples/simple/labels_map_21k.txt').read().splitlines() 95 | with torch.no_grad(): 96 | outputs = model(img).squeeze(0) 97 | for idx in torch.topk(outputs, k=3).indices.tolist(): 98 | prob = torch.softmax(outputs, -1)[idx].item() 99 | print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100)) 100 | 101 | 102 | for name, filename in npz_files.items(): 103 | 104 | # Load Jax weights 105 | npz = np.load(filename) 106 | 107 | # Load PyTorch model 108 | model = pytorch_pretrained_vit.ViT(name=name, pretrained=False) 109 | 110 | # Convert weights 111 | new_state_dict = convert(npz, model.state_dict()) 112 | 113 | # Load into model and test 114 | model.load_state_dict(new_state_dict) 115 | print(f'Checking: {name}') 116 | check_model(model, name) 117 | 118 | # Save weights 119 | new_filename = f'weights/{name}.pth' 120 | torch.save(new_state_dict, new_filename, _use_new_zipfile_serialization=False) 121 | print(f"Converted {filename} and saved to {new_filename}") 122 | 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViT PyTorch 2 | 3 | ### Quickstart 4 | 5 | Install with `pip install pytorch_pretrained_vit` and load a pretrained ViT with: 6 | ```python 7 | from pytorch_pretrained_vit import ViT 8 | model = ViT('B_16_imagenet1k', pretrained=True) 9 | ``` 10 | 11 | Or find a Google Colab example [here](https://colab.research.google.com/drive/1muZ4QFgVfwALgqmrfOkp7trAvqDemckO?usp=sharing). 12 | 13 | ### Overview 14 | This repository contains an op-for-op PyTorch reimplementation of the [Visual Transformer](https://openreview.net/forum?id=YicbFdNTTy) architecture from [Google](https://github.com/google-research/vision_transformer), along with pre-trained models and examples. 15 | 16 | The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. 17 | 18 | At the moment, you can easily: 19 | * Load pretrained ViT models 20 | * Evaluate on ImageNet or your own data 21 | * Finetune ViT on your own dataset 22 | 23 | _(Upcoming features)_ Coming soon: 24 | * Train ViT from scratch on ImageNet (1K) 25 | * Export to ONNX for efficient inference 26 | 27 | ### Table of contents 28 | 1. [About ViT](#about-vit) 29 | 2. [About ViT-PyTorch](#about-vit-pytorch) 30 | 3. [Installation](#installation) 31 | 4. [Usage](#usage) 32 | * [Load pretrained models](#loading-pretrained-models) 33 | * [Example: Classify](#example-classification) 34 | 35 | 36 | 6. [Contributing](#contributing) 37 | 38 | ### About ViT 39 | 40 | Visual Transformers (ViT) are a straightforward application of the [transformer architecture](https://arxiv.org/abs/1706.03762) to image classification. Even in computer vision, it seems, attention is all you need. 41 | 42 | The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like [BERT](https://arxiv.org/abs/1810.04805)), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. 43 | ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute. 44 | 45 |
46 | 47 |
48 | 49 | 50 | ### About ViT-PyTorch 51 | 52 | ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the [original Jax implementation](https://github.com/google-research/vision_transformer), so that it's easy to load Jax-pretrained weights. 53 | 54 | At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible. 55 | 56 | ### Installation 57 | 58 | Install with pip: 59 | ```bash 60 | pip install pytorch_pretrained_vit 61 | ``` 62 | 63 | Or from source: 64 | ```bash 65 | git clone https://github.com/lukemelas/ViT-PyTorch 66 | cd ViT-Pytorch 67 | pip install -e . 68 | ``` 69 | 70 | ### Usage 71 | 72 | #### Loading pretrained models 73 | 74 | Loading a pretrained model is easy: 75 | ```python 76 | from pytorch_pretrained_vit import ViT 77 | model = ViT('B_16_imagenet1k', pretrained=True) 78 | ``` 79 | 80 | Details about the models are below: 81 | 82 | | *Name* |* Pretrained on *|*Finetuned on*|*Available? *| 83 | |:-----------------:|:---------------:|:------------:|:-----------:| 84 | | `B_16` | ImageNet-21k | - | ✓ | 85 | | `B_32` | ImageNet-21k | - | ✓ | 86 | | `L_16` | ImageNet-21k | - | - | 87 | | `L_32` | ImageNet-21k | - | ✓ | 88 | | `B_16_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ | 89 | | `B_32_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ | 90 | | `L_16_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ | 91 | | `L_32_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ | 92 | 93 | #### Custom ViT 94 | 95 | Loading custom configurations is just as easy: 96 | ```python 97 | from pytorch_pretrained_vit import ViT 98 | # The following is equivalent to ViT('B_16') 99 | config = dict(hidden_size=512, num_heads=8, num_layers=6) 100 | model = ViT.from_config(config) 101 | ``` 102 | 103 | #### Example: Classification 104 | 105 | Below is a simple, complete example. It may also be found as a Jupyter notebook in `examples/simple` or as a [Colab Notebook](). 106 | 107 | 108 | ```python 109 | import json 110 | from PIL import Image 111 | import torch 112 | from torchvision import transforms 113 | 114 | # Load ViT 115 | from pytorch_pretrained_vit import ViT 116 | model = ViT('B_16_imagenet1k', pretrained=True) 117 | model.eval() 118 | 119 | # Load image 120 | # NOTE: Assumes an image `img.jpg` exists in the current directory 121 | img = transforms.Compose([ 122 | transforms.Resize((384, 384)), 123 | transforms.ToTensor(), 124 | transforms.Normalize(0.5, 0.5), 125 | ])(Image.open('img.jpg')).unsqueeze(0) 126 | print(img.shape) # torch.Size([1, 3, 384, 384]) 127 | 128 | # Classify 129 | with torch.no_grad(): 130 | outputs = model(img) 131 | print(outputs.shape) # (1, 1000) 132 | ``` 133 | 134 | 147 | 148 | 163 | 164 | 165 | #### ImageNet 166 | 167 | See `examples/imagenet` for details about evaluating on ImageNet. 168 | 169 | #### Credit 170 | 171 | Other great repositories with this model include: 172 | - [Ross Wightman's repo](https://github.com/rwightman/pytorch-image-models) 173 | - [Phil Wang's repo](https://github.com/lucidrains/vit-pytorch) 174 | 175 | ### Contributing 176 | 177 | If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues. 178 | 179 | I look forward to seeing what the community does with these models! 180 | -------------------------------------------------------------------------------- /pytorch_pretrained_vit/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for ViT. 2 | They are built to mirror those in the official Jax implementation. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from .transformer import Transformer 11 | from .utils import load_pretrained_weights, as_tuple 12 | from .configs import PRETRAINED_MODELS 13 | 14 | 15 | class PositionalEmbedding1D(nn.Module): 16 | """Adds (optionally learned) positional embeddings to the inputs.""" 17 | 18 | def __init__(self, seq_len, dim): 19 | super().__init__() 20 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 21 | 22 | def forward(self, x): 23 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 24 | return x + self.pos_embedding 25 | 26 | 27 | class ViT(nn.Module): 28 | """ 29 | Args: 30 | name (str): Model name, e.g. 'B_16' 31 | pretrained (bool): Load pretrained weights 32 | in_channels (int): Number of channels in input data 33 | num_classes (int): Number of classes, default 1000 34 | 35 | References: 36 | [1] https://openreview.net/forum?id=YicbFdNTTy 37 | """ 38 | 39 | def __init__( 40 | self, 41 | name: Optional[str] = None, 42 | pretrained: bool = False, 43 | patches: int = 16, 44 | dim: int = 768, 45 | ff_dim: int = 3072, 46 | num_heads: int = 12, 47 | num_layers: int = 12, 48 | attention_dropout_rate: float = 0.0, 49 | dropout_rate: float = 0.1, 50 | representation_size: Optional[int] = None, 51 | load_repr_layer: bool = False, 52 | classifier: str = 'token', 53 | positional_embedding: str = '1d', 54 | in_channels: int = 3, 55 | image_size: Optional[int] = None, 56 | num_classes: Optional[int] = None, 57 | ): 58 | super().__init__() 59 | 60 | # Configuration 61 | if name is None: 62 | check_msg = 'must specify name of pretrained model' 63 | assert not pretrained, check_msg 64 | assert not resize_positional_embedding, check_msg 65 | if num_classes is None: 66 | num_classes = 1000 67 | if image_size is None: 68 | image_size = 384 69 | else: # load pretrained model 70 | assert name in PRETRAINED_MODELS.keys(), \ 71 | 'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys()) 72 | config = PRETRAINED_MODELS[name]['config'] 73 | patches = config['patches'] 74 | dim = config['dim'] 75 | ff_dim = config['ff_dim'] 76 | num_heads = config['num_heads'] 77 | num_layers = config['num_layers'] 78 | attention_dropout_rate = config['attention_dropout_rate'] 79 | dropout_rate = config['dropout_rate'] 80 | representation_size = config['representation_size'] 81 | classifier = config['classifier'] 82 | if image_size is None: 83 | image_size = PRETRAINED_MODELS[name]['image_size'] 84 | if num_classes is None: 85 | num_classes = PRETRAINED_MODELS[name]['num_classes'] 86 | self.image_size = image_size 87 | 88 | # Image and patch sizes 89 | h, w = as_tuple(image_size) # image sizes 90 | fh, fw = as_tuple(patches) # patch sizes 91 | gh, gw = h // fh, w // fw # number of patches 92 | seq_len = gh * gw 93 | 94 | # Patch embedding 95 | self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw)) 96 | 97 | # Class token 98 | if classifier == 'token': 99 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 100 | seq_len += 1 101 | 102 | # Positional embedding 103 | if positional_embedding.lower() == '1d': 104 | self.positional_embedding = PositionalEmbedding1D(seq_len, dim) 105 | else: 106 | raise NotImplementedError() 107 | 108 | # Transformer 109 | self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 110 | ff_dim=ff_dim, dropout=dropout_rate) 111 | 112 | # Representation layer 113 | if representation_size and load_repr_layer: 114 | self.pre_logits = nn.Linear(dim, representation_size) 115 | pre_logits_size = representation_size 116 | else: 117 | pre_logits_size = dim 118 | 119 | # Classifier head 120 | self.norm = nn.LayerNorm(pre_logits_size, eps=1e-6) 121 | self.fc = nn.Linear(pre_logits_size, num_classes) 122 | 123 | # Initialize weights 124 | self.init_weights() 125 | 126 | # Load pretrained model 127 | if pretrained: 128 | pretrained_num_channels = 3 129 | pretrained_num_classes = PRETRAINED_MODELS[name]['num_classes'] 130 | pretrained_image_size = PRETRAINED_MODELS[name]['image_size'] 131 | load_pretrained_weights( 132 | self, name, 133 | load_first_conv=(in_channels == pretrained_num_channels), 134 | load_fc=(num_classes == pretrained_num_classes), 135 | load_repr_layer=load_repr_layer, 136 | resize_positional_embedding=(image_size != pretrained_image_size), 137 | ) 138 | 139 | @torch.no_grad() 140 | def init_weights(self): 141 | def _init(m): 142 | if isinstance(m, nn.Linear): 143 | nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal 144 | if hasattr(m, 'bias') and m.bias is not None: 145 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) 146 | self.apply(_init) 147 | nn.init.constant_(self.fc.weight, 0) 148 | nn.init.constant_(self.fc.bias, 0) 149 | nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02) # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02) 150 | nn.init.constant_(self.class_token, 0) 151 | 152 | def forward(self, x): 153 | """Breaks image into patches, applies transformer, applies MLP head. 154 | 155 | Args: 156 | x (tensor): `b,c,fh,fw` 157 | """ 158 | b, c, fh, fw = x.shape 159 | x = self.patch_embedding(x) # b,d,gh,gw 160 | x = x.flatten(2).transpose(1, 2) # b,gh*gw,d 161 | if hasattr(self, 'class_token'): 162 | x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 163 | if hasattr(self, 'positional_embedding'): 164 | x = self.positional_embedding(x) # b,gh*gw+1,d 165 | x = self.transformer(x) # b,gh*gw+1,d 166 | if hasattr(self, 'pre_logits'): 167 | x = self.pre_logits(x) 168 | x = torch.tanh(x) 169 | if hasattr(self, 'fc'): 170 | x = self.norm(x)[:, 0] # b,d 171 | x = self.fc(x) # b,num_classes 172 | return x 173 | 174 | -------------------------------------------------------------------------------- /jax_to_pytorch/explore-conversion-21k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 21, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 22, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "%autoreload 2" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 23, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from PIL import Image\n", 37 | "import numpy as np\n", 38 | "import torch\n", 39 | "from torchvision import transforms" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 24, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "" 51 | ] 52 | }, 53 | "execution_count": 24, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "from importlib import reload\n", 60 | "import pytorch_pretrained_vit\n", 61 | "reload(pytorch_pretrained_vit)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 25, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "name = 'L_32'\n", 71 | "filename = 'jax_weights/ViT-L_32.npz'\n", 72 | "# npz_files = {\n", 73 | "# 'B_16': 'jax_weights/ViT-B_16.npz',\n", 74 | "# 'B_32': 'jax_weights/ViT-B_32.npz',\n", 75 | "# 'L_16': 'jax_weights/ViT-L_16.npz',\n", 76 | "# 'L_32': 'jax_weights/ViT-L_32.npz',\n", 77 | "# 'B_16_imagenet1k': 'jax_weights/ViT-B_16_imagenet1k.npz',\n", 78 | "# 'B_32_imagenet1k': 'jax_weights/ViT-B_32_imagenet1k.npz',\n", 79 | "# 'L_16_imagenet1k': 'jax_weights/ViT-L_16_imagenet1k.npz',\n", 80 | "# 'L_32_imagenet1k': 'jax_weights/ViT-L_32_imagenet1k.npz',\n", 81 | "# }\n", 82 | "num_classes = 21843\n", 83 | "# num_classes = 1000" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 26, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def jax_to_pytorch(k):\n", 93 | " k = k.replace('Transformer/encoder_norm', 'norm')\n", 94 | " k = k.replace('LayerNorm_0', 'norm1')\n", 95 | " k = k.replace('LayerNorm_2', 'norm2')\n", 96 | " k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')\n", 97 | " k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')\n", 98 | " k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')\n", 99 | " k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')\n", 100 | " k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')\n", 101 | " k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')\n", 102 | " k = k.replace('Transformer/posembed_input', 'positional_embedding')\n", 103 | " k = k.replace('encoderblock_', 'blocks.')\n", 104 | " k = 'patch_embedding.bias' if k == 'embedding/bias' else k\n", 105 | " k = 'patch_embedding.weight' if k == 'embedding/kernel' else k\n", 106 | " k = 'class_token' if k == 'cls' else k\n", 107 | " k = k.replace('head', 'fc')\n", 108 | " k = k.replace('kernel', 'weight')\n", 109 | " k = k.replace('scale', 'weight')\n", 110 | " k = k.replace('/', '.')\n", 111 | " k = k.lower()\n", 112 | " return k\n", 113 | "\n", 114 | "\n", 115 | "def convert(npz, state_dict):\n", 116 | " new_state_dict = {}\n", 117 | " pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}\n", 118 | " for pytorch_k, pytorch_v in state_dict.items():\n", 119 | " \n", 120 | " # Naming\n", 121 | " if 'self_attn.out_proj.weight' in pytorch_k:\n", 122 | " v = pytorch_k2v[pytorch_k]\n", 123 | " v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])\n", 124 | " elif 'self_attn.in_proj_' in pytorch_k:\n", 125 | " v = np.stack((pytorch_k2v[pytorch_k + '*q'], \n", 126 | " pytorch_k2v[pytorch_k + '*k'], \n", 127 | " pytorch_k2v[pytorch_k + '*v']), axis=0)\n", 128 | " else:\n", 129 | " if pytorch_k not in pytorch_k2v:\n", 130 | " print(pytorch_k, list(pytorch_k2v.keys()))\n", 131 | " assert False\n", 132 | " v = pytorch_k2v[pytorch_k]\n", 133 | " v = torch.from_numpy(v)\n", 134 | " \n", 135 | " # Sizing\n", 136 | " if '.weight' in pytorch_k:\n", 137 | " if len(pytorch_v.shape) == 2:\n", 138 | " v = v.transpose(0, 1)\n", 139 | " if len(pytorch_v.shape) == 4:\n", 140 | " v = v.permute(3, 2, 0, 1)\n", 141 | " if ('proj.weight' in pytorch_k):\n", 142 | " v = v.transpose(0, 1)\n", 143 | " v = v.reshape(-1, v.shape[-1]).T\n", 144 | " if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):\n", 145 | " v = v.permute(0, 2, 1)\n", 146 | " v = v.reshape(-1, v.shape[-1])\n", 147 | " if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:\n", 148 | " v = v.reshape(-1)\n", 149 | " new_state_dict[pytorch_k] = v\n", 150 | " return new_state_dict\n", 151 | "\n", 152 | "\n", 153 | "def check_model(model, name):\n", 154 | " model.eval()\n", 155 | " img = Image.open('../examples/simple/img.jpg')\n", 156 | " img = transforms.Compose([transforms.Resize(model.image_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)\n", 157 | " if 'imagenet1k' in name:\n", 158 | " labels_file = '../examples/simple/labels_map.txt' \n", 159 | " labels_map = json.load(open(labels_file))\n", 160 | " labels_map = [labels_map[str(i)] for i in range(1000)]\n", 161 | " print('-----\\nShould be index 388 (panda) w/ high probability:')\n", 162 | " else:\n", 163 | " labels_map = open('../examples/simple/labels_map_21k.txt').read().splitlines()\n", 164 | " with torch.no_grad():\n", 165 | " outputs = model(img).squeeze(0)\n", 166 | " for idx in torch.topk(outputs, k=3).indices.tolist():\n", 167 | " prob = torch.softmax(outputs, -1)[idx].item()\n", 168 | " print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 27, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "" 180 | ] 181 | }, 182 | "execution_count": 27, 183 | "metadata": {}, 184 | "output_type": "execute_result" 185 | } 186 | ], 187 | "source": [ 188 | "# Load Jax weights\n", 189 | "npz = np.load(filename)\n", 190 | "\n", 191 | "# Load PyTorch model\n", 192 | "model = pytorch_pretrained_vit.ViT(name=name, pretrained=False, load_repr_layer=True)\n", 193 | "\n", 194 | "# Convert weights\n", 195 | "new_state_dict = convert(npz, model.state_dict())\n", 196 | "\n", 197 | "# Load into model and test\n", 198 | "model.load_state_dict(new_state_dict)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 35, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "[3690] yellow_mountain_saxifrage (0.07%)\n", 211 | "[228] red_fox (0.03%)\n", 212 | "[7705] amberjack (0.02%)\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "check_model(model, name)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 30, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "text/plain": [ 228 | "Parameter containing:\n", 229 | "tensor([[-0.1088, 0.0409, 0.0360, ..., -0.0716, -0.0018, 0.0037],\n", 230 | " [-0.0945, 0.0493, -0.1079, ..., -0.0731, -0.0225, -0.1013],\n", 231 | " [ 0.0356, -0.0351, 0.0510, ..., 0.0867, -0.0274, -0.0638],\n", 232 | " ...,\n", 233 | " [ 0.0978, -0.1415, 0.0287, ..., 0.0058, -0.0644, 0.0968],\n", 234 | " [ 0.0014, 0.0605, -0.0371, ..., -0.1093, -0.1687, -0.0141],\n", 235 | " [ 0.0515, -0.1962, 0.0966, ..., -0.0401, -0.1180, -0.0609]],\n", 236 | " requires_grad=True)" 237 | ] 238 | }, 239 | "execution_count": 30, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "model.pre_logits.weight" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 31, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "array([[-0.10884866, -0.09448911, 0.03560322, ..., 0.09780853,\n", 257 | " 0.00137385, 0.05146733],\n", 258 | " [ 0.04090862, 0.04934628, -0.03513117, ..., -0.14147781,\n", 259 | " 0.06050842, -0.19619823],\n", 260 | " [ 0.03600169, -0.10787308, 0.05104499, ..., 0.02869395,\n", 261 | " -0.03711622, 0.09662006],\n", 262 | " ...,\n", 263 | " [-0.071619 , -0.07306355, 0.08665656, ..., 0.00580501,\n", 264 | " -0.10929134, -0.04009511],\n", 265 | " [-0.00184033, -0.02247062, -0.02740383, ..., -0.06440727,\n", 266 | " -0.16867375, -0.11800908],\n", 267 | " [ 0.00365812, -0.10132378, -0.06381997, ..., 0.09679291,\n", 268 | " -0.01409991, -0.06087695]], dtype=float32)" 269 | ] 270 | }, 271 | "execution_count": 31, 272 | "metadata": {}, 273 | "output_type": "execute_result" 274 | } 275 | ], 276 | "source": [ 277 | "npz['pre_logits/kernel']" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.8.3" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 4 309 | } 310 | -------------------------------------------------------------------------------- /jax_to_pytorch/explore-conversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from PIL import Image\n", 28 | "import torch\n", 29 | "from torchvision import transforms" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "torch.Size([1, 3, 384, 384])" 41 | ] 42 | }, 43 | "execution_count": 4, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "img = Image.open('../examples/simple/img.jpg')\n", 50 | "img = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)\n", 51 | "img.shape" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "" 63 | ] 64 | }, 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "from importlib import reload\n", 72 | "import pytorch_pretrained_vit\n", 73 | "reload(pytorch_pretrained_vit)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# model = pytorch_pretrained_vit.ViT(name='B_16', pretrained=False, num_classes=21843)\n", 83 | "model = pytorch_pretrained_vit.ViT(name='B_16_imagenet1k', pretrained=False, num_classes=1000)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 7, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# list(model.state_dict().keys())" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "### Jax" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 8, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "import numpy as np" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 9, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# npz = np.load('imagenet21k_ViT-B_16.npz')\n", 118 | "npz = np.load('ViT-B_16.npz')" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 10, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# npz.files" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 11, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "def convert(npz, state_dict):\n", 137 | " new_state_dict = {}\n", 138 | " pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}\n", 139 | " for pytorch_k, pytorch_v in state_dict.items():\n", 140 | " \n", 141 | " # Naming\n", 142 | " if 'self_attn.out_proj.weight' in pytorch_k:\n", 143 | " v = pytorch_k2v[pytorch_k]\n", 144 | " v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])\n", 145 | " elif 'self_attn.in_proj_' in pytorch_k:\n", 146 | " v = np.stack((pytorch_k2v[pytorch_k + '*q'], \n", 147 | " pytorch_k2v[pytorch_k + '*k'], \n", 148 | " pytorch_k2v[pytorch_k + '*v']), axis=0)\n", 149 | " else:\n", 150 | " if pytorch_k not in pytorch_k2v:\n", 151 | " print(pytorch_k, list(pytorch_k2v.keys()))\n", 152 | " assert False\n", 153 | " v = pytorch_k2v[pytorch_k]\n", 154 | " v = torch.from_numpy(v)\n", 155 | " \n", 156 | " # Sizing\n", 157 | " if '.weight' in pytorch_k:\n", 158 | " if len(pytorch_v.shape) == 2:\n", 159 | " v = v.transpose(0, 1)\n", 160 | " if len(pytorch_v.shape) == 4:\n", 161 | " v = v.permute(3, 2, 0, 1)\n", 162 | " if ('proj.weight' in pytorch_k):\n", 163 | " v = v.transpose(0, 1)\n", 164 | " v = v.reshape(-1, v.shape[-1]).T\n", 165 | " if ('proj.bias' in pytorch_k):\n", 166 | " print(pytorch_k, v.shape)\n", 167 | " if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):\n", 168 | " v = v.permute(0, 2, 1)\n", 169 | " v = v.reshape(-1, v.shape[-1])\n", 170 | " if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:\n", 171 | " v = v.reshape(-1)\n", 172 | " new_state_dict[pytorch_k] = v\n", 173 | " return new_state_dict" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 12, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "def jax_to_pytorch(k):\n", 183 | " k = k.replace('Transformer/encoder_norm', 'norm')\n", 184 | " k = k.replace('LayerNorm_0', 'norm1')\n", 185 | " k = k.replace('LayerNorm_2', 'norm2')\n", 186 | " k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')\n", 187 | " k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')\n", 188 | " k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')\n", 189 | " k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')\n", 190 | " k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')\n", 191 | " k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')\n", 192 | " k = k.replace('Transformer/posembed_input', 'positional_embedding')\n", 193 | " k = k.replace('encoderblock_', 'blocks.')\n", 194 | " k = 'patch_embedding.bias' if k == 'embedding/bias' else k\n", 195 | " k = 'patch_embedding.weight' if k == 'embedding/kernel' else k\n", 196 | " k = 'class_token' if k == 'cls' else k\n", 197 | " k = k.replace('head', 'fc')\n", 198 | " k = k.replace('kernel', 'weight')\n", 199 | " k = k.replace('scale', 'weight')\n", 200 | " k = k.replace('/', '.')\n", 201 | " k = k.lower()\n", 202 | " return k" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 13, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "transformer.blocks.0.proj.bias torch.Size([768])\n", 215 | "transformer.blocks.1.proj.bias torch.Size([768])\n", 216 | "transformer.blocks.2.proj.bias torch.Size([768])\n", 217 | "transformer.blocks.3.proj.bias torch.Size([768])\n", 218 | "transformer.blocks.4.proj.bias torch.Size([768])\n", 219 | "transformer.blocks.5.proj.bias torch.Size([768])\n", 220 | "transformer.blocks.6.proj.bias torch.Size([768])\n", 221 | "transformer.blocks.7.proj.bias torch.Size([768])\n", 222 | "transformer.blocks.8.proj.bias torch.Size([768])\n", 223 | "transformer.blocks.9.proj.bias torch.Size([768])\n", 224 | "transformer.blocks.10.proj.bias torch.Size([768])\n", 225 | "transformer.blocks.11.proj.bias torch.Size([768])\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "new_state_dict = convert(npz, model.state_dict())" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 14, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "" 242 | ] 243 | }, 244 | "execution_count": 14, 245 | "metadata": {}, 246 | "output_type": "execute_result" 247 | } 248 | ], 249 | "source": [ 250 | "model.load_state_dict(new_state_dict)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 15, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "import json \n", 260 | "\n", 261 | "def check(M):\n", 262 | " labels_map = json.load(open('../examples/simple/labels_map.txt'))\n", 263 | " labels_map = [labels_map[str(i)] for i in range(1000)]\n", 264 | " with torch.no_grad():\n", 265 | " outputs = M(img)\n", 266 | " print('-----')\n", 267 | " for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():\n", 268 | " prob = torch.softmax(outputs, dim=1)[0, idx].item()\n", 269 | " print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 16, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "-----\n", 282 | "[388] giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (99.51%)\n", 283 | "[387] lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (0.16%)\n", 284 | "[297] sloth bear, Melursus ursinus, Ursus ursinus (0.05%)\n", 285 | "[295] American black bear, black bear, Ursus americanus, Euarctos americanus (0.03%)\n", 286 | "[296] ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus (0.03%)\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "model.eval()\n", 292 | "check(model)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 17, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "def printhook(self, input, output):\n", 302 | " print('Inside ' + self.__class__.__name__ + ' forward')\n", 303 | " print('input: ', type(input))\n", 304 | " print('input[0]: ', type(input[0]))\n", 305 | " print('input size:', input[0].size())\n", 306 | " print('input norm:', input[0].norm())\n", 307 | " if isinstance(output, tuple):\n", 308 | " output = output[0]\n", 309 | " print('output size:', output.data.size())\n", 310 | " print('output norm:', output.data.norm())\n", 311 | " print('-----------\\n')" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 18, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# h1 = m.blocks[0].attn.proj.register_forward_hook(printhook) # m.blocks[0].register_forward_hook(printhook) \n", 321 | "# h2 = model.transformer.blocks[0].proj.register_forward_hook(printhook) # model.transformer.layers[0].register_forward_hook(printhook) \n", 322 | "# m(img)\n", 323 | "# model(img)\n", 324 | "# h1.remove()\n", 325 | "# h2.remove()" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [] 334 | } 335 | ], 336 | "metadata": { 337 | "kernelspec": { 338 | "display_name": "Python 3", 339 | "language": "python", 340 | "name": "python3" 341 | }, 342 | "language_info": { 343 | "codemirror_mode": { 344 | "name": "ipython", 345 | "version": 3 346 | }, 347 | "file_extension": ".py", 348 | "mimetype": "text/x-python", 349 | "name": "python", 350 | "nbconvert_exporter": "python", 351 | "pygments_lexer": "ipython3", 352 | "version": "3.8.3" 353 | } 354 | }, 355 | "nbformat": 4, 356 | "nbformat_minor": 4 357 | } 358 | -------------------------------------------------------------------------------- /examples/imagenet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | 22 | from pytorch_pretrained_vit import ViT, load_pretrained_weights 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 25 | parser.add_argument('data', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 28 | help='model architecture (default: resnet18)') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=256, type=int, 36 | metavar='N', 37 | help='mini-batch size (default: 256), this is the total ' 38 | 'batch size of all GPUs on the current node when ' 39 | 'using Data Parallel or Distributed Data Parallel') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate', dest='lr') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)', 46 | dest='weight_decay') 47 | parser.add_argument('-p', '--print-freq', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 52 | help='evaluate model on validation set') 53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 54 | help='use pre-trained model') 55 | parser.add_argument('--world-size', default=-1, type=int, 56 | help='number of nodes for distributed training') 57 | parser.add_argument('--rank', default=-1, type=int, 58 | help='node rank for distributed training') 59 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 60 | help='url used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='nccl', type=str, 62 | help='distributed backend') 63 | parser.add_argument('--seed', default=None, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu', default=None, type=int, 66 | help='GPU id to use.') 67 | parser.add_argument('--image_size', default=224, type=int, 68 | help='image size') 69 | parser.add_argument('--vit', action='store_true', help='use ViT model') 70 | parser.add_argument('--multiprocessing-distributed', action='store_true', 71 | help='Use multi-processing distributed training to launch ' 72 | 'N processes per node, which has N GPUs. This is the ' 73 | 'fastest way to use PyTorch for either single node or ' 74 | 'multi node data parallel training') 75 | 76 | best_acc1 = 0 77 | 78 | 79 | def main(): 80 | args = parser.parse_args() 81 | 82 | if args.seed is not None: 83 | random.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | cudnn.deterministic = True 86 | warnings.warn('You have chosen to seed training. ' 87 | 'This will turn on the CUDNN deterministic setting, ' 88 | 'which can slow down your training considerably! ' 89 | 'You may see unexpected behavior when restarting ' 90 | 'from checkpoints.') 91 | 92 | if args.gpu is not None: 93 | warnings.warn('You have chosen a specific GPU. This will completely ' 94 | 'disable data parallelism.') 95 | 96 | if args.dist_url == "env://" and args.world_size == -1: 97 | args.world_size = int(os.environ["WORLD_SIZE"]) 98 | 99 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 100 | 101 | ngpus_per_node = torch.cuda.device_count() 102 | if args.multiprocessing_distributed: 103 | # Since we have ngpus_per_node processes per node, the total world_size 104 | # needs to be adjusted accordingly 105 | args.world_size = ngpus_per_node * args.world_size 106 | # Use torch.multiprocessing.spawn to launch distributed processes: the 107 | # main_worker process function 108 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 109 | else: 110 | # Simply call main_worker function 111 | main_worker(args.gpu, ngpus_per_node, args) 112 | 113 | 114 | def main_worker(gpu, ngpus_per_node, args): 115 | global best_acc1 116 | args.gpu = gpu 117 | 118 | if args.gpu is not None: 119 | print("Use GPU: {} for training".format(args.gpu)) 120 | 121 | if args.distributed: 122 | if args.dist_url == "env://" and args.rank == -1: 123 | args.rank = int(os.environ["RANK"]) 124 | if args.multiprocessing_distributed: 125 | # For multiprocessing distributed training, rank needs to be the 126 | # global rank among all the processes 127 | args.rank = args.rank * ngpus_per_node + gpu 128 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 129 | world_size=args.world_size, rank=args.rank) 130 | 131 | # NEW 132 | if args.vit: 133 | model = ViT(args.arch, pretrained=args.pretrained) 134 | 135 | # # NOTE: This is for debugging 136 | # model = ViT('B_16_imagenet1k', pretrained=False) 137 | # load_pretrained_weights(model, weights_path='/home/luke/projects/experiments/ViT-PyTorch/jax_to_pytorch/weights/B_16_imagenet1k.pth') 138 | 139 | else: 140 | model = models.__dict__[args.arch](pretrained=args.pretrained) 141 | print("=> using model '{}' (pretrained={})".format(args.arch, args.pretrained)) 142 | 143 | if args.distributed: 144 | # For multiprocessing distributed, DistributedDataParallel constructor 145 | # should always set the single device scope, otherwise, 146 | # DistributedDataParallel will use all available devices. 147 | if args.gpu is not None: 148 | torch.cuda.set_device(args.gpu) 149 | model.cuda(args.gpu) 150 | # When using a single GPU per process and per 151 | # DistributedDataParallel, we need to divide the batch size 152 | # ourselves based on the total number of GPUs we have 153 | args.batch_size = int(args.batch_size / ngpus_per_node) 154 | args.workers = int(args.workers / ngpus_per_node) 155 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 156 | else: 157 | model.cuda() 158 | # DistributedDataParallel will divide and allocate batch_size to all 159 | # available GPUs if device_ids are not set 160 | model = torch.nn.parallel.DistributedDataParallel(model) 161 | elif args.gpu is not None: 162 | torch.cuda.set_device(args.gpu) 163 | model = model.cuda(args.gpu) 164 | else: 165 | # DataParallel will divide and allocate batch_size to all available GPUs 166 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 167 | model.features = torch.nn.DataParallel(model.features) 168 | model.cuda() 169 | else: 170 | model = torch.nn.DataParallel(model).cuda() 171 | 172 | # define loss function (criterion) and optimizer 173 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 174 | 175 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 176 | momentum=args.momentum, 177 | weight_decay=args.weight_decay) 178 | 179 | # optionally resume from a checkpoint 180 | if args.resume: 181 | if os.path.isfile(args.resume): 182 | print("=> loading checkpoint '{}'".format(args.resume)) 183 | checkpoint = torch.load(args.resume) 184 | args.start_epoch = checkpoint['epoch'] 185 | best_acc1 = checkpoint['best_acc1'] 186 | if args.gpu is not None: 187 | # best_acc1 may be from a checkpoint from a different GPU 188 | best_acc1 = best_acc1.to(args.gpu) 189 | model.load_state_dict(checkpoint['state_dict']) 190 | optimizer.load_state_dict(checkpoint['optimizer']) 191 | print("=> loaded checkpoint '{}' (epoch {})" 192 | .format(args.resume, checkpoint['epoch'])) 193 | else: 194 | print("=> no checkpoint found at '{}'".format(args.resume)) 195 | 196 | cudnn.benchmark = True 197 | 198 | # Data loading code 199 | traindir = os.path.join(args.data, 'train') 200 | valdir = os.path.join(args.data, 'val') 201 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 202 | normalize = transforms.Normalize(0.5, 0.5) 203 | 204 | train_dataset = datasets.ImageFolder( 205 | traindir, 206 | transforms.Compose([ 207 | transforms.RandomResizedCrop(args.image_size), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.ToTensor(), 210 | normalize, 211 | ])) 212 | 213 | if args.distributed: 214 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 215 | else: 216 | train_sampler = None 217 | 218 | train_loader = torch.utils.data.DataLoader( 219 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 220 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 221 | 222 | val_transforms = transforms.Compose([ 223 | transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 224 | transforms.CenterCrop(args.image_size), 225 | transforms.ToTensor(), 226 | normalize, 227 | ]) 228 | print('Using image size', args.image_size) 229 | 230 | val_loader = torch.utils.data.DataLoader( 231 | datasets.ImageFolder(valdir, val_transforms), 232 | batch_size=args.batch_size, shuffle=False, 233 | num_workers=args.workers, pin_memory=True) 234 | 235 | if args.evaluate: 236 | res = validate(val_loader, model, criterion, args) 237 | with open('res.txt', 'w') as f: 238 | print(res, file=f) 239 | return 240 | 241 | for epoch in range(args.start_epoch, args.epochs): 242 | if args.distributed: 243 | train_sampler.set_epoch(epoch) 244 | adjust_learning_rate(optimizer, epoch, args) 245 | 246 | # train for one epoch 247 | train(train_loader, model, criterion, optimizer, epoch, args) 248 | 249 | # evaluate on validation set 250 | acc1 = validate(val_loader, model, criterion, args) 251 | 252 | # remember best acc@1 and save checkpoint 253 | is_best = acc1 > best_acc1 254 | best_acc1 = max(acc1, best_acc1) 255 | 256 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 257 | and args.rank % ngpus_per_node == 0): 258 | save_checkpoint({ 259 | 'epoch': epoch + 1, 260 | 'arch': args.arch, 261 | 'state_dict': model.state_dict(), 262 | 'best_acc1': best_acc1, 263 | 'optimizer' : optimizer.state_dict(), 264 | }, is_best) 265 | 266 | 267 | def train(train_loader, model, criterion, optimizer, epoch, args): 268 | batch_time = AverageMeter('Time', ':6.3f') 269 | data_time = AverageMeter('Data', ':6.3f') 270 | losses = AverageMeter('Loss', ':.4e') 271 | top1 = AverageMeter('Acc@1', ':6.2f') 272 | top5 = AverageMeter('Acc@5', ':6.2f') 273 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 274 | top5, prefix="Epoch: [{}]".format(epoch)) 275 | 276 | # switch to train mode 277 | model.train() 278 | 279 | end = time.time() 280 | for i, (images, target) in enumerate(train_loader): 281 | # measure data loading time 282 | data_time.update(time.time() - end) 283 | 284 | if args.gpu is not None: 285 | images = images.cuda(args.gpu, non_blocking=True) 286 | target = target.cuda(args.gpu, non_blocking=True) 287 | 288 | # compute output 289 | output = model(images) 290 | loss = criterion(output, target) 291 | 292 | # measure accuracy and record loss 293 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 294 | losses.update(loss.item(), images.size(0)) 295 | top1.update(acc1[0], images.size(0)) 296 | top5.update(acc5[0], images.size(0)) 297 | 298 | # compute gradient and do SGD step 299 | optimizer.zero_grad() 300 | loss.backward() 301 | optimizer.step() 302 | 303 | # measure elapsed time 304 | batch_time.update(time.time() - end) 305 | end = time.time() 306 | 307 | if i % args.print_freq == 0: 308 | progress.print(i) 309 | 310 | 311 | def validate(val_loader, model, criterion, args): 312 | batch_time = AverageMeter('Time', ':6.3f') 313 | losses = AverageMeter('Loss', ':.4e') 314 | top1 = AverageMeter('Acc@1', ':6.2f') 315 | top5 = AverageMeter('Acc@5', ':6.2f') 316 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 317 | prefix='Test: ') 318 | 319 | # switch to evaluate mode 320 | model.eval() 321 | 322 | with torch.no_grad(): 323 | end = time.time() 324 | for i, (images, target) in enumerate(val_loader): 325 | if args.gpu is not None: 326 | images = images.cuda(args.gpu, non_blocking=True) 327 | target = target.cuda(args.gpu, non_blocking=True) 328 | 329 | # compute output 330 | output = model(images) 331 | loss = criterion(output, target) 332 | 333 | # measure accuracy and record loss 334 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 335 | losses.update(loss.item(), images.size(0)) 336 | top1.update(acc1[0], images.size(0)) 337 | top5.update(acc5[0], images.size(0)) 338 | 339 | # measure elapsed time 340 | batch_time.update(time.time() - end) 341 | end = time.time() 342 | 343 | if i % args.print_freq == 0: 344 | progress.print(i) 345 | 346 | # TODO: this should also be done with the ProgressMeter 347 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 348 | .format(top1=top1, top5=top5)) 349 | 350 | return top1.avg 351 | 352 | 353 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 354 | torch.save(state, filename) 355 | if is_best: 356 | shutil.copyfile(filename, 'model_best.pth.tar') 357 | 358 | 359 | class AverageMeter(object): 360 | """Computes and stores the average and current value""" 361 | def __init__(self, name, fmt=':f'): 362 | self.name = name 363 | self.fmt = fmt 364 | self.reset() 365 | 366 | def reset(self): 367 | self.val = 0 368 | self.avg = 0 369 | self.sum = 0 370 | self.count = 0 371 | 372 | def update(self, val, n=1): 373 | self.val = val 374 | self.sum += val * n 375 | self.count += n 376 | self.avg = self.sum / self.count 377 | 378 | def __str__(self): 379 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 380 | return fmtstr.format(**self.__dict__) 381 | 382 | 383 | class ProgressMeter(object): 384 | def __init__(self, num_batches, *meters, prefix=""): 385 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 386 | self.meters = meters 387 | self.prefix = prefix 388 | 389 | def print(self, batch): 390 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 391 | entries += [str(meter) for meter in self.meters] 392 | print('\t'.join(entries)) 393 | 394 | def _get_batch_fmtstr(self, num_batches): 395 | num_digits = len(str(num_batches // 1)) 396 | fmt = '{:' + str(num_digits) + 'd}' 397 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 398 | 399 | 400 | def adjust_learning_rate(optimizer, epoch, args): 401 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 402 | lr = args.lr * (0.1 ** (epoch // 30)) 403 | for param_group in optimizer.param_groups: 404 | param_group['lr'] = lr 405 | 406 | 407 | def accuracy(output, target, topk=(1,)): 408 | """Computes the accuracy over the k top predictions for the specified values of k""" 409 | with torch.no_grad(): 410 | maxk = max(topk) 411 | batch_size = target.size(0) 412 | 413 | _, pred = output.topk(maxk, 1, True, True) 414 | pred = pred.t() 415 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 416 | 417 | res = [] 418 | for k in topk: 419 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 420 | res.append(correct_k.mul_(100.0 / batch_size)) 421 | return res 422 | 423 | 424 | if __name__ == '__main__': 425 | main() 426 | -------------------------------------------------------------------------------- /examples/simple/labels_map.txt: -------------------------------------------------------------------------------- 1 | {"0": "tench, Tinca tinca", "1": "goldfish, Carassius auratus", "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "3": "tiger shark, Galeocerdo cuvieri", "4": "hammerhead, hammerhead shark", "5": "electric ray, crampfish, numbfish, torpedo", "6": "stingray", "7": "cock", "8": "hen", "9": "ostrich, Struthio camelus", "10": "brambling, Fringilla montifringilla", "11": "goldfinch, Carduelis carduelis", "12": "house finch, linnet, Carpodacus mexicanus", "13": "junco, snowbird", "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "15": "robin, American robin, Turdus migratorius", "16": "bulbul", "17": "jay", "18": "magpie", "19": "chickadee", "20": "water ouzel, dipper", "21": "kite", "22": "bald eagle, American eagle, Haliaeetus leucocephalus", "23": "vulture", "24": "great grey owl, great gray owl, Strix nebulosa", "25": "European fire salamander, Salamandra salamandra", "26": "common newt, Triturus vulgaris", "27": "eft", "28": "spotted salamander, Ambystoma maculatum", "29": "axolotl, mud puppy, Ambystoma mexicanum", "30": "bullfrog, Rana catesbeiana", "31": "tree frog, tree-frog", "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "33": "loggerhead, loggerhead turtle, Caretta caretta", "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "35": "mud turtle", "36": "terrapin", "37": "box turtle, box tortoise", "38": "banded gecko", "39": "common iguana, iguana, Iguana iguana", "40": "American chameleon, anole, Anolis carolinensis", "41": "whiptail, whiptail lizard", "42": "agama", "43": "frilled lizard, Chlamydosaurus kingi", "44": "alligator lizard", "45": "Gila monster, Heloderma suspectum", "46": "green lizard, Lacerta viridis", "47": "African chameleon, Chamaeleo chamaeleon", "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "49": "African crocodile, Nile crocodile, Crocodylus niloticus", "50": "American alligator, Alligator mississipiensis", "51": "triceratops", "52": "thunder snake, worm snake, Carphophis amoenus", "53": "ringneck snake, ring-necked snake, ring snake", "54": "hognose snake, puff adder, sand viper", "55": "green snake, grass snake", "56": "king snake, kingsnake", "57": "garter snake, grass snake", "58": "water snake", "59": "vine snake", "60": "night snake, Hypsiglena torquata", "61": "boa constrictor, Constrictor constrictor", "62": "rock python, rock snake, Python sebae", "63": "Indian cobra, Naja naja", "64": "green mamba", "65": "sea snake", "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", "68": "sidewinder, horned rattlesnake, Crotalus cerastes", "69": "trilobite", "70": "harvestman, daddy longlegs, Phalangium opilio", "71": "scorpion", "72": "black and gold garden spider, Argiope aurantia", "73": "barn spider, Araneus cavaticus", "74": "garden spider, Aranea diademata", "75": "black widow, Latrodectus mactans", "76": "tarantula", "77": "wolf spider, hunting spider", "78": "tick", "79": "centipede", "80": "black grouse", "81": "ptarmigan", "82": "ruffed grouse, partridge, Bonasa umbellus", "83": "prairie chicken, prairie grouse, prairie fowl", "84": "peacock", "85": "quail", "86": "partridge", "87": "African grey, African gray, Psittacus erithacus", "88": "macaw", "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "90": "lorikeet", "91": "coucal", "92": "bee eater", "93": "hornbill", "94": "hummingbird", "95": "jacamar", "96": "toucan", "97": "drake", "98": "red-breasted merganser, Mergus serrator", "99": "goose", "100": "black swan, Cygnus atratus", "101": "tusker", "102": "echidna, spiny anteater, anteater", "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "104": "wallaby, brush kangaroo", "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "106": "wombat", "107": "jellyfish", "108": "sea anemone, anemone", "109": "brain coral", "110": "flatworm, platyhelminth", "111": "nematode, nematode worm, roundworm", "112": "conch", "113": "snail", "114": "slug", "115": "sea slug, nudibranch", "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", "117": "chambered nautilus, pearly nautilus, nautilus", "118": "Dungeness crab, Cancer magister", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "124": "crayfish, crawfish, crawdad, crawdaddy", "125": "hermit crab", "126": "isopod", "127": "white stork, Ciconia ciconia", "128": "black stork, Ciconia nigra", "129": "spoonbill", "130": "flamingo", "131": "little blue heron, Egretta caerulea", "132": "American egret, great white heron, Egretta albus", "133": "bittern", "134": "crane", "135": "limpkin, Aramus pictus", "136": "European gallinule, Porphyrio porphyrio", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "138": "bustard", "139": "ruddy turnstone, Arenaria interpres", "140": "red-backed sandpiper, dunlin, Erolia alpina", "141": "redshank, Tringa totanus", "142": "dowitcher", "143": "oystercatcher, oyster catcher", "144": "pelican", "145": "king penguin, Aptenodytes patagonica", "146": "albatross, mollymawk", "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "149": "dugong, Dugong dugon", "150": "sea lion", "151": "Chihuahua", "152": "Japanese spaniel", "153": "Maltese dog, Maltese terrier, Maltese", "154": "Pekinese, Pekingese, Peke", "155": "Shih-Tzu", "156": "Blenheim spaniel", "157": "papillon", "158": "toy terrier", "159": "Rhodesian ridgeback", "160": "Afghan hound, Afghan", "161": "basset, basset hound", "162": "beagle", "163": "bloodhound, sleuthhound", "164": "bluetick", "165": "black-and-tan coonhound", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "168": "redbone", "169": "borzoi, Russian wolfhound", "170": "Irish wolfhound", "171": "Italian greyhound", "172": "whippet", "173": "Ibizan hound, Ibizan Podenco", "174": "Norwegian elkhound, elkhound", "175": "otterhound, otter hound", "176": "Saluki, gazelle hound", "177": "Scottish deerhound, deerhound", "178": "Weimaraner", "179": "Staffordshire bullterrier, Staffordshire bull terrier", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "181": "Bedlington terrier", "182": "Border terrier", "183": "Kerry blue terrier", "184": "Irish terrier", "185": "Norfolk terrier", "186": "Norwich terrier", "187": "Yorkshire terrier", "188": "wire-haired fox terrier", "189": "Lakeland terrier", "190": "Sealyham terrier, Sealyham", "191": "Airedale, Airedale terrier", "192": "cairn, cairn terrier", "193": "Australian terrier", "194": "Dandie Dinmont, Dandie Dinmont terrier", "195": "Boston bull, Boston terrier", "196": "miniature schnauzer", "197": "giant schnauzer", "198": "standard schnauzer", "199": "Scotch terrier, Scottish terrier, Scottie", "200": "Tibetan terrier, chrysanthemum dog", "201": "silky terrier, Sydney silky", "202": "soft-coated wheaten terrier", "203": "West Highland white terrier", "204": "Lhasa, Lhasa apso", "205": "flat-coated retriever", "206": "curly-coated retriever", "207": "golden retriever", "208": "Labrador retriever", "209": "Chesapeake Bay retriever", "210": "German short-haired pointer", "211": "vizsla, Hungarian pointer", "212": "English setter", "213": "Irish setter, red setter", "214": "Gordon setter", "215": "Brittany spaniel", "216": "clumber, clumber spaniel", "217": "English springer, English springer spaniel", "218": "Welsh springer spaniel", "219": "cocker spaniel, English cocker spaniel, cocker", "220": "Sussex spaniel", "221": "Irish water spaniel", "222": "kuvasz", "223": "schipperke", "224": "groenendael", "225": "malinois", "226": "briard", "227": "kelpie", "228": "komondor", "229": "Old English sheepdog, bobtail", "230": "Shetland sheepdog, Shetland sheep dog, Shetland", "231": "collie", "232": "Border collie", "233": "Bouvier des Flandres, Bouviers des Flandres", "234": "Rottweiler", "235": "German shepherd, German shepherd dog, German police dog, alsatian", "236": "Doberman, Doberman pinscher", "237": "miniature pinscher", "238": "Greater Swiss Mountain dog", "239": "Bernese mountain dog", "240": "Appenzeller", "241": "EntleBucher", "242": "boxer", "243": "bull mastiff", "244": "Tibetan mastiff", "245": "French bulldog", "246": "Great Dane", "247": "Saint Bernard, St Bernard", "248": "Eskimo dog, husky", "249": "malamute, malemute, Alaskan malamute", "250": "Siberian husky", "251": "dalmatian, coach dog, carriage dog", "252": "affenpinscher, monkey pinscher, monkey dog", "253": "basenji", "254": "pug, pug-dog", "255": "Leonberg", "256": "Newfoundland, Newfoundland dog", "257": "Great Pyrenees", "258": "Samoyed, Samoyede", "259": "Pomeranian", "260": "chow, chow chow", "261": "keeshond", "262": "Brabancon griffon", "263": "Pembroke, Pembroke Welsh corgi", "264": "Cardigan, Cardigan Welsh corgi", "265": "toy poodle", "266": "miniature poodle", "267": "standard poodle", "268": "Mexican hairless", "269": "timber wolf, grey wolf, gray wolf, Canis lupus", "270": "white wolf, Arctic wolf, Canis lupus tundrarum", "271": "red wolf, maned wolf, Canis rufus, Canis niger", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "273": "dingo, warrigal, warragal, Canis dingo", "274": "dhole, Cuon alpinus", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "276": "hyena, hyaena", "277": "red fox, Vulpes vulpes", "278": "kit fox, Vulpes macrotis", "279": "Arctic fox, white fox, Alopex lagopus", "280": "grey fox, gray fox, Urocyon cinereoargenteus", "281": "tabby, tabby cat", "282": "tiger cat", "283": "Persian cat", "284": "Siamese cat, Siamese", "285": "Egyptian cat", "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "287": "lynx, catamount", "288": "leopard, Panthera pardus", "289": "snow leopard, ounce, Panthera uncia", "290": "jaguar, panther, Panthera onca, Felis onca", "291": "lion, king of beasts, Panthera leo", "292": "tiger, Panthera tigris", "293": "cheetah, chetah, Acinonyx jubatus", "294": "brown bear, bruin, Ursus arctos", "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "297": "sloth bear, Melursus ursinus, Ursus ursinus", "298": "mongoose", "299": "meerkat, mierkat", "300": "tiger beetle", "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "302": "ground beetle, carabid beetle", "303": "long-horned beetle, longicorn, longicorn beetle", "304": "leaf beetle, chrysomelid", "305": "dung beetle", "306": "rhinoceros beetle", "307": "weevil", "308": "fly", "309": "bee", "310": "ant, emmet, pismire", "311": "grasshopper, hopper", "312": "cricket", "313": "walking stick, walkingstick, stick insect", "314": "cockroach, roach", "315": "mantis, mantid", "316": "cicada, cicala", "317": "leafhopper", "318": "lacewing, lacewing fly", "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "320": "damselfly", "321": "admiral", "322": "ringlet, ringlet butterfly", "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "324": "cabbage butterfly", "325": "sulphur butterfly, sulfur butterfly", "326": "lycaenid, lycaenid butterfly", "327": "starfish, sea star", "328": "sea urchin", "329": "sea cucumber, holothurian", "330": "wood rabbit, cottontail, cottontail rabbit", "331": "hare", "332": "Angora, Angora rabbit", "333": "hamster", "334": "porcupine, hedgehog", "335": "fox squirrel, eastern fox squirrel, Sciurus niger", "336": "marmot", "337": "beaver", "338": "guinea pig, Cavia cobaya", "339": "sorrel", "340": "zebra", "341": "hog, pig, grunter, squealer, Sus scrofa", "342": "wild boar, boar, Sus scrofa", "343": "warthog", "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", "345": "ox", "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "347": "bison", "348": "ram, tup", "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "350": "ibex, Capra ibex", "351": "hartebeest", "352": "impala, Aepyceros melampus", "353": "gazelle", "354": "Arabian camel, dromedary, Camelus dromedarius", "355": "llama", "356": "weasel", "357": "mink", "358": "polecat, fitch, foulmart, foumart, Mustela putorius", "359": "black-footed ferret, ferret, Mustela nigripes", "360": "otter", "361": "skunk, polecat, wood pussy", "362": "badger", "363": "armadillo", "364": "three-toed sloth, ai, Bradypus tridactylus", "365": "orangutan, orang, orangutang, Pongo pygmaeus", "366": "gorilla, Gorilla gorilla", "367": "chimpanzee, chimp, Pan troglodytes", "368": "gibbon, Hylobates lar", "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", "370": "guenon, guenon monkey", "371": "patas, hussar monkey, Erythrocebus patas", "372": "baboon", "373": "macaque", "374": "langur", "375": "colobus, colobus monkey", "376": "proboscis monkey, Nasalis larvatus", "377": "marmoset", "378": "capuchin, ringtail, Cebus capucinus", "379": "howler monkey, howler", "380": "titi, titi monkey", "381": "spider monkey, Ateles geoffroyi", "382": "squirrel monkey, Saimiri sciureus", "383": "Madagascar cat, ring-tailed lemur, Lemur catta", "384": "indri, indris, Indri indri, Indri brevicaudatus", "385": "Indian elephant, Elephas maximus", "386": "African elephant, Loxodonta africana", "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "389": "barracouta, snoek", "390": "eel", "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "392": "rock beauty, Holocanthus tricolor", "393": "anemone fish", "394": "sturgeon", "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", "396": "lionfish", "397": "puffer, pufferfish, blowfish, globefish", "398": "abacus", "399": "abaya", "400": "academic gown, academic robe, judge's robe", "401": "accordion, piano accordion, squeeze box", "402": "acoustic guitar", "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", "404": "airliner", "405": "airship, dirigible", "406": "altar", "407": "ambulance", "408": "amphibian, amphibious vehicle", "409": "analog clock", "410": "apiary, bee house", "411": "apron", "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "413": "assault rifle, assault gun", "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", "415": "bakery, bakeshop, bakehouse", "416": "balance beam, beam", "417": "balloon", "418": "ballpoint, ballpoint pen, ballpen, Biro", "419": "Band Aid", "420": "banjo", "421": "bannister, banister, balustrade, balusters, handrail", "422": "barbell", "423": "barber chair", "424": "barbershop", "425": "barn", "426": "barometer", "427": "barrel, cask", "428": "barrow, garden cart, lawn cart, wheelbarrow", "429": "baseball", "430": "basketball", "431": "bassinet", "432": "bassoon", "433": "bathing cap, swimming cap", "434": "bath towel", "435": "bathtub, bathing tub, bath, tub", "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "437": "beacon, lighthouse, beacon light, pharos", "438": "beaker", "439": "bearskin, busby, shako", "440": "beer bottle", "441": "beer glass", "442": "bell cote, bell cot", "443": "bib", "444": "bicycle-built-for-two, tandem bicycle, tandem", "445": "bikini, two-piece", "446": "binder, ring-binder", "447": "binoculars, field glasses, opera glasses", "448": "birdhouse", "449": "boathouse", "450": "bobsled, bobsleigh, bob", "451": "bolo tie, bolo, bola tie, bola", "452": "bonnet, poke bonnet", "453": "bookcase", "454": "bookshop, bookstore, bookstall", "455": "bottlecap", "456": "bow", "457": "bow tie, bow-tie, bowtie", "458": "brass, memorial tablet, plaque", "459": "brassiere, bra, bandeau", "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "461": "breastplate, aegis, egis", "462": "broom", "463": "bucket, pail", "464": "buckle", "465": "bulletproof vest", "466": "bullet train, bullet", "467": "butcher shop, meat market", "468": "cab, hack, taxi, taxicab", "469": "caldron, cauldron", "470": "candle, taper, wax light", "471": "cannon", "472": "canoe", "473": "can opener, tin opener", "474": "cardigan", "475": "car mirror", "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", "477": "carpenter's kit, tool kit", "478": "carton", "479": "car wheel", "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "481": "cassette", "482": "cassette player", "483": "castle", "484": "catamaran", "485": "CD player", "486": "cello, violoncello", "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", "488": "chain", "489": "chainlink fence", "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "491": "chain saw, chainsaw", "492": "chest", "493": "chiffonier, commode", "494": "chime, bell, gong", "495": "china cabinet, china closet", "496": "Christmas stocking", "497": "church, church building", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "499": "cleaver, meat cleaver, chopper", "500": "cliff dwelling", "501": "cloak", "502": "clog, geta, patten, sabot", "503": "cocktail shaker", "504": "coffee mug", "505": "coffeepot", "506": "coil, spiral, volute, whorl, helix", "507": "combination lock", "508": "computer keyboard, keypad", "509": "confectionery, confectionary, candy store", "510": "container ship, containership, container vessel", "511": "convertible", "512": "corkscrew, bottle screw", "513": "cornet, horn, trumpet, trump", "514": "cowboy boot", "515": "cowboy hat, ten-gallon hat", "516": "cradle", "517": "crane", "518": "crash helmet", "519": "crate", "520": "crib, cot", "521": "Crock Pot", "522": "croquet ball", "523": "crutch", "524": "cuirass", "525": "dam, dike, dyke", "526": "desk", "527": "desktop computer", "528": "dial telephone, dial phone", "529": "diaper, nappy, napkin", "530": "digital clock", "531": "digital watch", "532": "dining table, board", "533": "dishrag, dishcloth", "534": "dishwasher, dish washer, dishwashing machine", "535": "disk brake, disc brake", "536": "dock, dockage, docking facility", "537": "dogsled, dog sled, dog sleigh", "538": "dome", "539": "doormat, welcome mat", "540": "drilling platform, offshore rig", "541": "drum, membranophone, tympan", "542": "drumstick", "543": "dumbbell", "544": "Dutch oven", "545": "electric fan, blower", "546": "electric guitar", "547": "electric locomotive", "548": "entertainment center", "549": "envelope", "550": "espresso maker", "551": "face powder", "552": "feather boa, boa", "553": "file, file cabinet, filing cabinet", "554": "fireboat", "555": "fire engine, fire truck", "556": "fire screen, fireguard", "557": "flagpole, flagstaff", "558": "flute, transverse flute", "559": "folding chair", "560": "football helmet", "561": "forklift", "562": "fountain", "563": "fountain pen", "564": "four-poster", "565": "freight car", "566": "French horn, horn", "567": "frying pan, frypan, skillet", "568": "fur coat", "569": "garbage truck, dustcart", "570": "gasmask, respirator, gas helmet", "571": "gas pump, gasoline pump, petrol pump, island dispenser", "572": "goblet", "573": "go-kart", "574": "golf ball", "575": "golfcart, golf cart", "576": "gondola", "577": "gong, tam-tam", "578": "gown", "579": "grand piano, grand", "580": "greenhouse, nursery, glasshouse", "581": "grille, radiator grille", "582": "grocery store, grocery, food market, market", "583": "guillotine", "584": "hair slide", "585": "hair spray", "586": "half track", "587": "hammer", "588": "hamper", "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", "590": "hand-held computer, hand-held microcomputer", "591": "handkerchief, hankie, hanky, hankey", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "594": "harp", "595": "harvester, reaper", "596": "hatchet", "597": "holster", "598": "home theater, home theatre", "599": "honeycomb", "600": "hook, claw", "601": "hoopskirt, crinoline", "602": "horizontal bar, high bar", "603": "horse cart, horse-cart", "604": "hourglass", "605": "iPod", "606": "iron, smoothing iron", "607": "jack-o'-lantern", "608": "jean, blue jean, denim", "609": "jeep, landrover", "610": "jersey, T-shirt, tee shirt", "611": "jigsaw puzzle", "612": "jinrikisha, ricksha, rickshaw", "613": "joystick", "614": "kimono", "615": "knee pad", "616": "knot", "617": "lab coat, laboratory coat", "618": "ladle", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "621": "lawn mower, mower", "622": "lens cap, lens cover", "623": "letter opener, paper knife, paperknife", "624": "library", "625": "lifeboat", "626": "lighter, light, igniter, ignitor", "627": "limousine, limo", "628": "liner, ocean liner", "629": "lipstick, lip rouge", "630": "Loafer", "631": "lotion", "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "633": "loupe, jeweler's loupe", "634": "lumbermill, sawmill", "635": "magnetic compass", "636": "mailbag, postbag", "637": "mailbox, letter box", "638": "maillot", "639": "maillot, tank suit", "640": "manhole cover", "641": "maraca", "642": "marimba, xylophone", "643": "mask", "644": "matchstick", "645": "maypole", "646": "maze, labyrinth", "647": "measuring cup", "648": "medicine chest, medicine cabinet", "649": "megalith, megalithic structure", "650": "microphone, mike", "651": "microwave, microwave oven", "652": "military uniform", "653": "milk can", "654": "minibus", "655": "miniskirt, mini", "656": "minivan", "657": "missile", "658": "mitten", "659": "mixing bowl", "660": "mobile home, manufactured home", "661": "Model T", "662": "modem", "663": "monastery", "664": "monitor", "665": "moped", "666": "mortar", "667": "mortarboard", "668": "mosque", "669": "mosquito net", "670": "motor scooter, scooter", "671": "mountain bike, all-terrain bike, off-roader", "672": "mountain tent", "673": "mouse, computer mouse", "674": "mousetrap", "675": "moving van", "676": "muzzle", "677": "nail", "678": "neck brace", "679": "necklace", "680": "nipple", "681": "notebook, notebook computer", "682": "obelisk", "683": "oboe, hautboy, hautbois", "684": "ocarina, sweet potato", "685": "odometer, hodometer, mileometer, milometer", "686": "oil filter", "687": "organ, pipe organ", "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", "689": "overskirt", "690": "oxcart", "691": "oxygen mask", "692": "packet", "693": "paddle, boat paddle", "694": "paddlewheel, paddle wheel", "695": "padlock", "696": "paintbrush", "697": "pajama, pyjama, pj's, jammies", "698": "palace", "699": "panpipe, pandean pipe, syrinx", "700": "paper towel", "701": "parachute, chute", "702": "parallel bars, bars", "703": "park bench", "704": "parking meter", "705": "passenger car, coach, carriage", "706": "patio, terrace", "707": "pay-phone, pay-station", "708": "pedestal, plinth, footstall", "709": "pencil box, pencil case", "710": "pencil sharpener", "711": "perfume, essence", "712": "Petri dish", "713": "photocopier", "714": "pick, plectrum, plectron", "715": "pickelhaube", "716": "picket fence, paling", "717": "pickup, pickup truck", "718": "pier", "719": "piggy bank, penny bank", "720": "pill bottle", "721": "pillow", "722": "ping-pong ball", "723": "pinwheel", "724": "pirate, pirate ship", "725": "pitcher, ewer", "726": "plane, carpenter's plane, woodworking plane", "727": "planetarium", "728": "plastic bag", "729": "plate rack", "730": "plow, plough", "731": "plunger, plumber's helper", "732": "Polaroid camera, Polaroid Land camera", "733": "pole", "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "735": "poncho", "736": "pool table, billiard table, snooker table", "737": "pop bottle, soda bottle", "738": "pot, flowerpot", "739": "potter's wheel", "740": "power drill", "741": "prayer rug, prayer mat", "742": "printer", "743": "prison, prison house", "744": "projectile, missile", "745": "projector", "746": "puck, hockey puck", "747": "punching bag, punch bag, punching ball, punchball", "748": "purse", "749": "quill, quill pen", "750": "quilt, comforter, comfort, puff", "751": "racer, race car, racing car", "752": "racket, racquet", "753": "radiator", "754": "radio, wireless", "755": "radio telescope, radio reflector", "756": "rain barrel", "757": "recreational vehicle, RV, R.V.", "758": "reel", "759": "reflex camera", "760": "refrigerator, icebox", "761": "remote control, remote", "762": "restaurant, eating house, eating place, eatery", "763": "revolver, six-gun, six-shooter", "764": "rifle", "765": "rocking chair, rocker", "766": "rotisserie", "767": "rubber eraser, rubber, pencil eraser", "768": "rugby ball", "769": "rule, ruler", "770": "running shoe", "771": "safe", "772": "safety pin", "773": "saltshaker, salt shaker", "774": "sandal", "775": "sarong", "776": "sax, saxophone", "777": "scabbard", "778": "scale, weighing machine", "779": "school bus", "780": "schooner", "781": "scoreboard", "782": "screen, CRT screen", "783": "screw", "784": "screwdriver", "785": "seat belt, seatbelt", "786": "sewing machine", "787": "shield, buckler", "788": "shoe shop, shoe-shop, shoe store", "789": "shoji", "790": "shopping basket", "791": "shopping cart", "792": "shovel", "793": "shower cap", "794": "shower curtain", "795": "ski", "796": "ski mask", "797": "sleeping bag", "798": "slide rule, slipstick", "799": "sliding door", "800": "slot, one-armed bandit", "801": "snorkel", "802": "snowmobile", "803": "snowplow, snowplough", "804": "soap dispenser", "805": "soccer ball", "806": "sock", "807": "solar dish, solar collector, solar furnace", "808": "sombrero", "809": "soup bowl", "810": "space bar", "811": "space heater", "812": "space shuttle", "813": "spatula", "814": "speedboat", "815": "spider web, spider's web", "816": "spindle", "817": "sports car, sport car", "818": "spotlight, spot", "819": "stage", "820": "steam locomotive", "821": "steel arch bridge", "822": "steel drum", "823": "stethoscope", "824": "stole", "825": "stone wall", "826": "stopwatch, stop watch", "827": "stove", "828": "strainer", "829": "streetcar, tram, tramcar, trolley, trolley car", "830": "stretcher", "831": "studio couch, day bed", "832": "stupa, tope", "833": "submarine, pigboat, sub, U-boat", "834": "suit, suit of clothes", "835": "sundial", "836": "sunglass", "837": "sunglasses, dark glasses, shades", "838": "sunscreen, sunblock, sun blocker", "839": "suspension bridge", "840": "swab, swob, mop", "841": "sweatshirt", "842": "swimming trunks, bathing trunks", "843": "swing", "844": "switch, electric switch, electrical switch", "845": "syringe", "846": "table lamp", "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", "848": "tape player", "849": "teapot", "850": "teddy, teddy bear", "851": "television, television system", "852": "tennis ball", "853": "thatch, thatched roof", "854": "theater curtain, theatre curtain", "855": "thimble", "856": "thresher, thrasher, threshing machine", "857": "throne", "858": "tile roof", "859": "toaster", "860": "tobacco shop, tobacconist shop, tobacconist", "861": "toilet seat", "862": "torch", "863": "totem pole", "864": "tow truck, tow car, wrecker", "865": "toyshop", "866": "tractor", "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "868": "tray", "869": "trench coat", "870": "tricycle, trike, velocipede", "871": "trimaran", "872": "tripod", "873": "triumphal arch", "874": "trolleybus, trolley coach, trackless trolley", "875": "trombone", "876": "tub, vat", "877": "turnstile", "878": "typewriter keyboard", "879": "umbrella", "880": "unicycle, monocycle", "881": "upright, upright piano", "882": "vacuum, vacuum cleaner", "883": "vase", "884": "vault", "885": "velvet", "886": "vending machine", "887": "vestment", "888": "viaduct", "889": "violin, fiddle", "890": "volleyball", "891": "waffle iron", "892": "wall clock", "893": "wallet, billfold, notecase, pocketbook", "894": "wardrobe, closet, press", "895": "warplane, military plane", "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "897": "washer, automatic washer, washing machine", "898": "water bottle", "899": "water jug", "900": "water tower", "901": "whiskey jug", "902": "whistle", "903": "wig", "904": "window screen", "905": "window shade", "906": "Windsor tie", "907": "wine bottle", "908": "wing", "909": "wok", "910": "wooden spoon", "911": "wool, woolen, woollen", "912": "worm fence, snake fence, snake-rail fence, Virginia fence", "913": "wreck", "914": "yawl", "915": "yurt", "916": "web site, website, internet site, site", "917": "comic book", "918": "crossword puzzle, crossword", "919": "street sign", "920": "traffic light, traffic signal, stoplight", "921": "book jacket, dust cover, dust jacket, dust wrapper", "922": "menu", "923": "plate", "924": "guacamole", "925": "consomme", "926": "hot pot, hotpot", "927": "trifle", "928": "ice cream, icecream", "929": "ice lolly, lolly, lollipop, popsicle", "930": "French loaf", "931": "bagel, beigel", "932": "pretzel", "933": "cheeseburger", "934": "hotdog, hot dog, red hot", "935": "mashed potato", "936": "head cabbage", "937": "broccoli", "938": "cauliflower", "939": "zucchini, courgette", "940": "spaghetti squash", "941": "acorn squash", "942": "butternut squash", "943": "cucumber, cuke", "944": "artichoke, globe artichoke", "945": "bell pepper", "946": "cardoon", "947": "mushroom", "948": "Granny Smith", "949": "strawberry", "950": "orange", "951": "lemon", "952": "fig", "953": "pineapple, ananas", "954": "banana", "955": "jackfruit, jak, jack", "956": "custard apple", "957": "pomegranate", "958": "hay", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "961": "dough", "962": "meat loaf, meatloaf", "963": "pizza, pizza pie", "964": "potpie", "965": "burrito", "966": "red wine", "967": "espresso", "968": "cup", "969": "eggnog", "970": "alp", "971": "bubble", "972": "cliff, drop, drop-off", "973": "coral reef", "974": "geyser", "975": "lakeside, lakeshore", "976": "promontory, headland, head, foreland", "977": "sandbar, sand bar", "978": "seashore, coast, seacoast, sea-coast", "979": "valley, vale", "980": "volcano", "981": "ballplayer, baseball player", "982": "groom, bridegroom", "983": "scuba diver", "984": "rapeseed", "985": "daisy", "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "987": "corn", "988": "acorn", "989": "hip, rose hip, rosehip", "990": "buckeye, horse chestnut, conker", "991": "coral fungus", "992": "agaric", "993": "gyromitra", "994": "stinkhorn, carrion fungus", "995": "earthstar", "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "997": "bolete", "998": "ear, spike, capitulum", "999": "toilet tissue, toilet paper, bathroom tissue"} --------------------------------------------------------------------------------