├── .gitignore
├── LICENSE
├── README.md
├── images
├── Balloons.jpg
├── Driving.jpg
└── Family-cooking.jpg
└── persian-image-captioning.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Hamtech-ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Persian-Image-Captioning
2 |
3 | [](https://huggingface.co/spaces/MahsaShahidi/Persian-Image-Captioning)
4 |
5 | We fine-tuning the [Vision Encoder Decoder Model](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vision-encoder-decoder#transformers.VisionEncoderDecoderModel) for the task of image captioning on the [coco-flickr-farsi](https://www.kaggle.com/navidkanaani/coco-flickr-farsi) dataset. The implementation of our model is in PyTorch with transformers library by Hugging Face(🤗).
6 |
7 | You can choose any pretrained vision model and any language model to use in the Vision Encoder Decoder model. Here we use [ViT](https://huggingface.co/google/vit-base-patch16-224-in21k) as the encoder, and [ParsBERT (v2.0)](https://huggingface.co/HooshvareLab/bert-fa-base-uncased-clf-persiannews) as the decoder. The encoder and decoder are loaded separately via `from_pretrained()`function. Cross-attention layers are randomly initialized and added to the decoder.
8 |
9 | You may refer to the [Vision Encoder Decoder Model](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder) for more information.
10 |
11 | ## How to use
12 | You can generate caption of an image using this model using the code below:
13 | ```python
14 | import torch
15 | import urllib
16 | import PIL
17 | import matplotlib.pyplot as plt
18 | from transformers import ViTFeatureExtractor, AutoTokenizer, \
19 | VisionEncoderDecoderModel
20 |
21 | def show_img(image):
22 | # show image
23 | plt.axis("off")
24 | plt.imshow(image)
25 |
26 | if torch.cuda.is_available():
27 | device = 'cuda'
28 | else:
29 | device = 'cpu'
30 |
31 |
32 | #pass the url of any image to generate a caption for it
33 | urllib.request.urlretrieve("https://images.unsplash.com/photo-1628191011227-522c7c3f0af9?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=870&q=80", "sample.png")
34 | image = PIL.Image.open("sample.png")
35 |
36 |
37 | #Load the model you trained for inference
38 | model_checkpoint = 'MahsaShahidi/Persian-Image-Captioning'
39 | model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
40 |
41 | feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
42 | tokenizer = AutoTokenizer.from_pretrained('HooshvareLab/bert-fa-base-uncased-clf-persiannews')
43 |
44 | sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
45 | caption_ids = model.generate(sample, max_length = 30)[0]
46 | caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
47 | print(caption_text)
48 | show_img(image)
49 | ```
50 |
51 | ## Inference
52 | Following are the reslts of 3 captions generated on free stock photos after 2 epochs of training.
53 | Image | Caption
54 | --- | ---
55 |
| **Generated Caption:** زنی در آشپزخانه در حال اماده کردن غذا است.
56 |
| **Generated Caption:** گروهی از مردم در حال پرواز بادبادک در یک زمین چمنزار.
57 |
| **Generated Caption:** مردی در ماشین نشسته و به ماشین نگاه می کند.
58 |
59 |
60 |
61 | ## Credits
62 | A huge thanks to Kaggle for providing free access to GPU, and to the creators of Huggingface, ViT, and ParsBERT!
63 |
64 |
65 | ## References
66 | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929 )
67 |
--------------------------------------------------------------------------------
/images/Balloons.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hamtech-ai/Persian-Image-Captioning/3b12d38a09dd62b21cae1d8321b0ad938b118df0/images/Balloons.jpg
--------------------------------------------------------------------------------
/images/Driving.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hamtech-ai/Persian-Image-Captioning/3b12d38a09dd62b21cae1d8321b0ad938b118df0/images/Driving.jpg
--------------------------------------------------------------------------------
/images/Family-cooking.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hamtech-ai/Persian-Image-Captioning/3b12d38a09dd62b21cae1d8321b0ad938b118df0/images/Family-cooking.jpg
--------------------------------------------------------------------------------