├── CITATION.cff ├── predict.py ├── dataset.py ├── .gitignore ├── README.md └── train.py /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: Neural Transformers for Hindi Image Captioning 6 | message: 'If you use this software, please cite it as below.' 7 | type: software 8 | authors: 9 | - given-names: Sean Benhur 10 | email: seanbenhur@gmail.com 11 | orcid: 'https://orcid.org/0000-0002-8022-1668' 12 | affiliation: PSG College of Arts and Science 13 | - given-names: Herumb Shandilya 14 | affiliation: Jaypee Institute of Information Technology 15 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from transformers import (AutoTokenizer, VisionEncoderDecoderModel, 5 | ViTFeatureExtractor) 6 | 7 | if torch.cuda.is_available(): 8 | device = "cuda" 9 | else: 10 | device = "cpu" 11 | 12 | url = "https://shorturl.at/fvxEQ" 13 | image = Image.open(requests.get(url, stream=True).raw) 14 | 15 | encoder_checkpoint = "google/vit-base-patch16-224" 16 | decoder_checkpoint = "surajp/gpt2-hindi" 17 | model_checkpoint = "team-indain-image-caption/hindi-image-captioning" 18 | feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) 19 | tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) 20 | model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) 21 | 22 | 23 | 24 | def predict(image): 25 | clean_text = lambda x: x.replace("<|endoftext|>", "").split("\n")[0] 26 | sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device) Inference 27 | caption_ids = model.generate(sample, max_length=50)[0] 28 | caption_text = clean_text(tokenizer.decode(caption_ids)) 29 | return caption_text 30 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from sklearn.model_selection import train_test_split 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class Image_Caption_Dataset(Dataset): 8 | def __init__( 9 | self, root_dir, df, feature_extractor, tokenizer, max_target_length=512 10 | ): 11 | self.root_dir = root_dir 12 | self.df = df 13 | self.feature_extractor = feature_extractor 14 | self.tokenizer = tokenizer 15 | self.max_length = max_target_length 16 | 17 | def __len__(self): 18 | return self.df.shape[0] 19 | 20 | def __getitem__(self, idx): 21 | # return image 22 | image_path = self.df["images"][idx] 23 | text = self.df["text"][idx] 24 | # prepare image 25 | image = Image.open(self.root_dir + "/" + image_path).convert("RGB") 26 | pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values 27 | # add captions by encoding the input 28 | captions = self.tokenizer( 29 | text, padding="max_length", max_length=self.max_length 30 | ).input_ids 31 | captions = [ 32 | caption if caption != self.tokenizer.pad_token_id else -100 33 | for caption in captions 34 | ] 35 | encoding = { 36 | "pixel_values": pixel_values.squeeze(), 37 | "labels": torch.tensor(captions), 38 | } 39 | return encoding 40 | 41 | 42 | def load_dataset(root_dir, df, feature_extractor, tokenizer, max_target_length=512): 43 | # split the dataset into train and test 44 | train_df, val_df = train_test_split(df, test_size=0.1, random_state=42) 45 | train_dataset = Image_Caption_Dataset( 46 | root_dir, train_df, feature_extractor, tokenizer, max_target_length 47 | ) 48 | val_dataset = Image_Caption_Dataset( 49 | root_dir, val_df, feature_extractor, tokenizer, max_target_length 50 | ) 51 | return train_dataset, val_dataset 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hindi Image Captioning Model 2 | 3 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/team-indain-image-caption/Hindi-image-captioning) 4 | 5 | This is an encoder-decoder image captioning model made with [VIT](https://huggingface.co/google/vit-base-patch16-224-in21k) encoder and [GPT2-Hindi](https://huggingface.co/surajp/gpt2-hindi) as a decoder. This is a first attempt at using ViT + GPT2-Hindi for a Hindi image captioning task. We used the Flickr8k Hindi Dataset available on kaggle to train the model. 6 | 7 | This model was trained during HuggingFace course community week, organized by Huggingface. The pretrained weights are available [here](https://huggingface.co/team-indain-image-caption/hindi-image-captioning) 8 | 9 | ## How to use 10 | 11 | Here is how to use this model to caption an image of the Flickr8k dataset: 12 | ```python 13 | import torch 14 | import requests 15 | from PIL import Image 16 | from transformers import ViTFeatureExtractor, AutoTokenizer, \ 17 | VisionEncoderDecoderModel 18 | 19 | if torch.cuda.is_available(): 20 | device = 'cuda' 21 | else: 22 | device = 'cpu' 23 | 24 | image_path = 'sample.jpg' 25 | image = Image.open(image_path) 26 | 27 | encoder_checkpoint = 'google/vit-base-patch16-224' 28 | decoder_checkpoint = 'surajp/gpt2-hindi' 29 | model_checkpoint = 'team-indain-image-caption/hindi-image-captioning' 30 | feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) 31 | tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) 32 | model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) 33 | 34 | #Inference 35 | sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device) 36 | clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0] 37 | 38 | caption_ids = model.generate(sample, max_length = 50)[0] 39 | caption_text = clean_text(tokenizer.decode(caption_ids)) 40 | print(caption_text) 41 | ``` 42 | 43 | ## Training data 44 | We used the Flickr8k Hindi Dataset, which is the translated version of the original Flickr8k Dataset, available on [Kaggle](https://www.kaggle.com/bhushanpatilnew/hindi-caption) to train the model. 45 | 46 | ## Training procedure 47 | This model was trained during HuggingFace course community week, organized by Huggingface. The training was done on Kaggle GPU. 48 | 49 | ## Evaluation Results 50 | 51 | Due to longer inference time, we sampled around 3000 comments from the test dataset and computed METEOR and BLEU scores. 52 | 53 | - BLEU - 0.137 54 | - METEOR - 0.320 55 | 56 | 57 | ## Team Members 58 | - [Sean Benhur](https://www.linkedin.com/in/seanbenhur/) 59 | - [Herumb Shandilya](https://www.linkedin.com/in/herumb-s-740163131/) 60 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dataset import load_dataset 3 | from transformers import (AutoTokenizer, Seq2SeqTrainer, 4 | Seq2SeqTrainingArguments, VisionEncoderDecoderModel, 5 | ViTFeatureExtractor, default_data_collator) 6 | 7 | captions_path = "./Flickr8k-Hindi.txt" 8 | root_dir = "../input/flickr8k/Images" 9 | 10 | encoder_checkpoint = "google/vit-base-patch16-224" 11 | decoder_checkpoint = "surajp/gpt2-hindi" 12 | output_dir = "./image_captioning_checkpoint" 13 | # load feature extractor and tokenizer 14 | feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) 15 | tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) 16 | 17 | with open(captions_path) as f: 18 | data = [] 19 | 20 | for i in f.readlines(): 21 | sp = i.split(" ") 22 | data.append([sp[0] + ".jpg", " ".join(sp[1:])]) 23 | 24 | hindi = pd.DataFrame(data, columns=["images", "text"]) 25 | # image file is not present in dir 26 | hindi = hindi[hindi["images"] != "2258277193_586949ec62.jpg"] 27 | train_dataset, val_dataset = load_dataset(hindi, root_dir, tokenizer, feature_extractor) 28 | 29 | 30 | # initialize a vit-bert from a pretrained ViT and a pretrained GPT2 model 31 | model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( 32 | encoder_checkpoint, decoder_checkpoint 33 | ) 34 | # set special tokens used for creating the decoder_input_ids from the labels 35 | model.config.decoder_start_token_id = tokenizer.bos_token_id 36 | model.config.pad_token_id = tokenizer.pad_token_id 37 | # make sure vocab size is set correctly 38 | model.config.vocab_size = model.config.decoder.vocab_size 39 | 40 | # set beam search parameters 41 | model.config.eos_token_id = tokenizer.sep_token_id 42 | model.config.max_length = 512 43 | model.config.early_stopping = True 44 | model.config.no_repeat_ngram_size = 3 45 | model.config.length_penalty = 2.0 46 | model.config.num_beams = 4 47 | model.decoder.resize_token_embeddings(len(tokenizer)) 48 | 49 | # freeze the encoder 50 | for param in model.encoder.parameters(): 51 | param.requires_grad = False 52 | 53 | 54 | training_args = Seq2SeqTrainingArguments( 55 | predict_with_generate=True, 56 | evaluation_strategy="steps", 57 | per_device_train_batch_size=8, 58 | per_device_eval_batch_size=8, 59 | overwrite_output_dir=True, 60 | fp16=True, 61 | run_name="first_run", 62 | load_best_model_at_end=True, 63 | output_dir=output_dir, 64 | logging_steps=2000, 65 | save_steps=2000, 66 | eval_steps=2000, 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | # instantiate trainer 72 | trainer = Seq2SeqTrainer( 73 | model=model, 74 | tokenizer=feature_extractor, 75 | args=training_args, 76 | train_dataset=train_dataset, 77 | eval_dataset=val_dataset, 78 | data_collator=default_data_collator, 79 | ) 80 | trainer.train() 81 | --------------------------------------------------------------------------------