├── .gitignore
├── .idea
├── .gitignore
├── Hugging_Face_tutorials.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
└── vit.py
/.gitignore:
--------------------------------------------------------------------------------
1 | test-cifar-10/*
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/Hugging_Face_tutorials.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hugging Face tutorials
2 |
3 | Article: [A complete Hugging Face tutorial: how to build and train a vision transformer](https://theaisummer.com/hugging-face-vit/)
4 |
5 | Code: [Vision Transformer](https://github.com/The-AI-Summer/Hugging_Face_tutorials/blob/master/vit.py)
6 |
--------------------------------------------------------------------------------
/vit.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 |
4 | from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \
5 | default_data_collator, EarlyStoppingCallback
6 | from transformers.modeling_outputs import SequenceClassifierOutput
7 | from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D
8 |
9 | train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])
10 | splits = train_ds.train_test_split(test_size=0.1)
11 | train_ds = splits['train']
12 | val_ds = splits['test']
13 |
14 | feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
15 | data_collator = default_data_collator
16 |
17 |
18 | def preprocess_images(examples):
19 | images = examples['img']
20 | images = [np.array(image, dtype=np.uint8) for image in images]
21 | images = [np.moveaxis(image, source=-1, destination=0) for image in images]
22 | inputs = feature_extractor(images=images)
23 | examples['pixel_values'] = inputs['pixel_values']
24 |
25 | return examples
26 |
27 |
28 | features = Features({
29 | 'label': ClassLabel(
30 | names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
31 | 'img': Array3D(dtype="int64", shape=(3, 32, 32)),
32 | 'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
33 | })
34 |
35 | preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
36 | preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
37 | preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)
38 |
39 |
40 | class ViTForImageClassification2(nn.Module):
41 | def __init__(self, num_labels=10):
42 | super(ViTForImageClassification2, self).__init__()
43 | self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
44 | self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
45 | self.num_labels = num_labels
46 |
47 | def forward(self, pixel_values, labels):
48 | outputs = self.vit(pixel_values=pixel_values)
49 | logits = self.classifier(outputs.last_hidden_state[:, 0])
50 |
51 | loss = None
52 | if labels is not None:
53 | loss_fct = nn.CrossEntropyLoss()
54 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55 |
56 | return SequenceClassifierOutput(
57 | loss=loss,
58 | logits=logits,
59 | hidden_states=outputs.hidden_states,
60 | attentions=outputs.attentions,
61 | )
62 |
63 |
64 | args = TrainingArguments(
65 | f"test-cifar-10",
66 | evaluation_strategy="epoch",
67 | learning_rate=2e-5,
68 | per_device_train_batch_size=10,
69 | per_device_eval_batch_size=4,
70 | num_train_epochs=3,
71 | weight_decay=0.01,
72 | load_best_model_at_end=True,
73 | metric_for_best_model="accuracy",
74 | logging_dir='logs',
75 | )
76 |
77 | # model = ViTForImageClassification()
78 | model = ViTForImageClassification2()
79 |
80 |
81 | def compute_metrics(eval_pred):
82 | predictions, labels = eval_pred
83 | predictions = np.argmax(predictions, axis=1)
84 | return load_metric("accuracy").compute(predictions=predictions, references=labels)
85 |
86 |
87 | trainer = Trainer(
88 | model,
89 | args,
90 | train_dataset=preprocessed_train_ds,
91 | eval_dataset=preprocessed_val_ds,
92 | data_collator=data_collator,
93 | compute_metrics=compute_metrics,
94 | )
95 |
96 | trainer.train()
97 |
98 | outputs = trainer.predict(preprocessed_test_ds)
99 |
--------------------------------------------------------------------------------