├── requirements.txt ├── README.md ├── run.py ├── .gitignore └── img_model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pytorch-lightning -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # huggingface-vit-finetune 2 | 3 | Huggingface does images now! 4 | 5 | Well...they will soon. For now we gotta install `transformers` from master. 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | pip install git+https://github.com/huggingface/transformers.git@master --upgrade 10 | python run.py 11 | ``` 12 | 13 | ## Using trained models w/ `transformers` 14 | 15 | Currently, the following models are available: 16 | - nateraw/vit-base-patch16-224-cifar10 17 | 18 | ```python 19 | from transformers import ViTFeatureExtractor, ViTForImageClassification 20 | from PIL import Image 21 | import requests 22 | 23 | url = 'https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png' 24 | image = Image.open(requests.get(url, stream=True).raw) 25 | feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10') 26 | model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10') 27 | inputs = feature_extractor(images=image, return_tensors="pt") 28 | outputs = model(**inputs) 29 | preds = outputs.logits.argmax(dim=1) 30 | 31 | classes = [ 32 | 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' 33 | ] 34 | classes[preds[0]] 35 | ``` -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from transformers import ViTFeatureExtractor, ViTForImageClassification, BatchFeature 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision.transforms import ToTensor, Normalize, Resize, Compose 4 | 5 | import torch 6 | from torchvision.datasets import CIFAR10 7 | import pytorch_lightning as pl 8 | 9 | from img_model import ImageClassifier 10 | 11 | 12 | class SimpleCustomBatch: 13 | def __init__(self, data): 14 | transposed_data = list(zip(*data)) 15 | self.inp = torch.stack(transposed_data[0], 0) 16 | self.tgt = torch.tensor(transposed_data[1]) 17 | 18 | # custom memory pinning method on custom type 19 | def pin_memory(self): 20 | self.inp = self.inp.pin_memory() 21 | self.tgt = self.tgt.pin_memory() 22 | return {'pixel_values': self.inp, 'labels': self.tgt} 23 | 24 | 25 | def my_collate(batch): 26 | return SimpleCustomBatch(batch) 27 | 28 | 29 | class ViTFeatureExtractorTransforms: 30 | def __init__(self, model_name_or_path): 31 | feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path) 32 | transform = [] 33 | 34 | if feature_extractor.do_resize: 35 | transform.append(Resize(feature_extractor.size)) 36 | 37 | transform.append(ToTensor()) 38 | 39 | if feature_extractor.do_normalize: 40 | transform.append(Normalize(feature_extractor.image_mean, feature_extractor.image_std)) 41 | 42 | self.transform = Compose(transform) 43 | 44 | def __call__(self, x): 45 | return self.transform(x) 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | model_name_or_path = 'google/vit-base-patch16-224-in21k' 51 | num_labels = 10 52 | batch_size = 24 53 | num_workers = 2 54 | max_epochs = 4 55 | 56 | train_loader = DataLoader( 57 | CIFAR10('./', download=True, transform=ViTFeatureExtractorTransforms(model_name_or_path)), 58 | batch_size=batch_size, 59 | num_workers=num_workers, 60 | pin_memory=True, 61 | collate_fn=my_collate 62 | ) 63 | val_loader = DataLoader( 64 | CIFAR10('./', download=True, train=False, transform=ViTFeatureExtractorTransforms(model_name_or_path)), 65 | batch_size=batch_size, 66 | num_workers=num_workers, 67 | pin_memory=True, 68 | collate_fn=my_collate 69 | ) 70 | model = ImageClassifier(model_name_or_path) 71 | # HACK - put this somewhere else 72 | model.total_steps = ( 73 | (len(train_loader.dataset) // (batch_size)) 74 | // 1 75 | * float(max_epochs) 76 | ) 77 | pixel_values, labels = next(iter(train_loader)) 78 | trainer = pl.Trainer(gpus=1, max_epochs=4, precision=16, limit_train_batches=5) 79 | trainer.fit(model, train_loader, val_loader) 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | cifar-10* 141 | lightning_logs/ 142 | vit* -------------------------------------------------------------------------------- /img_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Callable, List, Optional 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from pytorch_lightning.utilities import rank_zero_only 8 | from torch.utils.data import DataLoader 9 | from transformers import (AdamW, AutoConfig, 10 | AutoModelForSequenceClassification, AutoTokenizer, 11 | get_linear_schedule_with_warmup, ViTForImageClassification, ViTFeatureExtractor) 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | 16 | class ImageClassifier(pl.LightningModule): 17 | def __init__( 18 | self, 19 | model_name_or_path: str, 20 | learning_rate: float = 2e-5, 21 | adam_epsilon: float = 1e-8, 22 | weight_decay: float = 0.0, 23 | warmup_steps: int = 0, 24 | predictions_file: str = 'predictions.pt', 25 | num_labels: int = 10 26 | ): 27 | super().__init__() 28 | self.save_hyperparameters() 29 | 30 | self.model = ViTForImageClassification.from_pretrained(model_name_or_path, num_labels=num_labels) 31 | self.accuracy_metric = pl.metrics.Accuracy() 32 | 33 | def metric(self, preds, labels, mode='val'): 34 | a = self.accuracy_metric(preds, labels) 35 | return {f'{mode}_acc': a} 36 | 37 | def forward(self, **inputs): 38 | return self.model(**inputs) 39 | 40 | def training_step(self, batch, batch_idx): 41 | outputs = self(**batch) 42 | loss = outputs[0] 43 | self.log('train_loss', loss) 44 | return loss 45 | 46 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 47 | outputs = self(**batch) 48 | val_loss, logits = outputs[:2] 49 | preds = torch.argmax(logits, axis=1) 50 | metric_dict = self.metric(preds, batch['labels']) 51 | self.log_dict(metric_dict, prog_bar=True, on_step=False, on_epoch=True) 52 | self.log('val_loss', val_loss, prog_bar=True) 53 | 54 | def test_step(self, batch, batch_idx, dataloader_idx=0): 55 | outputs = self(**batch) 56 | test_loss, logits = outputs[:2] 57 | preds = torch.argmax(logits, axis=1) 58 | self.write_prediction('preds', preds, self.hparams.predictions_file) 59 | self.write_prediction('labels', batch['labels'], self.hparams.predictions_file) 60 | metric_dict = self.metric(preds, batch['labels'], mode='test') 61 | self.log_dict(metric_dict, prog_bar=True, on_step=False, on_epoch=True) 62 | 63 | def configure_optimizers(self): 64 | "Prepare optimizer and schedule (linear warmup and decay)" 65 | no_decay = ["bias", "LayerNorm.weight"] 66 | optimizer_grouped_parameters = [ 67 | { 68 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 69 | "weight_decay": self.hparams.weight_decay, 70 | }, 71 | { 72 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 73 | "weight_decay": 0.0, 74 | }, 75 | ] 76 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 77 | scheduler = get_linear_schedule_with_warmup( 78 | optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps 79 | ) 80 | scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1} 81 | return [optimizer], [scheduler] 82 | 83 | @rank_zero_only 84 | def save_pretrained(self, save_dir): 85 | self.hparams.save_dir = save_dir 86 | self.model.save_pretrained(self.hparams.save_dir) 87 | # self.tokenizer.save_pretrained(self.hparams.save_dir) 88 | --------------------------------------------------------------------------------