├── .DS_Store
├── .idea
└── workspace.xml
├── Images
├── COCO_val2014_000000060623.jpg
├── COCO_val2014_000000165547.jpg
├── COCO_val2014_000000354533.jpg
├── COCO_val2014_000000386164.jpg
├── COCO_val2014_000000562207.jpg
├── COCO_val2014_000000579664.jpg
├── CONCEPTUAL_01.jpg
├── CONCEPTUAL_02.jpg
├── CONCEPTUAL_03.jpg
├── CONCEPTUAL_04.jpg
├── CONCEPTUAL_05.jpg
└── CONCEPTUAL_06.jpg
├── LICENSE
├── README.md
├── cog.yaml
├── data
├── .DS_Store
└── coco
│ └── .DS_Store
├── environment.yml
├── notebooks
├── clip_prefix_captioning_inference.ipynb
└── transformer_inference.ipynb
├── parse_coco.py
├── parse_conceptual.py
├── predict.py
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/.DS_Store
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 | suffix
51 | prefix
52 | args.prefix
53 | save_config
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 | 1631795510237
126 |
127 |
128 | 1631795510237
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000060623.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000060623.jpg
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000165547.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000165547.jpg
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000354533.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000354533.jpg
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000386164.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000386164.jpg
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000562207.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000562207.jpg
--------------------------------------------------------------------------------
/Images/COCO_val2014_000000579664.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/COCO_val2014_000000579664.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_01.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_02.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_03.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_04.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_04.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_05.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_05.jpg
--------------------------------------------------------------------------------
/Images/CONCEPTUAL_06.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/Images/CONCEPTUAL_06.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 rmokady
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CLIP prefix captioning.
2 |
3 |
4 | Inference Notebook:
5 |
6 |
7 |
8 |
9 |
10 | ## Official implementation for the paper ["ClipCap: CLIP Prefix for Image Captioning"](https://arxiv.org/abs/2111.09734)
11 |
12 |
13 |
14 |
15 | ## Description
16 | Image captioning is a complicated task, where usually a pretrained detection network is used, requires additional supervision in the form of object annotation. We present a new approach that does not requires additional information (i.e. requires only images and captions), thus can be applied to any data. In addition, our model's training time is much faster than similar methods while achieving comparable to state-of-the-art results, even for the Conceptual Captions dataset contains over 3M images.
17 |
18 | In our work, we use the [CLIP](https://github.com/openai/CLIP) model, which was already trained over an extremely large number of images, thus is capable of generating semantic encodings for arbitrary images without additional supervision. To produce meaningful sentences we fine-tune a pretrained language model, which has been proven to be successful for other natural language tasks. The key idea is to use the CLIP encoding as a prefix to the textual captions by employing a simple mapping network over the raw encoding, and then fine-tune our language model to generate a valid caption. In addition, we present another variant, where we utilize a transformer architecture for the mapping network and avoid the fine-tuning of GPT-2. Still, our light model achieve comaparable to state-of-the-art over nocaps dataset.
19 |
20 | ## COCO Examples
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | A couple of people standing next to an elephant.
30 | A wooden table sitting in front of a window.
31 | A bunch of bananas sitting on top of a table.
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | A woman holding a plate with a piece of cake in front of her face.
43 | A wooden table topped with lots of wooden utensils.
44 | A red motorcycle parked on top of a dirt field.
45 |
46 |
47 |
48 |
49 | ## Conceptual Captions Examples
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 | 3D render of a man holding a globe.
59 | Students enjoing the cherry blossoms
60 | Green leaf of lettuce on a white plate.
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 | The hotel and casino on the waterfront.
72 | The triangle is a symbol of the soul.
73 | Cartoon boy in the bath.
74 |
75 |
76 |
77 |
78 | ## Inference Notebooks
79 | To help visualize the results we provide a Colab notebook found in `notebooks/clip_prefix_captioning_inference.ipynb`.
80 | The notebook will download the pretrained models and run inference on a sample images or
81 | on images of your choosing. It is recommended to run this in [Google Colab](https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing).
82 | Inference notebook for the **transformer mapping network (without fine-tune GPT-2)** can be found [here](https://colab.research.google.com/drive/180L3rMFmGujudwO1EJNF-lHIpAsAZ5xq?usp=sharing) for the COCO model (also in `notebooks/transformer_inference.ipynb`).
83 |
84 |
85 |
86 | Both [COCO](https://drive.google.com/file/d/1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX/view?usp=sharing) and [Conceptual Captions](https://drive.google.com/file/d/14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT/view?usp=sharing) pretrained models are available for mlp mapping network. For the transformer (without fine-tuning GPT-2) we provide [COCO](https://drive.google.com/file/d/1GYPToCqFREwi285wPLhuVExlz7DDUDfJ/view?usp=sharing) pretrained model.
87 |
88 |
89 |
90 | ## Inference GUI
91 | 1. Run it [in the browser](https://replicate.ai/rmokady/clip_prefix_caption) using replicate.ai UI.
92 | 2. Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [](https://huggingface.co/spaces/akhaliq/CLIP_prefix_captioning) (currently not supporting beam search)
93 |
94 |
95 | ## Training prerequisites
96 |
97 | [comment]: <> (Dependencies can be found at the [Inference notebook](https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing) )
98 | Clone, create environment and install dependencies:
99 | ```
100 | git clone https://github.com/rmokady/CLIP_prefix_caption && cd CLIP_prefix_caption
101 | conda env create -f environment.yml
102 | conda activate clip_prefix_caption
103 | ```
104 |
105 | ## COCO training
106 |
107 | Download [train_captions](https://drive.google.com/file/d/1D3EzUK1d1lNhD2hAvRiKPThidiVbP2K_/view?usp=sharing) to `data/coco/annotations`.
108 |
109 | Download [training images](http://images.cocodataset.org/zips/train2014.zip) and [validation images](http://images.cocodataset.org/zips/val2014.zip) and unzip (We use Karpathy et el. split).
110 |
111 | Extract CLIP features using (output is `data/coco/oscar_split_ViT-B_32_train.pkl`):
112 | ```
113 | python parse_coco.py --clip_model_type ViT-B/32
114 | ```
115 | Train with fine-tuning of GPT2:
116 | ```
117 | python train.py --data ./data/coco/oscar_split_ViT-B_32_train.pkl --out_dir ./coco_train/
118 | ```
119 |
120 | Train only transformer mapping network:
121 | ```
122 | python train.py --only_prefix --data ./data/coco/oscar_split_ViT-B_32_train.pkl --out_dir ./coco_train/ --mapping_type transformer --num_layres 8 --prefix_length 40 --prefix_length_clip 40
123 | ```
124 |
125 | **If you wish to use ResNet-based CLIP:**
126 |
127 | ```
128 | python parse_coco.py --clip_model_type RN50x4
129 | ```
130 | ```
131 | python train.py --only_prefix --data ./data/coco/oscar_split_RN50x4_train.pkl --out_dir ./coco_train/ --mapping_type transformer --num_layres 8 --prefix_length 40 --prefix_length_clip 40 --is_rn
132 | ```
133 |
134 | ## Conceptual training
135 |
136 | Download the .TSV train/val files from [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/download) and place them under directory.
137 |
138 | Download the images and extract CLIP features using (outputs are `/conceptual_clip_ViT-B_32_train.pkl` and `/conceptual_clip_ViT-B_32_val.pkl`):
139 | ```
140 | python parse_conceptual.py --clip_model_type ViT-B/32 --data_root --num_threads 16
141 | ```
142 | Notice, downloading the images might take a few days.
143 |
144 | Train with fine-tuning of GPT2:
145 | ```
146 | python train.py --data /conceptual_clip_ViT-B_32_train.pkl --out_dir ./conceptual_train/
147 | ```
148 | Similarly to the COCO training, you can train a transformer mapping network, and / or parse the images using a ResNet-based CLIP.
149 |
150 | ## Citation
151 | If you use this code for your research, please cite:
152 | ```
153 | @article{mokady2021clipcap,
154 | title={ClipCap: CLIP Prefix for Image Captioning},
155 | author={Mokady, Ron and Hertz, Amir and Bermano, Amit H},
156 | journal={arXiv preprint arXiv:2111.09734},
157 | year={2021}
158 | }
159 | ```
160 |
161 |
162 |
163 |
164 | ## Acknowledgments
165 | This repository is heavily based on [CLIP](https://github.com/openai/CLIP) and [Hugging-faces](https://github.com/huggingface/transformers) repositories.
166 | For training we used the data of [COCO dataset](https://cocodataset.org/#home) and [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/).
167 |
168 | ## Contact
169 | For any inquiry please contact us at our email addresses: ron.mokady@gmail.com or amirhertz@mail.tau.ac.il.
170 |
171 |
172 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | image: r8.im/rmokady/clip_prefix_caption
5 |
6 | build:
7 | # set to true if your model requires a GPU
8 | gpu: true
9 | cuda: "10.2"
10 |
11 | # a list of ubuntu apt packages to install
12 | system_packages:
13 | - "libgl1-mesa-glx"
14 | - "libglib2.0-0"
15 |
16 | # python version in the form '3.8' or '3.8.12'
17 | python_version: "3.8"
18 |
19 | # a list of packages in the format ==
20 | python_packages:
21 | - "transformers==4.11.3"
22 | - "git+https://github.com/openai/CLIP.git"
23 | - "torch==1.9.1"
24 | - "numpy==1.19.5"
25 | - "pillow==8.3.2"
26 | - "scikit-image==0.16.2"
27 |
28 | # commands run after the enviroment is setup
29 | run:
30 | # - "echo env is ready!"
31 | # - "echo another command if needed"
32 |
33 | # predict.py defines how predictions are run on your model
34 | predict: "predict.py:Predictor"
35 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/data/.DS_Store
--------------------------------------------------------------------------------
/data/coco/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmokady/CLIP_prefix_caption/1ad805a844a62ab2e5480479aa021bccf0d4d12a/data/coco/.DS_Store
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: clip_prefix_caption
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python>=3.8
6 | - pip>=21
7 | - scikit-image=0.18.1
8 | - pip:
9 | - transformers~=4.10.2
10 | - ftfy
11 | - regex
12 | - tqdm
13 | - git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/parse_coco.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import skimage.io as io
3 | import clip
4 | from PIL import Image
5 | import pickle
6 | import json
7 | import os
8 | from tqdm import tqdm
9 | import argparse
10 |
11 |
12 | def main(clip_model_type: str):
13 | device = torch.device('cuda:0')
14 | clip_model_name = clip_model_type.replace('/', '_')
15 | out_path = f"./data/coco/oscar_split_{clip_model_name}_train.pkl"
16 | clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
17 | with open('./data/coco/annotations/train_caption.json', 'r') as f:
18 | data = json.load(f)
19 | print("%0d captions loaded from json " % len(data))
20 | all_embeddings = []
21 | all_captions = []
22 | for i in tqdm(range(len(data))):
23 | d = data[i]
24 | img_id = d["image_id"]
25 | filename = f"./data/coco/train2014/COCO_train2014_{int(img_id):012d}.jpg"
26 | if not os.path.isfile(filename):
27 | filename = f"./data/coco/val2014/COCO_val2014_{int(img_id):012d}.jpg"
28 | image = io.imread(filename)
29 | image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)
30 | with torch.no_grad():
31 | prefix = clip_model.encode_image(image).cpu()
32 | d["clip_embedding"] = i
33 | all_embeddings.append(prefix)
34 | all_captions.append(d)
35 | if (i + 1) % 10000 == 0:
36 | with open(out_path, 'wb') as f:
37 | pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
38 |
39 | with open(out_path, 'wb') as f:
40 | pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
41 |
42 | print('Done')
43 | print("%0d embeddings saved " % len(all_embeddings))
44 | return 0
45 |
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
50 | args = parser.parse_args()
51 | exit(main(args.clip_model_type))
52 |
--------------------------------------------------------------------------------
/parse_conceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 | from torch.utils.data import DataLoader, Dataset
4 | from PIL import Image
5 | import pickle
6 | from tqdm import tqdm
7 | import os
8 | import csv
9 | import threading
10 | import requests
11 | import shutil
12 | import PIL
13 | import json
14 | from typing import List, Tuple, Optional
15 | import argparse
16 |
17 |
18 | class ConceptualDS(Dataset):
19 |
20 | @staticmethod
21 | def get_all_data(data_root: str, suffix: str):
22 | data = []
23 | for i in range(16):
24 | out_data_path = f"{data_root}/conceptual_{suffix}_{i:02d}.pkl"
25 | if os.path.isfile(out_data_path):
26 | with open(out_data_path, 'rb') as f:
27 | raw_data = pickle.load(f)["info"]
28 | data.append(raw_data)
29 |
30 | return data
31 |
32 | @staticmethod
33 | def collect(data_root: str, suffix: str):
34 | raw_data = ConceptualDS.get_all_data(data_root, suffix)
35 | data = []
36 | for thread_data in raw_data:
37 | for item in thread_data:
38 | data.append((item, thread_data[item]["caption"]))
39 | return data
40 |
41 | def __len__(self):
42 | return len(self.data)
43 |
44 | def __getitem__(self, item: int):
45 | image_name, caption = self.data[item]
46 | image_path = f"{self.data_root}/{self.suffix}/{image_name}.jpg"
47 | is_error = False
48 | image = self.dummy
49 | try:
50 | image = self.preprocess(Image.open(image_path))
51 | except PIL.UnidentifiedImageError:
52 | is_error = True
53 | except OSError:
54 | is_error = True
55 | except BaseException:
56 | is_error = True
57 | if is_error:
58 | return image, "", image_name
59 | return image, caption, image_name
60 |
61 | def __init__(self, data_root: str, preprocess, suffix: str):
62 | self.suffix = suffix
63 | self.data_root = data_root
64 | self.data = self.collect(data_root, suffix)
65 | self.preprocess = preprocess
66 | self.dummy = torch.zeros(3, 288, 288)
67 |
68 |
69 | def save_pickle(data, out_path: str, recover_index: Optional[int] = None):
70 | if os.path.isfile(out_path) and recover_index is not None:
71 | recover_path = f'{out_path[:-4]}_{recover_index:02d}.pkl'
72 | shutil.copyfile(out_path, recover_path)
73 | with open(out_path, 'wb') as f:
74 | pickle.dump(data, f)
75 |
76 |
77 | def get_image(url: str, out_path: str, timeout=10):
78 | try:
79 | r = requests.get(url, stream=True, timeout=timeout)
80 | if r.status_code == 200:
81 | with open(out_path, 'wb') as f:
82 | r.raw.decode_content = True
83 | shutil.copyfileobj(r.raw, f)
84 | return True
85 | return False
86 | except BaseException:
87 | return False
88 |
89 |
90 | def thread(urls: List[Tuple[List[str], int]], thread_id: int, progress: tqdm, lock: Optional[threading.Lock],
91 | suffix: str, conceptual_root: str):
92 | out_root = f"{conceptual_root}/{suffix}"
93 | out_data_path = f"{conceptual_root}/conceptual_{suffix}_{thread_id:02d}.pkl"
94 | recover_index = 0
95 | if os.path.isfile(out_data_path):
96 | with open(out_data_path, 'rb') as f:
97 | data = pickle.load(f)
98 | parsed = data['parsed']
99 | info = data['info']
100 | else:
101 | parsed = set()
102 | info = {}
103 | for i in range(0, len(urls)):
104 | (caption, url), ind = urls[i]
105 | name = f"{ind:08d}"
106 | out_path = f"{out_root}/{name}.jpg"
107 | if url not in parsed and not os.path.isfile(out_path) and get_image(url, out_path):
108 | parsed.add(url)
109 | info[name] = {"url": url, "caption": caption}
110 | if lock is not None:
111 | lock.acquire()
112 | try:
113 | progress.update()
114 | finally:
115 | lock.release()
116 | else:
117 | progress.update()
118 | if (i + 1) % 1000 == 0:
119 | save_pickle({'parsed': parsed, 'info': info}, out_data_path, recover_index)
120 | recover_index = 1 - recover_index
121 | save_pickle({'parsed': parsed, 'info': info}, out_data_path, 2)
122 | return 0
123 |
124 |
125 | def download_conceptual(conceptual_root: str, num_threads: int):
126 | urls = []
127 | for suffix in ("val", "train"):
128 | if suffix == "train":
129 | tsv_path = f"{conceptual_root}/Train_GCC-training.tsv"
130 | else:
131 | tsv_path = f"{conceptual_root}/Validation_GCC-1.1.0-Validation.tsv"
132 | with open(tsv_path) as f:
133 | read_tsv = csv.reader(f, delimiter="\t")
134 | for i, row in enumerate(read_tsv):
135 | urls.append((row, i))
136 | progress = tqdm(total=len(urls))
137 | if num_threads == 1:
138 | thread(urls, 0, progress, None, suffix, conceptual_root)
139 | else:
140 | groups = []
141 | threads = []
142 | lock = threading.Lock()
143 | split_size = len(urls) // num_threads
144 | for i in range(num_threads):
145 | if i < num_threads - 1:
146 | groups.append(urls[i * split_size: (i + 1) * split_size])
147 | else:
148 | groups.append(urls[i * split_size:])
149 | for i in range(num_threads):
150 | threads.append(threading.Thread(target=thread, args=(groups[i], i, progress, lock, suffix, conceptual_root)))
151 | for i in range(num_threads):
152 | threads[i].start()
153 | for i in range(num_threads):
154 | threads[i].join()
155 | progress.close()
156 |
157 |
158 | def add_period(caption: str):
159 | caption = caption.strip()
160 | if caption[-1] != '.':
161 | caption = caption + '.'
162 | elif caption[-2] == ' ':
163 | caption = caption[:-2] + '.'
164 | return caption
165 |
166 |
167 | def create_clip_embeddings(conceptual_root: str, clip_model_type: str):
168 | all_embeddings = []
169 | all_captions = []
170 | for suffix in ("val", "train"):
171 | device = torch.device("cuda:0")
172 | clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
173 | clip_model = clip_model.eval()
174 | ds = ConceptualDS(conceptual_root, preprocess, suffix)
175 | dl = DataLoader(ds, batch_size=200, shuffle=False, num_workers=8, drop_last=False)
176 | progress = tqdm(total=len(dl))
177 | counter = 0
178 | clip_model_name = clip_model_type.replace('/', '_')
179 | out_data_path = f"{conceptual_root}/conceptual_clip_{clip_model_name}_{suffix}.pkl"
180 | recover_index = 0
181 | for i, data in enumerate(dl):
182 | images, captions, image_names = data
183 | images = images.to(device)
184 | with torch.no_grad():
185 | prefix = clip_model.encode_image(images).cpu()
186 | is_valid = list(map(lambda x: x != "", captions))
187 | mask = torch.tensor(is_valid)
188 | all_embeddings.append(prefix[mask])
189 | captions = [caption for j, caption in enumerate(captions) if is_valid[j]]
190 | image_names = [image_name for j, image_name in enumerate(image_names) if is_valid[j]]
191 | all_captions.extend([{"caption": add_period(caption), "clip_embedding": counter + j, "image_id": image_name}
192 | for j, (caption, image_name) in enumerate(zip(captions, image_names))])
193 | progress.update()
194 | counter += len(captions)
195 | if (i + 1) % 1000 == 0:
196 | save_pickle({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, out_data_path, recover_index)
197 | recover_index = 1 - recover_index
198 | save_pickle({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, out_data_path, 2)
199 | progress.close()
200 |
201 | return 0
202 |
203 |
204 | def main():
205 | parser = argparse.ArgumentParser()
206 | parser.add_argument('--data_root', default='./data/conceptual')
207 | parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
208 | parser.add_argument('--num_threads', type=int, default=16)
209 | args = parser.parse_args()
210 | download_conceptual(args.data_root, args.num_threads)
211 | create_clip_embeddings(args.data_root, args.clip_model_type)
212 |
213 |
214 | if __name__ == '__main__':
215 | main()
216 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/python.md
3 |
4 | import clip
5 | import os
6 | from torch import nn
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as nnf
10 | import sys
11 | from typing import Tuple, List, Union, Optional
12 | from transformers import (
13 | GPT2Tokenizer,
14 | GPT2LMHeadModel,
15 | AdamW,
16 | get_linear_schedule_with_warmup,
17 | )
18 | import skimage.io as io
19 | import PIL.Image
20 |
21 | import cog
22 |
23 | # import torch
24 |
25 | N = type(None)
26 | V = np.array
27 | ARRAY = np.ndarray
28 | ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
29 | VS = Union[Tuple[V, ...], List[V]]
30 | VN = Union[V, N]
31 | VNS = Union[VS, N]
32 | T = torch.Tensor
33 | TS = Union[Tuple[T, ...], List[T]]
34 | TN = Optional[T]
35 | TNS = Union[Tuple[TN, ...], List[TN]]
36 | TSN = Optional[TS]
37 | TA = Union[T, ARRAY]
38 |
39 | WEIGHTS_PATHS = {
40 | "coco": "coco_weights.pt",
41 | "conceptual-captions": "conceptual_weights.pt",
42 | }
43 |
44 | D = torch.device
45 | CPU = torch.device("cpu")
46 |
47 |
48 | class Predictor(cog.Predictor):
49 | def setup(self):
50 | """Load the model into memory to make running multiple predictions efficient"""
51 | self.device = torch.device("cuda")
52 | self.clip_model, self.preprocess = clip.load(
53 | "ViT-B/32", device=self.device, jit=False
54 | )
55 | self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
56 |
57 | self.models = {}
58 | self.prefix_length = 10
59 | for key, weights_path in WEIGHTS_PATHS.items():
60 | model = ClipCaptionModel(self.prefix_length)
61 | model.load_state_dict(torch.load(weights_path, map_location=CPU))
62 | model = model.eval()
63 | model = model.to(self.device)
64 | self.models[key] = model
65 |
66 | @cog.input("image", type=cog.Path, help="Input image")
67 | @cog.input(
68 | "model",
69 | type=str,
70 | options=WEIGHTS_PATHS.keys(),
71 | default="coco",
72 | help="Model to use",
73 | )
74 | @cog.input(
75 | "use_beam_search",
76 | type=bool,
77 | default=False,
78 | help="Whether to apply beam search to generate the output text",
79 | )
80 | def predict(self, image, model, use_beam_search):
81 | """Run a single prediction on the model"""
82 | image = io.imread(image)
83 | model = self.models[model]
84 | pil_image = PIL.Image.fromarray(image)
85 | image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
86 | with torch.no_grad():
87 | prefix = self.clip_model.encode_image(image).to(
88 | self.device, dtype=torch.float32
89 | )
90 | prefix_embed = model.clip_project(prefix).reshape(1, self.prefix_length, -1)
91 | if use_beam_search:
92 | return generate_beam(model, self.tokenizer, embed=prefix_embed)[0]
93 | else:
94 | return generate2(model, self.tokenizer, embed=prefix_embed)
95 |
96 |
97 | class MLP(nn.Module):
98 | def forward(self, x: T) -> T:
99 | return self.model(x)
100 |
101 | def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
102 | super(MLP, self).__init__()
103 | layers = []
104 | for i in range(len(sizes) - 1):
105 | layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
106 | if i < len(sizes) - 2:
107 | layers.append(act())
108 | self.model = nn.Sequential(*layers)
109 |
110 |
111 | class ClipCaptionModel(nn.Module):
112 |
113 | # @functools.lru_cache #FIXME
114 | def get_dummy_token(self, batch_size: int, device: D) -> T:
115 | return torch.zeros(
116 | batch_size, self.prefix_length, dtype=torch.int64, device=device
117 | )
118 |
119 | def forward(
120 | self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
121 | ):
122 | embedding_text = self.gpt.transformer.wte(tokens)
123 | prefix_projections = self.clip_project(prefix).view(
124 | -1, self.prefix_length, self.gpt_embedding_size
125 | )
126 | # print(embedding_text.size()) #torch.Size([5, 67, 768])
127 | # print(prefix_projections.size()) #torch.Size([5, 1, 768])
128 | embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
129 | if labels is not None:
130 | dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
131 | labels = torch.cat((dummy_token, tokens), dim=1)
132 | out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
133 | return out
134 |
135 | def __init__(self, prefix_length: int, prefix_size: int = 512):
136 | super(ClipCaptionModel, self).__init__()
137 | self.prefix_length = prefix_length
138 | self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
139 | self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
140 | if prefix_length > 10: # not enough memory
141 | self.clip_project = nn.Linear(
142 | prefix_size, self.gpt_embedding_size * prefix_length
143 | )
144 | else:
145 | self.clip_project = MLP(
146 | (
147 | prefix_size,
148 | (self.gpt_embedding_size * prefix_length) // 2,
149 | self.gpt_embedding_size * prefix_length,
150 | )
151 | )
152 |
153 |
154 | class ClipCaptionPrefix(ClipCaptionModel):
155 | def parameters(self, recurse: bool = True):
156 | return self.clip_project.parameters()
157 |
158 | def train(self, mode: bool = True):
159 | super(ClipCaptionPrefix, self).train(mode)
160 | self.gpt.eval()
161 | return self
162 |
163 |
164 | def generate_beam(
165 | model,
166 | tokenizer,
167 | beam_size: int = 5,
168 | prompt=None,
169 | embed=None,
170 | entry_length=67,
171 | temperature=1.0,
172 | stop_token: str = ".",
173 | ):
174 |
175 | model.eval()
176 | stop_token_index = tokenizer.encode(stop_token)[0]
177 | tokens = None
178 | scores = None
179 | device = next(model.parameters()).device
180 | seq_lengths = torch.ones(beam_size, device=device)
181 | is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
182 | with torch.no_grad():
183 | if embed is not None:
184 | generated = embed
185 | else:
186 | if tokens is None:
187 | tokens = torch.tensor(tokenizer.encode(prompt))
188 | tokens = tokens.unsqueeze(0).to(device)
189 | generated = model.gpt.transformer.wte(tokens)
190 | for i in range(entry_length):
191 | outputs = model.gpt(inputs_embeds=generated)
192 | logits = outputs.logits
193 | logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
194 | logits = logits.softmax(-1).log()
195 | if scores is None:
196 | scores, next_tokens = logits.topk(beam_size, -1)
197 | generated = generated.expand(beam_size, *generated.shape[1:])
198 | next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
199 | if tokens is None:
200 | tokens = next_tokens
201 | else:
202 | tokens = tokens.expand(beam_size, *tokens.shape[1:])
203 | tokens = torch.cat((tokens, next_tokens), dim=1)
204 | else:
205 | logits[is_stopped] = -float(np.inf)
206 | logits[is_stopped, 0] = 0
207 | scores_sum = scores[:, None] + logits
208 | seq_lengths[~is_stopped] += 1
209 | scores_sum_average = scores_sum / seq_lengths[:, None]
210 | scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
211 | beam_size, -1
212 | )
213 | next_tokens_source = next_tokens // scores_sum.shape[1]
214 | seq_lengths = seq_lengths[next_tokens_source]
215 | next_tokens = next_tokens % scores_sum.shape[1]
216 | next_tokens = next_tokens.unsqueeze(1)
217 | tokens = tokens[next_tokens_source]
218 | tokens = torch.cat((tokens, next_tokens), dim=1)
219 | generated = generated[next_tokens_source]
220 | scores = scores_sum_average * seq_lengths
221 | is_stopped = is_stopped[next_tokens_source]
222 | next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
223 | generated.shape[0], 1, -1
224 | )
225 | generated = torch.cat((generated, next_token_embed), dim=1)
226 | is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
227 | if is_stopped.all():
228 | break
229 | scores = scores / seq_lengths
230 | output_list = tokens.cpu().numpy()
231 | output_texts = [
232 | tokenizer.decode(output[: int(length)])
233 | for output, length in zip(output_list, seq_lengths)
234 | ]
235 | order = scores.argsort(descending=True)
236 | output_texts = [output_texts[i] for i in order]
237 | return output_texts
238 |
239 |
240 | def generate2(
241 | model,
242 | tokenizer,
243 | tokens=None,
244 | prompt=None,
245 | embed=None,
246 | entry_count=1,
247 | entry_length=67, # maximum number of words
248 | top_p=0.8,
249 | temperature=1.0,
250 | stop_token: str = ".",
251 | ):
252 | model.eval()
253 | generated_num = 0
254 | generated_list = []
255 | stop_token_index = tokenizer.encode(stop_token)[0]
256 | filter_value = -float("Inf")
257 | device = next(model.parameters()).device
258 |
259 | with torch.no_grad():
260 |
261 | for entry_idx in range(entry_count):
262 | if embed is not None:
263 | generated = embed
264 | else:
265 | if tokens is None:
266 | tokens = torch.tensor(tokenizer.encode(prompt))
267 | tokens = tokens.unsqueeze(0).to(device)
268 |
269 | generated = model.gpt.transformer.wte(tokens)
270 |
271 | for i in range(entry_length):
272 |
273 | outputs = model.gpt(inputs_embeds=generated)
274 | logits = outputs.logits
275 | logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
276 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
277 | cumulative_probs = torch.cumsum(
278 | nnf.softmax(sorted_logits, dim=-1), dim=-1
279 | )
280 | sorted_indices_to_remove = cumulative_probs > top_p
281 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
282 | ..., :-1
283 | ].clone()
284 | sorted_indices_to_remove[..., 0] = 0
285 |
286 | indices_to_remove = sorted_indices[sorted_indices_to_remove]
287 | logits[:, indices_to_remove] = filter_value
288 | next_token = torch.argmax(logits, -1).unsqueeze(0)
289 | next_token_embed = model.gpt.transformer.wte(next_token)
290 | if tokens is None:
291 | tokens = next_token
292 | else:
293 | tokens = torch.cat((tokens, next_token), dim=1)
294 | generated = torch.cat((generated, next_token_embed), dim=1)
295 | if stop_token_index == next_token.item():
296 | break
297 |
298 | output_list = list(tokens.squeeze().cpu().numpy())
299 | output_text = tokenizer.decode(output_list)
300 | generated_list.append(output_text)
301 |
302 | return generated_list[0]
303 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as nnf
4 | from torch.utils.data import Dataset, DataLoader
5 | from enum import Enum
6 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
7 | from tqdm import tqdm
8 | import os
9 | import pickle
10 | import sys
11 | import argparse
12 | import json
13 | from typing import Tuple, Optional, Union
14 |
15 |
16 | class MappingType(Enum):
17 | MLP = 'mlp'
18 | Transformer = 'transformer'
19 |
20 |
21 | class ClipCocoDataset(Dataset):
22 |
23 | def __len__(self) -> int:
24 | return len(self.captions_tokens)
25 |
26 | def pad_tokens(self, item: int):
27 | tokens = self.captions_tokens[item]
28 | padding = self.max_seq_len - tokens.shape[0]
29 | if padding > 0:
30 | tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
31 | self.captions_tokens[item] = tokens
32 | elif padding < 0:
33 | tokens = tokens[:self.max_seq_len]
34 | self.captions_tokens[item] = tokens
35 | mask = tokens.ge(0) # mask is zero where we out of sequence
36 | tokens[~mask] = 0
37 | mask = mask.float()
38 | mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
39 | return tokens, mask
40 |
41 | def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
42 | tokens, mask = self.pad_tokens(item)
43 | prefix = self.prefixes[self.caption2embedding[item]]
44 | if self.normalize_prefix:
45 | prefix = prefix.float()
46 | prefix = prefix / prefix.norm(2, -1)
47 | return tokens, mask, prefix
48 |
49 | def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2",
50 | normalize_prefix=False):
51 | self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
52 | self.prefix_length = prefix_length
53 | self.normalize_prefix = normalize_prefix
54 | with open(data_path, 'rb') as f:
55 | all_data = pickle.load(f)
56 | print("Data size is %0d" % len(all_data["clip_embedding"]))
57 | sys.stdout.flush()
58 | self.prefixes = all_data["clip_embedding"]
59 | captions_raw = all_data["captions"]
60 | self.image_ids = [caption["image_id"] for caption in captions_raw]
61 | self.captions = [caption['caption'] for caption in captions_raw]
62 | if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
63 | with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
64 | self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
65 | else:
66 | self.captions_tokens = []
67 | self.caption2embedding = []
68 | max_seq_len = 0
69 | for caption in captions_raw:
70 | self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
71 | self.caption2embedding.append(caption["clip_embedding"])
72 | max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
73 | # self.max_seq_len = max_seq_len
74 | with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
75 | pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
76 | all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
77 | self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
78 |
79 |
80 | class MLP(nn.Module):
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | return self.model(x)
84 |
85 | def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
86 | super(MLP, self).__init__()
87 | layers = []
88 | for i in range(len(sizes) - 1):
89 | layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
90 | if i < len(sizes) - 2:
91 | layers.append(act())
92 | self.model = nn.Sequential(*layers)
93 |
94 |
95 | class MlpTransformer(nn.Module):
96 | def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
97 | super().__init__()
98 | out_d = out_d if out_d is not None else in_dim
99 | self.fc1 = nn.Linear(in_dim, h_dim)
100 | self.act = act
101 | self.fc2 = nn.Linear(h_dim, out_d)
102 | self.dropout = nn.Dropout(dropout)
103 |
104 | def forward(self, x):
105 | x = self.fc1(x)
106 | x = self.act(x)
107 | x = self.dropout(x)
108 | x = self.fc2(x)
109 | x = self.dropout(x)
110 | return x
111 |
112 | class MultiHeadAttention(nn.Module):
113 |
114 | def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
115 | super().__init__()
116 | self.num_heads = num_heads
117 | head_dim = dim_self // num_heads
118 | self.scale = head_dim ** -0.5
119 | self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
120 | self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
121 | self.project = nn.Linear(dim_self, dim_self)
122 | self.dropout = nn.Dropout(dropout)
123 |
124 | def forward(self, x, y=None, mask=None):
125 | y = y if y is not None else x
126 | b, n, c = x.shape
127 | _, m, d = y.shape
128 | # b n h dh
129 | queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
130 | # b m 2 h dh
131 | keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
132 | keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
133 | attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
134 | if mask is not None:
135 | if mask.dim() == 2:
136 | mask = mask.unsqueeze(1)
137 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
138 | attention = attention.softmax(dim=2)
139 | out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
140 | out = self.project(out)
141 | return out, attention
142 |
143 |
144 | class TransformerLayer(nn.Module):
145 |
146 | def forward_with_attention(self, x, y=None, mask=None):
147 | x_, attention = self.attn(self.norm1(x), y, mask)
148 | x = x + x_
149 | x = x + self.mlp(self.norm2(x))
150 | return x, attention
151 |
152 | def forward(self, x, y=None, mask=None):
153 | x = x + self.attn(self.norm1(x), y, mask)[0]
154 | x = x + self.mlp(self.norm2(x))
155 | return x
156 |
157 | def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
158 | norm_layer: nn.Module = nn.LayerNorm):
159 | super().__init__()
160 | self.norm1 = norm_layer(dim_self)
161 | self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
162 | self.norm2 = norm_layer(dim_self)
163 | self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
164 |
165 |
166 | class Transformer(nn.Module):
167 |
168 | def forward_with_attention(self, x, y=None, mask=None):
169 | attentions = []
170 | for layer in self.layers:
171 | x, att = layer.forward_with_attention(x, y, mask)
172 | attentions.append(att)
173 | return x, attentions
174 |
175 | def forward(self, x, y=None, mask=None):
176 | for i, layer in enumerate(self.layers):
177 | if i % 2 == 0 and self.enc_dec: # cross
178 | x = layer(x, y)
179 | elif self.enc_dec: # self
180 | x = layer(x, x, mask)
181 | else: # self or cross
182 | x = layer(x, y, mask)
183 | return x
184 |
185 | def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
186 | mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
187 | super(Transformer, self).__init__()
188 | dim_ref = dim_ref if dim_ref is not None else dim_self
189 | self.enc_dec = enc_dec
190 | if enc_dec:
191 | num_layers = num_layers * 2
192 | layers = []
193 | for i in range(num_layers):
194 | if i % 2 == 0 and enc_dec: # cross
195 | layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
196 | elif enc_dec: # self
197 | layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
198 | else: # self or cross
199 | layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
200 | self.layers = nn.ModuleList(layers)
201 |
202 |
203 | class TransformerMapper(nn.Module):
204 |
205 | def forward(self, x):
206 | x = self.linear(x).view(x.shape[0], self.clip_length, -1)
207 | prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
208 | prefix = torch.cat((x, prefix), dim=1)
209 | out = self.transformer(prefix)[:, self.clip_length:]
210 | return out
211 |
212 | def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
213 | super(TransformerMapper, self).__init__()
214 | self.clip_length = clip_length
215 | self.transformer = Transformer(dim_embedding, 8, num_layers)
216 | self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
217 | self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
218 |
219 |
220 | class ClipCaptionModel(nn.Module):
221 |
222 | def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
223 | return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
224 |
225 | def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
226 | labels: Optional[torch.Tensor] = None):
227 | embedding_text = self.gpt.transformer.wte(tokens)
228 | prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
229 | embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
230 | if labels is not None:
231 | dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
232 | labels = torch.cat((dummy_token, tokens), dim=1)
233 | out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
234 | return out
235 |
236 | def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
237 | num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
238 | super(ClipCaptionModel, self).__init__()
239 | self.prefix_length = prefix_length
240 | self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
241 | self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
242 | if mapping_type == MappingType.MLP:
243 | self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
244 | self.gpt_embedding_size * prefix_length))
245 | else:
246 | self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
247 | clip_length, num_layers)
248 |
249 |
250 | class ClipCaptionPrefix(ClipCaptionModel):
251 |
252 | def parameters(self, recurse: bool = True):
253 | return self.clip_project.parameters()
254 |
255 | def train(self, mode: bool = True):
256 | super(ClipCaptionPrefix, self).train(mode)
257 | self.gpt.eval()
258 | return self
259 |
260 |
261 | def save_config(args: argparse.Namespace):
262 | config = {}
263 | for key, item in args._get_kwargs():
264 | config[key] = item
265 | out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
266 | with open(out_path, 'w') as outfile:
267 | json.dump(config, outfile)
268 |
269 |
270 | def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
271 | with open(config_path) as f:
272 | config = json.load(f)
273 | parser = argparse.ArgumentParser()
274 | parser.set_defaults(**config)
275 | args = parser.parse_args()
276 | if type(epoch_or_latest) is int:
277 | epoch_or_latest = f"-{epoch_or_latest:03d}"
278 | model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
279 | if args.only_prefix:
280 | model = ClipCaptionPrefix(args.prefix_length)
281 | else:
282 | model = ClipCaptionModel(args.prefix_length)
283 | if os.path.isfile(model_path):
284 | print(f"loading model from {model_path}")
285 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
286 | else:
287 | print(f"{model_path} is not exist")
288 | return model, parser
289 |
290 |
291 | def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
292 | lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):
293 |
294 | device = torch.device('cuda:0')
295 | batch_size = args.bs
296 | epochs = args.epochs
297 | if not os.path.exists(output_dir):
298 | os.makedirs(output_dir)
299 | model = model.to(device)
300 | model.train()
301 | optimizer = AdamW(model.parameters(), lr=lr)
302 | train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
303 | scheduler = get_linear_schedule_with_warmup(
304 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
305 | )
306 | # save_config(args)
307 | for epoch in range(epochs):
308 | print(f">>> Training epoch {epoch}")
309 | sys.stdout.flush()
310 | progress = tqdm(total=len(train_dataloader), desc=output_prefix)
311 | for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
312 | model.zero_grad()
313 | tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
314 | outputs = model(tokens, prefix, mask)
315 | logits = outputs.logits[:, dataset.prefix_length - 1: -1]
316 | loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
317 | loss.backward()
318 | optimizer.step()
319 | scheduler.step()
320 | optimizer.zero_grad()
321 | progress.set_postfix({"loss": loss.item()})
322 | progress.update()
323 | if (idx + 1) % 10000 == 0:
324 | torch.save(
325 | model.state_dict(),
326 | os.path.join(output_dir, f"{output_prefix}_latest.pt"),
327 | )
328 | progress.close()
329 | if epoch % args.save_every == 0 or epoch == epochs - 1:
330 | torch.save(
331 | model.state_dict(),
332 | os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
333 | )
334 | return model
335 |
336 |
337 | def main():
338 | parser = argparse.ArgumentParser()
339 | parser.add_argument('--data', default='./data/coco/oscar_split_train.pkl')
340 | parser.add_argument('--out_dir', default='./checkpoints')
341 | parser.add_argument('--prefix', default='coco_prefix', help='prefix for saved filenames')
342 | parser.add_argument('--epochs', type=int, default=10)
343 | parser.add_argument('--save_every', type=int, default=1)
344 | parser.add_argument('--prefix_length', type=int, default=10)
345 | parser.add_argument('--prefix_length_clip', type=int, default=10)
346 | parser.add_argument('--bs', type=int, default=40)
347 | parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
348 | parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
349 | parser.add_argument('--num_layers', type=int, default=8)
350 | parser.add_argument('--is_rn', dest='is_rn', action='store_true')
351 | parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
352 | args = parser.parse_args()
353 | prefix_length = args.prefix_length
354 | dataset = ClipCocoDataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix)
355 | prefix_dim = 640 if args.is_rn else 512
356 | args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]
357 | if args.only_prefix:
358 | model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
359 | num_layers=args.num_layers, mapping_type=args.mapping_type)
360 | print("Train only prefix")
361 | else:
362 | model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
363 | num_layers=args.num_layers, mapping_type=args.mapping_type)
364 | print("Train both prefix and GPT")
365 | sys.stdout.flush()
366 | train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix)
367 |
368 |
369 | if __name__ == '__main__':
370 | main()
371 |
--------------------------------------------------------------------------------