├── .gitignore ├── .idea ├── .gitignore ├── Hugging_Face_tutorials.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md └── vit.py /.gitignore: -------------------------------------------------------------------------------- 1 | test-cifar-10/* -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/Hugging_Face_tutorials.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hugging Face tutorials 2 | 3 | Article: [A complete Hugging Face tutorial: how to build and train a vision transformer](https://theaisummer.com/hugging-face-vit/) 4 | 5 | Code: [Vision Transformer](https://github.com/The-AI-Summer/Hugging_Face_tutorials/blob/master/vit.py) 6 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \ 5 | default_data_collator, EarlyStoppingCallback 6 | from transformers.modeling_outputs import SequenceClassifierOutput 7 | from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D 8 | 9 | train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]']) 10 | splits = train_ds.train_test_split(test_size=0.1) 11 | train_ds = splits['train'] 12 | val_ds = splits['test'] 13 | 14 | feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') 15 | data_collator = default_data_collator 16 | 17 | 18 | def preprocess_images(examples): 19 | images = examples['img'] 20 | images = [np.array(image, dtype=np.uint8) for image in images] 21 | images = [np.moveaxis(image, source=-1, destination=0) for image in images] 22 | inputs = feature_extractor(images=images) 23 | examples['pixel_values'] = inputs['pixel_values'] 24 | 25 | return examples 26 | 27 | 28 | features = Features({ 29 | 'label': ClassLabel( 30 | names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']), 31 | 'img': Array3D(dtype="int64", shape=(3, 32, 32)), 32 | 'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)), 33 | }) 34 | 35 | preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features) 36 | preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features) 37 | preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features) 38 | 39 | 40 | class ViTForImageClassification2(nn.Module): 41 | def __init__(self, num_labels=10): 42 | super(ViTForImageClassification2, self).__init__() 43 | self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') 44 | self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) 45 | self.num_labels = num_labels 46 | 47 | def forward(self, pixel_values, labels): 48 | outputs = self.vit(pixel_values=pixel_values) 49 | logits = self.classifier(outputs.last_hidden_state[:, 0]) 50 | 51 | loss = None 52 | if labels is not None: 53 | loss_fct = nn.CrossEntropyLoss() 54 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 55 | 56 | return SequenceClassifierOutput( 57 | loss=loss, 58 | logits=logits, 59 | hidden_states=outputs.hidden_states, 60 | attentions=outputs.attentions, 61 | ) 62 | 63 | 64 | args = TrainingArguments( 65 | f"test-cifar-10", 66 | evaluation_strategy="epoch", 67 | learning_rate=2e-5, 68 | per_device_train_batch_size=10, 69 | per_device_eval_batch_size=4, 70 | num_train_epochs=3, 71 | weight_decay=0.01, 72 | load_best_model_at_end=True, 73 | metric_for_best_model="accuracy", 74 | logging_dir='logs', 75 | ) 76 | 77 | # model = ViTForImageClassification() 78 | model = ViTForImageClassification2() 79 | 80 | 81 | def compute_metrics(eval_pred): 82 | predictions, labels = eval_pred 83 | predictions = np.argmax(predictions, axis=1) 84 | return load_metric("accuracy").compute(predictions=predictions, references=labels) 85 | 86 | 87 | trainer = Trainer( 88 | model, 89 | args, 90 | train_dataset=preprocessed_train_ds, 91 | eval_dataset=preprocessed_val_ds, 92 | data_collator=data_collator, 93 | compute_metrics=compute_metrics, 94 | ) 95 | 96 | trainer.train() 97 | 98 | outputs = trainer.predict(preprocessed_test_ds) 99 | --------------------------------------------------------------------------------