├── assets └── cover.png ├── environment.yml ├── modules ├── checkpoint.py ├── train_si.py ├── train_nrp.py └── train_gpp.py ├── .gitignore ├── models ├── model.py ├── si_models.py ├── nrp_models.py └── gpp_models.py ├── utils ├── misc.py └── metric_logger.py ├── README.md ├── preprocess └── download_images.py ├── EXPERIMENT.md ├── main_si.py ├── main_nrp.py ├── main_gpp.py ├── DATASET.md ├── LICENSE └── data └── mpchat_nrp.py /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahnjaewoo/MPCHAT/HEAD/assets/cover.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mpchat 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7.10 6 | - pip=21.0.1 7 | - cudatoolkit=10.1 8 | - pytorch::pytorch=1.7.1 9 | - torchvision=0.8.2 10 | - torchaudio=0.7.2 11 | - pip: 12 | - tqdm==4.64.0 13 | - nltk==3.7 14 | - transformers==4.17.0 15 | - protobuf==3.20.0 16 | - gallery-dl==1.21.2 17 | - tensorflow-hub==0.7.0 18 | - tensorflow==2.0.0 19 | - pillow==6.1.0 20 | - praw==7.5.0 21 | - sentence-transformers==2.2.0 22 | - emoji==1.7.0 23 | - redditcleaner==1.1.2 24 | - tensorboardx==2.5 25 | - tensorboard==2.0.2 26 | - timm==0.4.9 27 | - ipython==7.34.0 28 | - ftfy==6.1.1 29 | -------------------------------------------------------------------------------- /modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from transformers import WEIGHTS_NAME 5 | 6 | def load_checkpoint_args(args, logger): 7 | recover_args = {'global_step': 0, 'step': 0, 'last_checkpoint_dir': None, 8 | 'last_best_checkpoint_dir': None, 'last_best_score': None} 9 | 10 | if os.path.exists(args.output_dir): 11 | save_file = os.path.join(args.output_dir, "last_checkpoint") 12 | try: 13 | with open(save_file, "r") as f: 14 | texts = f.read().split('\n') 15 | last_saved = texts[0] 16 | last_saved = last_saved.strip() 17 | last_best_saved = texts[1].split('best: ')[-1].strip() 18 | last_best_score = json.loads(texts[2]) 19 | 20 | except IOError: 21 | # if file doesn't exist, maybe because it has just been 22 | # deleted by a separate process 23 | last_saved = "" 24 | if last_saved: 25 | folder_name = os.path.splitext(last_saved.split('/')[0])[0] # in the form of checkpoint-00001 or checkpoint-00001/pytorch_model.bin 26 | recover_args['last_checkpoint_dir'] = os.path.join(args.output_dir, folder_name) 27 | recover_args['epoch'] = int(folder_name.split('-')[1]) 28 | recover_args['global_step'] = int(folder_name.split('-')[2]) 29 | recover_args['last_best_checkpoint_dir'] = os.path.join(args.output_dir, last_best_saved) 30 | recover_args['last_best_score'] = last_best_score 31 | assert os.path.isfile(os.path.join(recover_args['last_checkpoint_dir'], WEIGHTS_NAME)), "Last_checkpoint detected, but file not found!" 32 | 33 | if recover_args['last_checkpoint_dir'] is not None: # recovery 34 | args.model_name_or_path = recover_args['last_checkpoint_dir'] 35 | logger.info(" -> Recovering model from {}".format(recover_args['last_checkpoint_dir'])) 36 | 37 | return recover_args 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import ( 4 | CLIPModel, 5 | AutoModel, 6 | ) 7 | 8 | # contrastive loss function, adapted from 9 | # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html 10 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 11 | return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) 12 | 13 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor: 14 | caption_loss = contrastive_loss(similarity) 15 | image_loss = contrastive_loss(similarity.t()) 16 | return (caption_loss + image_loss) / 2.0 17 | 18 | #Mean Pooling - Take attention mask into account for correct averaging 19 | def mean_pooling(model_output, attention_mask): 20 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 21 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 22 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 23 | 24 | class BaseModel(torch.nn.Module): 25 | def __init__(self, args, clip_processor): 26 | super().__init__() 27 | self.args = args 28 | self.clip_processor = clip_processor 29 | 30 | def forward(self): 31 | pass 32 | 33 | class ClipClipModel(BaseModel): 34 | def __init__(self, args, clip_processor): 35 | super().__init__(args, clip_processor) 36 | 37 | clip_models = [CLIPModel.from_pretrained(args.clip_model_name) for x in range(3)] 38 | self.context_image_encoder = clip_models[0].vision_model 39 | self.context_image_projection = clip_models[0].visual_projection 40 | self.context_text_encoder = clip_models[0].text_model 41 | self.context_text_projection = clip_models[0].text_projection 42 | self.response_encoder = clip_models[1].text_model 43 | self.response_projection = clip_models[1].text_projection 44 | self.persona_text_encoder = clip_models[2].text_model 45 | self.persona_text_projection = clip_models[2].text_projection 46 | self.persona_image_encoder = clip_models[1].vision_model 47 | self.persona_image_projection = clip_models[1].visual_projection 48 | 49 | self.logit_scale = clip_models[0].logit_scale 50 | 51 | class ClipSbertModel(BaseModel): 52 | def __init__(self, args, clip_processor): 53 | super().__init__(args, clip_processor) 54 | 55 | sbert_models = [AutoModel.from_pretrained(args.sbert_model_name) for x in range(3)] 56 | self.context_text_encoder = sbert_models[0] 57 | self.persona_text_encoder = sbert_models[1] 58 | self.response_encoder = sbert_models[2] 59 | clip_models = [CLIPModel.from_pretrained(args.clip_model_name) for x in range(2)] 60 | self.context_image_encoder = clip_models[0].vision_model 61 | self.context_image_projection = torch.nn.Linear(self.context_image_encoder.config.hidden_size, self.context_text_encoder.config.hidden_size) 62 | self.persona_image_encoder = clip_models[1].vision_model 63 | self.persona_image_projection = torch.nn.Linear(self.context_image_encoder.config.hidden_size, self.context_text_encoder.config.hidden_size) 64 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license. 2 | 3 | import errno 4 | import os 5 | import os.path as op 6 | import yaml 7 | import random 8 | import torch 9 | import numpy as np 10 | import torch.distributed as dist 11 | 12 | from PIL import Image 13 | 14 | def mkdir(path): 15 | # if it is the current folder, skip. 16 | if path == '': 17 | return 18 | try: 19 | os.makedirs(path) 20 | except OSError as e: 21 | if e.errno != errno.EEXIST: 22 | raise 23 | 24 | 25 | def set_seed(seed, n_gpu): 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | if n_gpu > 0: 30 | torch.cuda.manual_seed_all(seed) 31 | 32 | 33 | def load_from_yaml_file(yaml_file): 34 | with open(yaml_file, 'r') as fp: 35 | return yaml.load(fp) 36 | 37 | 38 | def find_file_path_in_yaml(fname, root): 39 | if fname is not None: 40 | if op.isfile(fname): 41 | return fname 42 | elif op.isfile(op.join(root, fname)): 43 | return op.join(root, fname) 44 | else: 45 | raise FileNotFoundError( 46 | errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) 47 | ) 48 | 49 | 50 | def get_rank(): 51 | if not dist.is_available(): 52 | return 0 53 | if not dist.is_initialized(): 54 | return 0 55 | return dist.get_rank() 56 | 57 | 58 | def is_main_process(): 59 | return get_rank() == 0 60 | 61 | 62 | def get_world_size(): 63 | if not dist.is_available(): 64 | return 1 65 | if not dist.is_initialized(): 66 | return 1 67 | return dist.get_world_size() 68 | 69 | def compute_metrics_from_logits(logits, targets): 70 | """ 71 | recall@k for N candidates 72 | 73 | logits: (batch_size, num_candidates) 74 | targets: (batch_size, ) 75 | """ 76 | batch_size, num_candidates = logits.shape 77 | 78 | sorted_indices = logits.sort(descending=True)[1] 79 | targets = targets.tolist() 80 | 81 | recall_k = dict() 82 | if num_candidates <= 10: 83 | ks = [1, max(1, round(num_candidates*0.2)), max(1, round(num_candidates*0.5))] 84 | elif num_candidates <= 100: 85 | ks = [1, max(1, round(num_candidates*0.1)), max(1, round(num_candidates*0.5))] 86 | else: 87 | raise ValueError("num_candidates: {0} is not proper".format(num_candidates)) 88 | for k in ks: 89 | # sorted_indices[:,:k]: (batch_size, k) 90 | num_ok = 0 91 | for tgt, topk in zip(targets, sorted_indices[:,:k].tolist()): 92 | if tgt in topk: 93 | num_ok += 1 94 | recall_k[f'recall@{k}'] = (num_ok/batch_size) 95 | 96 | # MRR 97 | MRR = 0 98 | for tgt, topk in zip(targets, sorted_indices.tolist()): 99 | rank = topk.index(tgt)+1 100 | MRR += 1/rank 101 | MRR = MRR/batch_size 102 | return recall_k, MRR 103 | 104 | def pil_loader(path: str) -> Image.Image: 105 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 106 | with open(path, 'rb') as f: 107 | img = Image.open(f) 108 | return img.convert('RGB') 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPCHAT 2 | 3 | Welcome! 👋🏻\ 4 | This is the official repository of our ACL 2023 paper: \ 5 | **[MPCHAT: Towards Multimodal Persona-Grounded Conversation.](https://aclanthology.org/2023.acl-long.189/)** [[Poster](https://drive.google.com/file/d/1THxA0WcQzuQIM-omBNYiFCytBNftlIOz/view?usp=sharing)] [[Slides](https://drive.google.com/file/d/1ZplE-6PM45A23KYYceV9geMRx_itfG8k/view?usp=drive_link)] 6 | 7 | ![dialogue illustration](assets/cover.png) 8 | 9 | Please cite our work if you found the resources in this repository useful: 10 | 11 | ```bib 12 | @inproceedings{ahn2023mpchat, 13 | title={MPCHAT: Towards Multimodal Persona-Grounded Conversation}, 14 | author={Jaewoo Ahn and Yeda Song and Sangdoo Yun and Gunhee Kim}, 15 | booktitle={ACL}, 16 | year=2023 17 | } 18 | ``` 19 | 20 | ## Dataset 21 | 22 | ### Note for dataset users 23 | 24 | **Terms of use:** Uses of MPCHAT are subject to [Reddit API terms](https://www.reddit.com/wiki/api-terms/). Users must comply with [Reddit User Agreeement, Content Policy, and Privacy Policy](https://www.redditinc.com/policies). 25 | 26 | **Usage restrictions**: MPCHAT should only be used for non-commercial research. Any commercial and for-profit uses of MPCHAT are restricted – it should not be used to train models that will be deployed in production systems as part of a product offered by businesses or government agencies. 27 | 28 | Refer to the Ethics Statement in the paper more details. 29 | 30 | ### Dataset download 31 | You can download our dialog dataset directly by clicking this [link](https://drive.google.com/file/d/18bur87ayw_8NkQsCz_UCml0mHfuytlVu/view?usp=drive_link). 32 | 33 | Please check [DATASET.md](DATASET.md) for descriptions of structure and attribute in the dataset. 34 | 35 | **Note**: We do not distribute image files as we do not legally own them. The annotation files contain image URLs – you can download images as the following (but it is a bit slow): 36 | ```python 37 | python preprocess/download_images.py \ 38 | --dialog_image_url_directory "./dialog_image_urls.json" \ 39 | --persona_image_url_directory "./persona_image_urls.json" \ 40 | --save_dialog_image_directory "./images/dialog/" \ 41 | --save_persona_image_directory "./images/persona/" 42 | ``` 43 | or instead, by using the [redcaps-downloader tool](https://github.com/redcaps-dataset/redcaps-downloader). 44 | 45 | ### Image removal request 46 | 47 | Did you find any problematic image in MPCHAT that should be removed? Report it to us using this [link](https://docs.google.com/forms/d/e/1FAIpQLSdz6q2IlE3Npr-ZidJmCAW-xMzu48m5-Jcta4r6FXEHwRBGYQ/viewform?usp=sf_link)! We will review all requests and remove problematic image in the next version release. 48 | 49 | ## Model training and inference 50 | 51 | ### Environment setup 52 | 53 | We recommend using Anaconda. The following command will create a new conda environment mpchat with all the dependencies. 54 | ```bash 55 | conda env create -f environment.yml 56 | ``` 57 | 58 | To activate the environment: 59 | ```bash 60 | conda activate mpchat 61 | ``` 62 | 63 | Please check [EXPERIMENT.md](EXPERIMENT.md) for model training and inference per each task. 64 | 65 | ## Have any questions? 66 | 67 | Please contact [Jaewoo Ahn](https://ahnjaewoo.github.io) at jaewoo.ahn at vl.snu.ac.kr 68 | 69 | ## License 70 | 71 | This repository is CC-BY 4.0 licensed. See the [LICENSE](https://github.com/ahnjaewoo/mpchat/blob/main/LICENSE) file for details. 72 | -------------------------------------------------------------------------------- /preprocess/download_images.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | import json 5 | import argparse 6 | import requests 7 | 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | def save_image(id2image_url, args, mode): 12 | assert mode in ['dialog', 'persona'] 13 | for image_id, image_url in tqdm(id2image_url.items(), ncols=50): 14 | fname = f"{image_id}_{image_url.split('/')[-1]}" 15 | request_trial = 0 16 | while True: 17 | try: 18 | response = requests.get(image_url) 19 | # check if image was downloaded (response must be 200). one exception: 20 | # imgur gives response 200 with "removed.png" image if not found. 21 | if response.status_code != 200: 22 | print(f'Wrong status_code = {response.status_code}') 23 | if response.status_code == 429: 24 | time.sleep(10) 25 | elif "removed.png" in response.url: 26 | print(f"Removed image: {image_url}") 27 | break 28 | else: 29 | # Write image to disk if it was downloaded successfully. 30 | pil_image = Image.open(io.BytesIO(response.content)).convert("RGB") 31 | image_width, image_height = pil_image.size 32 | scale = args.longer_resize / float(max(image_width, image_height)) 33 | if scale != 1.0: 34 | new_width, new_height = tuple( 35 | int(round(d * scale)) for d in (image_width, image_height) 36 | ) 37 | pil_image = pil_image.resize((new_width, new_height)) 38 | if mode == 'dialog': 39 | pil_image.save(os.path.join(args.save_dialog_image_directory, fname)) 40 | else: 41 | pil_image.save(os.path.join(args.save_persona_image_directory, fname)) 42 | break 43 | except: 44 | print('Something wrong...') 45 | 46 | request_trial += 1 47 | print(f'{request_trial}-th Request trial...') 48 | if request_trial > args.max_trial: 49 | break 50 | return 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--max_trial", default=5, type=int) 55 | parser.add_argument("--dialog_image_url_directory", default=None, required=True, type=str) 56 | parser.add_argument("--persona_image_url_directory", default=None, required=True, type=str) 57 | parser.add_argument("--save_dialog_image_directory", default=None, required=True, type=str) 58 | parser.add_argument("--save_persona_image_directory", default=None, required=True, type=str) 59 | parser.add_argument("--longer_resize", default=512, type=int, 60 | help="Resize the longer edge of image to this size before \ 61 | saving to disk (preserve aspect ratio). Set to -1 to avoid any resizing. \ 62 | Defaults to 512.") 63 | args = parser.parse_args() 64 | 65 | with open(args.dialog_image_url_directory, 'r') as fp: 66 | id2dialog_image_url = json.load(fp) 67 | 68 | with open(args.persona_image_url_directory, 'r') as fp: 69 | id2persona_image_url = json.load(fp) 70 | 71 | os.makedirs(args.save_dialog_image_directory, exist_ok=True) 72 | os.makedirs(args.save_persona_image_directory, exist_ok=True) 73 | 74 | save_image(id2dialog_image_url, args, 'dialog') 75 | save_image(id2persona_image_url, args, 'persona') 76 | 77 | print('Good Job Computer!') 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /EXPERIMENT.md: -------------------------------------------------------------------------------- 1 | ## Next Response Prediction 2 | 3 | **CLIP-CLIP** 4 | ```bash 5 | python main_nrp.py \ 6 | --model_type "clip-clip" \ 7 | --dialog_data_dir "." \ 8 | --dialog_image_data_dir "./images/dialog/" \ 9 | --persona_image_data_dir "./images/persona/" \ 10 | --output_dir "outputs/clip-clip/nrp/full_inputs" \ 11 | --seed 202 \ 12 | --sum_persona_images \ 13 | --remove_empty_images \ 14 | --do_train \ 15 | --do_test \ 16 | --per_gpu_train_batch_size 8 \ 17 | --per_gpu_eval_batch_size 4 \ 18 | --max_num_responses 100 \ 19 | --learning_rate 3e-06 \ 20 | --weight_decay 0.05 \ 21 | --num_train_epochs 5 \ 22 | --save_epoch 1 \ 23 | --num_workers 12 24 | ``` 25 | 26 | **CLIP-SBERT** 27 | ```bash 28 | python main_nrp.py \ 29 | --model_type "clip-sbert" \ 30 | --dialog_data_dir "." \ 31 | --dialog_image_data_dir "./images/dialog/" \ 32 | --persona_image_data_dir "./images/persona/" \ 33 | --output_dir "outputs/clip-sbert/nrp/full_inputs" \ 34 | --seed 202 \ 35 | --freeze_image_encoder \ 36 | --sum_persona_images \ 37 | --remove_empty_images \ 38 | --do_train \ 39 | --do_test \ 40 | --per_gpu_train_batch_size 8 \ 41 | --per_gpu_eval_batch_size 4 \ 42 | --max_num_responses 100 \ 43 | --learning_rate 1e-05 \ 44 | --max_seq_length 128 \ 45 | --weight_decay 0.05 \ 46 | --num_train_epochs 5 \ 47 | --save_epoch 1 \ 48 | --num_workers 12 49 | ``` 50 | 51 | ## Grounding Persona Prediction 52 | 53 | **CLIP-CLIP (no-response)** 54 | ```bash 55 | python main_gpp.py \ 56 | --model_type "clip-clip" \ 57 | --dialog_data_dir "." \ 58 | --dialog_image_data_dir "./images/dialog/" \ 59 | --persona_image_data_dir "./images/persona/" \ 60 | --output_dir "outputs/clip-clip/gpp-context/full_inputs" \ 61 | --seed 202 \ 62 | --sum_persona_images \ 63 | --remove_empty_images \ 64 | --do_train \ 65 | --do_test \ 66 | --per_gpu_train_batch_size 8 \ 67 | --per_gpu_eval_batch_size 4 \ 68 | --max_num_candidate_persona_elements 100 \ 69 | --learning_rate 3e-06 \ 70 | --weight_decay 0.05 \ 71 | --num_train_epochs 5 \ 72 | --save_epoch 1 \ 73 | --num_workers 12 74 | ``` 75 | 76 | **CLIP-CLIP (response)** 77 | ```bash 78 | python main_gpp.py \ 79 | --model_type "clip-clip" \ 80 | --dialog_data_dir "." \ 81 | --dialog_image_data_dir "./images/dialog/" \ 82 | --persona_image_data_dir "./images/persona/" \ 83 | --output_dir "outputs/clip-clip/gpp-response/full_inputs" \ 84 | --seed 202 \ 85 | --sum_persona_images \ 86 | --remove_empty_images \ 87 | --use_response \ 88 | --do_train \ 89 | --do_test \ 90 | --per_gpu_train_batch_size 8 \ 91 | --per_gpu_eval_batch_size 4 \ 92 | --max_num_candidate_persona_elements 100 \ 93 | --learning_rate 3e-06 \ 94 | --weight_decay 0.05 \ 95 | --num_train_epochs 5 \ 96 | --save_epoch 1 \ 97 | --num_workers 12 98 | ``` 99 | 100 | **CLIP-SBERT (no-response)** 101 | ```bash 102 | python main_gpp.py \ 103 | --model_type "clip-sbert" \ 104 | --dialog_data_dir "." \ 105 | --dialog_image_data_dir "./images/dialog/" \ 106 | --persona_image_data_dir "./images/persona/" \ 107 | --output_dir "outputs/clip-sbert/gpp-context/full_inputs" \ 108 | --seed 202 \ 109 | --freeze_image_encoder \ 110 | --sum_persona_images \ 111 | --remove_empty_images \ 112 | --do_train \ 113 | --do_test \ 114 | --per_gpu_train_batch_size 8 \ 115 | --per_gpu_eval_batch_size 4 \ 116 | --max_num_candidate_persona_elements 100 \ 117 | --learning_rate 1e-05 \ 118 | --max_seq_length 128 \ 119 | --weight_decay 0.05 \ 120 | --num_train_epochs 5 \ 121 | --save_epoch 1 \ 122 | --num_workers 12 123 | ``` 124 | 125 | **CLIP-SBERT (response)** 126 | ```bash 127 | python main_gpp.py \ 128 | --model_type "clip-sbert" \ 129 | --dialog_data_dir "." \ 130 | --dialog_image_data_dir "./images/dialog/" \ 131 | --persona_image_data_dir "./images/persona/" \ 132 | --output_dir "outputs/clip-sbert/gpp-response/full_inputs" \ 133 | --seed 202 \ 134 | --freeze_image_encoder \ 135 | --sum_persona_images \ 136 | --remove_empty_images \ 137 | --use_response \ 138 | --do_train \ 139 | --do_test \ 140 | --per_gpu_train_batch_size 8 \ 141 | --per_gpu_eval_batch_size 4 \ 142 | --max_num_candidate_persona_elements 100 \ 143 | --learning_rate 1e-05 \ 144 | --max_seq_length 128 \ 145 | --weight_decay 0.05 \ 146 | --num_train_epochs 5 \ 147 | --save_epoch 1 \ 148 | --num_workers 12 149 | ``` 150 | 151 | ## Speaker Identification 152 | 153 | **CLIP-CLIP** 154 | ```bash 155 | python main_si.py \ 156 | --model_type "clip-clip" \ 157 | --dialog_data_dir "." \ 158 | --dialog_image_data_dir "./images/dialog/" \ 159 | --persona_image_data_dir "./images/persona/" \ 160 | --output_dir "outputs/clip-clip/si/full_inputs" \ 161 | --seed 202 \ 162 | --sum_persona_images \ 163 | --remove_empty_images \ 164 | --do_train \ 165 | --do_test \ 166 | --per_gpu_train_batch_size 8 \ 167 | --per_gpu_eval_batch_size 4 \ 168 | --max_num_candidate_authors 100 \ 169 | --learning_rate 3e-06 \ 170 | --weight_decay 0.05 \ 171 | --num_train_epochs 5 \ 172 | --save_epoch 1 \ 173 | --num_workers 12 174 | ``` 175 | 176 | **CLIP-SBERT** 177 | ```bash 178 | python main_si.py \ 179 | --model_type "clip-sbert" \ 180 | --dialog_data_dir "." \ 181 | --dialog_image_data_dir "./images/dialog/" \ 182 | --persona_image_data_dir "./images/persona/" \ 183 | --output_dir "outputs/clip-sbert/si/full_inputs" \ 184 | --seed 202 \ 185 | --freeze_image_encoder \ 186 | --sum_persona_images \ 187 | --remove_empty_images \ 188 | --do_train \ 189 | --do_test \ 190 | --per_gpu_train_batch_size 8 \ 191 | --per_gpu_eval_batch_size 4 \ 192 | --max_num_candidate_authors 100 \ 193 | --learning_rate 2e-05 \ 194 | --max_seq_length 128 \ 195 | --weight_decay 0.05 \ 196 | --num_train_epochs 5 \ 197 | --save_epoch 1 \ 198 | --num_workers 12 199 | ``` 200 | -------------------------------------------------------------------------------- /utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | import os 5 | 6 | import torch 7 | 8 | from .misc import is_main_process 9 | 10 | 11 | class SmoothedValue(object): 12 | """Track a series of values and provide access to smoothed values over a 13 | window or the global series average. 14 | """ 15 | 16 | def __init__(self, window_size=20): 17 | self.deque = deque(maxlen=window_size) 18 | # self.series = [] 19 | self.total = 0.0 20 | self.count = 0 21 | 22 | def update(self, value): 23 | self.deque.append(value) 24 | # self.series.append(value) 25 | self.count += 1 26 | self.total += value 27 | 28 | @property 29 | def median(self): 30 | d = torch.tensor(list(self.deque)) 31 | return d.median().item() 32 | 33 | @property 34 | def avg(self): 35 | d = torch.tensor(list(self.deque)) 36 | return d.mean().item() 37 | 38 | @property 39 | def global_avg(self): 40 | return self.total / self.count 41 | 42 | @property 43 | def last_value(self): 44 | return self.deque[-1] 45 | 46 | 47 | class MetricLogger(object): 48 | def __init__(self, delimiter="\t"): 49 | self.meters = {} 50 | self.params = {} 51 | self.delimiter = delimiter 52 | 53 | def update_params(self, update_dict): 54 | for param_group, group_dict in update_dict.items(): 55 | if param_group not in self.params: 56 | self.params[param_group] = {} 57 | for param_name, param_value in group_dict.items(): 58 | # skipping parameters if they start with '_' 59 | if param_name.startswith('_'): 60 | continue 61 | if isinstance(param_value, torch.Tensor): 62 | param_value = param_value.item() 63 | assert isinstance(param_value, (float, int)) 64 | self.params[param_group][param_name] = param_value 65 | 66 | def update_metrics(self, update_dict): 67 | for metric_group, group_dict in update_dict.items(): 68 | if metric_group not in self.meters: 69 | self.meters[metric_group] = defaultdict(SmoothedValue) 70 | for metric_name, metric_value in group_dict.items(): 71 | # skipping metrics if they start with '_' 72 | if metric_name.startswith('_'): 73 | continue 74 | if isinstance(metric_value, torch.Tensor): 75 | metric_value = metric_value.item() 76 | assert isinstance(metric_value, (float, int)) 77 | self.meters[metric_group][metric_name].update(metric_value) 78 | 79 | def get_logs(self, iteration): 80 | return_str = [] 81 | if len(self.meters) > 0: 82 | offset_m = max([len(group_name) for group_name in self.meters.keys()]) 83 | else: 84 | offset_m = 0 85 | if len(self.params) > 0: 86 | offset_p = max([len(group_name) for group_name in self.params.keys()]) 87 | else: 88 | offset_p = 0 89 | offset = max(offset_m, offset_p) 90 | 91 | for group_name, values in sorted(self.meters.items(), 92 | key=lambda x: x[0]): 93 | loss_str = [] 94 | for name, meter in values.items(): 95 | loss_str.append("{}: {:.4f} ({:.4f})".format( 96 | name, meter.median, meter.global_avg, 97 | )) 98 | return_str.append( 99 | "{:{offset}s} - {}".format( 100 | group_name, self.delimiter.join(loss_str), offset=offset, 101 | ), 102 | ) 103 | for group_name, values in self.params.items(): 104 | loss_str = [] 105 | for name, param in values.items(): 106 | loss_str.append("{}: {:.6f}".format(name, param)) 107 | return_str.append( 108 | "{:{offset}s} - {}".format( 109 | group_name, self.delimiter.join(loss_str), offset=offset, 110 | ), 111 | ) 112 | return "\n ".join(return_str) 113 | 114 | 115 | class TensorboardLogger(MetricLogger): 116 | def __init__(self, 117 | log_dir, 118 | delimiter='\t'): 119 | super(TensorboardLogger, self).__init__(delimiter) 120 | try: 121 | from tensorboardX import SummaryWriter 122 | except ImportError: 123 | raise ImportError( 124 | 'To use tensorboard please install tensorboardX ' 125 | '[ pip install tensorboardx ].' 126 | ) 127 | self.philly_tb_logger = None 128 | self.philly_tb_logger_avg = None 129 | self.philly_tb_logger_med = None 130 | if is_main_process(): 131 | self.tb_logger = SummaryWriter(log_dir) 132 | self.tb_logger_avg = SummaryWriter(os.path.join(log_dir, 'avg')) 133 | self.tb_logger_med = SummaryWriter(os.path.join(log_dir, 'med')) 134 | else: 135 | self.tb_logger = None 136 | self.tb_logger_avg = None 137 | self.tb_logger_med = None 138 | 139 | def get_logs(self, iteration): 140 | if self.tb_logger: 141 | for group_name, values in self.meters.items(): 142 | for name, meter in values.items(): 143 | self.tb_logger.add_scalar( 144 | '{}/{}'.format(group_name, name), 145 | meter.last_value, iteration, 146 | ) 147 | self.tb_logger_avg.add_scalar( 148 | '{}/{}'.format(group_name, name), 149 | meter.avg, iteration, 150 | ) 151 | self.tb_logger_med.add_scalar( 152 | '{}/{}'.format(group_name, name), 153 | meter.median, iteration, 154 | ) 155 | if self.philly_tb_logger: 156 | self.philly_tb_logger.add_scalar( 157 | '{}/{}'.format(group_name, name), 158 | meter.last_value, iteration, 159 | ) 160 | self.philly_tb_logger_avg.add_scalar( 161 | '{}/{}'.format(group_name, name), 162 | meter.avg, iteration, 163 | ) 164 | self.philly_tb_logger_med.add_scalar( 165 | '{}/{}'.format(group_name, name), 166 | meter.median, iteration, 167 | ) 168 | for group_name, values in self.params.items(): 169 | for name, param in values.items(): 170 | self.tb_logger.add_scalar( 171 | '{}/{}'.format(group_name, name), 172 | param, iteration, 173 | ) 174 | if self.philly_tb_logger: 175 | self.philly_tb_logger.add_scalar( 176 | '{}/{}'.format(group_name, name), 177 | param, iteration, 178 | ) 179 | return super(TensorboardLogger, self).get_logs(iteration) 180 | 181 | def close(self): 182 | if is_main_process(): 183 | self.tb_logger.close() 184 | self.tb_logger_avg.close() 185 | self.tb_logger_med.close() 186 | -------------------------------------------------------------------------------- /main_si.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import glob 5 | import logging 6 | import argparse 7 | 8 | import torch 9 | from utils.misc import ( 10 | set_seed, 11 | ) 12 | from utils.metric_logger import TensorboardLogger 13 | from data.mpchat_si import MpchatClipClipSiDataset, MpchatClipSbertSiDataset 14 | from models.si_models import ClipClipSi, ClipSbertSi 15 | from modules.checkpoint import load_checkpoint_args 16 | from modules.train_si import train, evaluate 17 | 18 | from transformers import ( 19 | AutoTokenizer, 20 | CLIPProcessor, 21 | WEIGHTS_NAME, 22 | ) 23 | 24 | MODEL_CLASSES = { 25 | 'clip-clip': (ClipClipSi, MpchatClipClipSiDataset), 26 | 'clip-sbert': (ClipSbertSi, MpchatClipSbertSiDataset), 27 | } 28 | logger = logging.getLogger(__name__) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | 33 | ## Required (or pre-defined) params 34 | parser.add_argument("--dialog_data_dir", default=None, type=str, required=True, help="The dialogue data dir") 35 | parser.add_argument("--dialog_image_data_dir", default=None, type=str, required=True, help="The dialogue image data dir") 36 | parser.add_argument("--persona_image_data_dir", default=None, type=str, required=True, help="The persona image data dir") 37 | parser.add_argument("--output_dir", default=None, type=str, required=True, 38 | help="The output directory where the model checkpoints will be written.") 39 | parser.add_argument("--model_type", default='clip-clip', choices=['clip-clip', 'clip-sbert']) 40 | parser.add_argument("--model_name_or_path", default='', type=str, 41 | help="Path to pre-trained model or shortcut name") 42 | 43 | ## Configs 44 | parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") 45 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 46 | parser.add_argument("--num_workers", default=6, type=int) 47 | parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") 48 | parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision instead of 32-bit") 49 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 50 | parser.add_argument("--do_test", action='store_true', help="Whether to run test on the test set.") 51 | parser.add_argument("--freeze_image_encoder", action='store_true', help="Whether to freeze image encoder or not") 52 | parser.add_argument("--freeze_text_encoder", action='store_true', help="Whether to freeze image encoder or not") 53 | parser.add_argument("--remove_empty_images", action='store_true', help="Whether to remove empty images or not") 54 | parser.add_argument("--sum_persona_images", action='store_true', help="Whether to sum persona images or not") 55 | 56 | # Misc: other params (model, input, etc) 57 | parser.add_argument("--clip_model_name", default='openai/clip-vit-base-patch32', type=str, help="CLIP model name") 58 | parser.add_argument("--sbert_model_name", default='sentence-transformers/multi-qa-distilbert-cos-v1', type=str, help="SBERT model name") 59 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 60 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.") 61 | parser.add_argument("--max_num_candidate_authors", type=int, default=100, help="maximum number of candidate authors") 62 | parser.add_argument("--max_seq_length", type=int, default=77) 63 | parser.add_argument("--max_num_imgs", type=int, default=5) 64 | parser.add_argument("--img_size", type=int, default=224) 65 | parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") 66 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 67 | help="Number of updates steps to accumulate before performing a backward/update pass.") 68 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") 69 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 70 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 71 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 72 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 73 | parser.add_argument('--logging_steps', type=int, default=100, help="Log every X updates steps.") 74 | parser.add_argument('--save_epoch', type=int, default=5, help="Save checkpoint every X epochs.") 75 | parser.add_argument('--save_after_epoch', type=int, default=-1, help="Save checkpoint after epoch.") 76 | 77 | args = parser.parse_args() 78 | 79 | # Setup CUDA, GPU & distributed training 80 | if args.local_rank == -1 or args.no_cuda: 81 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 82 | args.n_gpu = torch.cuda.device_count() 83 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 84 | raise NotImplementedError 85 | args.device = device 86 | 87 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 88 | datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 89 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 90 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 91 | 92 | # set seed 93 | set_seed(args.seed, args.n_gpu) 94 | 95 | # Output config 96 | os.makedirs(args.output_dir, exist_ok=True) 97 | 98 | # Load saved checkpoint 99 | recover_args = load_checkpoint_args(args, logger) 100 | 101 | # tokenizer 102 | tokenizer = AutoTokenizer.from_pretrained(args.sbert_model_name) 103 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name) 104 | model_class, dataset_class = MODEL_CLASSES[args.model_type] 105 | 106 | # Prepare model 107 | model = model_class(args, clip_processor) 108 | if recover_args['last_checkpoint_dir'] is not None or args.model_name_or_path != '': # recovery 109 | model_logging = model.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'pytorch_model.bin'))) 110 | logger.info(f"{model_logging}") 111 | 112 | # Freeze model 113 | if args.freeze_image_encoder: 114 | for param in model.context_image_encoder.parameters(): 115 | param.requires_grad = False 116 | for param in model.persona_image_encoder.parameters(): 117 | param.requires_grad = False 118 | 119 | if args.freeze_text_encoder: 120 | for param in model.context_text_encoder.parameters(): 121 | param.requires_grad = False 122 | for param in model.persona_text_encoder.parameters(): 123 | param.requires_grad = False 124 | for param in model.response_encoder.parameters(): 125 | param.requires_grad = False 126 | 127 | total_params = sum(p.numel() for p in model.parameters()) 128 | logger.info('Total Parameters: {}'.format(total_params)) 129 | 130 | model.to(args.device) 131 | 132 | logger.info("Training/evaluation parameters %s", args) 133 | 134 | # load eval dataset 135 | eval_dataset = dataset_class(args, tokenizer, clip_processor, 'val') 136 | 137 | # load tensorboard 138 | tb_log_dir = os.path.join(args.output_dir, 'train_logs') 139 | meters = TensorboardLogger( 140 | log_dir=tb_log_dir, 141 | delimiter=" ", 142 | ) 143 | 144 | # training 145 | if args.do_train: 146 | train_dataset = dataset_class(args, tokenizer, clip_processor, 'train') 147 | global_step, tr_loss = train(args, train_dataset, eval_dataset, model, meters, recover_args, logger) 148 | logger.info("global_step = %s, average loss = %s", global_step, tr_loss) 149 | 150 | # test 151 | if args.do_test: 152 | test_dataset = dataset_class(args, tokenizer, clip_processor, 'test') 153 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 154 | 155 | try: 156 | with open(os.path.join(args.output_dir, "last_checkpoint"), "r") as f: 157 | texts = f.read().split('\n') 158 | best_saved = texts[1].split('best: ')[-1].strip() 159 | checkpoints = [ckpt for ckpt in checkpoints if best_saved in ckpt] 160 | except: 161 | logger.info("Cannot load checkpoint!") 162 | pass 163 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 164 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 165 | 166 | test_log_json = [] 167 | for checkpoint in checkpoints: 168 | epoch = checkpoint.split('-')[-2] 169 | global_step = checkpoint.split('-')[-1] 170 | model.load_state_dict(torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))) 171 | model.to(args.device) 172 | test_scores = evaluate(args, model, test_dataset, 'test', logger, prefix=global_step) 173 | 174 | epoch_log = {'epoch': epoch, 'test_scores': test_scores} 175 | test_log_json.append(epoch_log) 176 | 177 | if args.local_rank in [-1, 0]: 178 | with open(args.output_dir + '/test_logs.json', 'w') as fp: 179 | json.dump(test_log_json, fp) 180 | 181 | # close the tb logger 182 | meters.close() 183 | logger.info("Good Job Computer!") 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /main_nrp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import glob 5 | import logging 6 | import argparse 7 | 8 | import torch 9 | from utils.misc import ( 10 | set_seed, 11 | ) 12 | from utils.metric_logger import TensorboardLogger 13 | from data.mpchat_nrp import MpchatClipClipNrpDataset, MpchatClipSbertNrpDataset 14 | from models.nrp_models import ClipClipNrp, ClipSbertNrp 15 | from modules.checkpoint import load_checkpoint_args 16 | from modules.train_nrp import train, evaluate 17 | 18 | from transformers import ( 19 | AutoTokenizer, 20 | CLIPProcessor, 21 | WEIGHTS_NAME, 22 | ) 23 | 24 | MODEL_CLASSES = { 25 | 'clip-clip': (ClipClipNrp, MpchatClipClipNrpDataset), 26 | 'clip-sbert': (ClipSbertNrp, MpchatClipSbertNrpDataset), 27 | } 28 | logger = logging.getLogger(__name__) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | 33 | ## Required (or pre-defined) params 34 | parser.add_argument("--dialog_data_dir", default=None, type=str, required=True, help="The dialogue data dir") 35 | parser.add_argument("--dialog_image_data_dir", default=None, type=str, required=True, help="The dialogue image data dir") 36 | parser.add_argument("--persona_image_data_dir", default=None, type=str, required=True, help="The persona image data dir") 37 | parser.add_argument("--output_dir", default=None, type=str, required=True, 38 | help="The output directory where the model checkpoints will be written.") 39 | parser.add_argument("--model_type", default='clip-clip', choices=['clip-clip', 'clip-sbert']) 40 | parser.add_argument("--model_name_or_path", default='', type=str, 41 | help="Path to pre-trained model or shortcut name") 42 | 43 | ## Configs 44 | parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") 45 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 46 | parser.add_argument("--num_workers", default=6, type=int) 47 | parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") 48 | parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision instead of 32-bit") 49 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 50 | parser.add_argument("--do_test", action='store_true', help="Whether to run test on the test set.") 51 | parser.add_argument("--freeze_image_encoder", action='store_true', help="Whether to freeze image encoder or not") 52 | parser.add_argument("--freeze_text_encoder", action='store_true', help="Whether to freeze image encoder or not") 53 | parser.add_argument("--remove_empty_images", action='store_true', help="Whether to remove empty images or not") 54 | parser.add_argument("--sum_persona_images", action='store_true', help="Whether to sum persona images or not") 55 | 56 | # Misc: other params (model, input, etc) 57 | parser.add_argument("--clip_model_name", default='openai/clip-vit-base-patch32', type=str, help="CLIP model name") 58 | parser.add_argument("--sbert_model_name", default='sentence-transformers/multi-qa-distilbert-cos-v1', type=str, help="SBERT model name") 59 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 60 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.") 61 | parser.add_argument("--max_num_responses", type=int, default=100, help="maximum number of multimodal personas") 62 | parser.add_argument("--max_seq_length", type=int, default=77) 63 | parser.add_argument("--max_num_imgs", type=int, default=5) 64 | parser.add_argument("--img_size", type=int, default=224) 65 | parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") 66 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 67 | help="Number of updates steps to accumulate before performing a backward/update pass.") 68 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") 69 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 70 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 71 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 72 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 73 | parser.add_argument('--logging_steps', type=int, default=100, help="Log every X updates steps.") 74 | parser.add_argument('--save_epoch', type=int, default=5, help="Save checkpoint every X epochs.") 75 | parser.add_argument('--save_after_epoch', type=int, default=-1, help="Save checkpoint after epoch.") 76 | 77 | args = parser.parse_args() 78 | 79 | # Setup CUDA, GPU & distributed training 80 | if args.local_rank == -1 or args.no_cuda: 81 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 82 | args.n_gpu = torch.cuda.device_count() 83 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 84 | raise NotImplementedError 85 | args.device = device 86 | 87 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 88 | datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 89 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 90 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 91 | 92 | # set seed 93 | set_seed(args.seed, args.n_gpu) 94 | 95 | # Output config 96 | os.makedirs(args.output_dir, exist_ok=True) 97 | 98 | # Load saved checkpoint 99 | recover_args = load_checkpoint_args(args, logger) 100 | 101 | # tokenizer 102 | tokenizer = AutoTokenizer.from_pretrained(args.sbert_model_name) 103 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name) 104 | model_class, dataset_class = MODEL_CLASSES[args.model_type] 105 | 106 | # Prepare model 107 | model = model_class(args, clip_processor) 108 | if recover_args['last_checkpoint_dir'] is not None or args.model_name_or_path != '': # recovery 109 | model_logging = model.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'pytorch_model.bin'))) 110 | logger.info(f"{model_logging}") 111 | 112 | # Freeze model 113 | if args.freeze_image_encoder: 114 | for param in model.context_image_encoder.parameters(): 115 | param.requires_grad = False 116 | for param in model.persona_image_encoder.parameters(): 117 | param.requires_grad = False 118 | 119 | if args.freeze_text_encoder: 120 | for param in model.context_text_encoder.parameters(): 121 | param.requires_grad = False 122 | for param in model.persona_text_encoder.parameters(): 123 | param.requires_grad = False 124 | for param in model.response_encoder.parameters(): 125 | param.requires_grad = False 126 | 127 | total_params = sum(p.numel() for p in model.parameters()) 128 | logger.info('Total Parameters: {}'.format(total_params)) 129 | 130 | model.to(args.device) 131 | 132 | logger.info("Training/evaluation parameters %s", args) 133 | 134 | # load eval dataset 135 | eval_dataset = dataset_class(args, tokenizer, clip_processor, 'val') 136 | 137 | # load tensorboard 138 | tb_log_dir = os.path.join(args.output_dir, 'train_logs') 139 | meters = TensorboardLogger( 140 | log_dir=tb_log_dir, 141 | delimiter=" ", 142 | ) 143 | 144 | # training 145 | if args.do_train: 146 | train_dataset = dataset_class(args, tokenizer, clip_processor, 'train') 147 | global_step, tr_loss = train(args, train_dataset, eval_dataset, model, meters, recover_args, logger) 148 | logger.info("global_step = %s, average loss = %s", global_step, tr_loss) 149 | 150 | # test 151 | if args.do_test: 152 | test_dataset = dataset_class(args, tokenizer, clip_processor, 'test') 153 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 154 | 155 | try: 156 | with open(os.path.join(args.output_dir, "last_checkpoint"), "r") as f: 157 | texts = f.read().split('\n') 158 | best_saved = texts[1].split('best: ')[-1].strip() 159 | checkpoints = [ckpt for ckpt in checkpoints if best_saved in ckpt] 160 | except: 161 | logger.info("Cannot load checkpoint!") 162 | pass 163 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 164 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 165 | 166 | test_log_json = [] 167 | for checkpoint in checkpoints: 168 | epoch = checkpoint.split('-')[-2] 169 | global_step = checkpoint.split('-')[-1] 170 | model.load_state_dict(torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))) 171 | model.to(args.device) 172 | test_scores = evaluate(args, model, test_dataset, 'test', logger, prefix=global_step) 173 | 174 | epoch_log = {'epoch': epoch, 'test_scores': test_scores} 175 | test_log_json.append(epoch_log) 176 | 177 | if args.local_rank in [-1, 0]: 178 | with open(args.output_dir + '/test_logs.json', 'w') as fp: 179 | json.dump(test_log_json, fp) 180 | 181 | # close the tb logger 182 | meters.close() 183 | logger.info("Good Job Computer!") 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /main_gpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import glob 5 | import logging 6 | import argparse 7 | 8 | import torch 9 | from utils.misc import ( 10 | set_seed, 11 | ) 12 | from utils.metric_logger import TensorboardLogger 13 | from data.mpchat_gpp import MpchatClipClipGppDataset, MpchatClipSbertGppDataset 14 | from models.gpp_models import ClipClipGpp, ClipSbertGpp 15 | from modules.checkpoint import load_checkpoint_args 16 | from modules.train_gpp import train, evaluate 17 | 18 | from transformers import ( 19 | AutoTokenizer, 20 | CLIPProcessor, 21 | WEIGHTS_NAME, 22 | ) 23 | 24 | MODEL_CLASSES = { 25 | 'clip-clip': (ClipClipGpp, MpchatClipClipGppDataset), 26 | 'clip-sbert': (ClipSbertGpp, MpchatClipSbertGppDataset), 27 | } 28 | logger = logging.getLogger(__name__) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | 33 | ## Required (or pre-defined) params 34 | parser.add_argument("--dialog_data_dir", default=None, type=str, required=True, help="The dialogue data dir") 35 | parser.add_argument("--dialog_image_data_dir", default=None, type=str, required=True, help="The dialogue image data dir") 36 | parser.add_argument("--persona_image_data_dir", default=None, type=str, required=True, help="The persona image data dir") 37 | parser.add_argument("--output_dir", default=None, type=str, required=True, 38 | help="The output directory where the model checkpoints will be written.") 39 | parser.add_argument("--model_type", default='clip-clip', choices=['clip-clip', 'clip-sbert']) 40 | parser.add_argument("--model_name_or_path", default='', type=str, 41 | help="Path to pre-trained model or shortcut name") 42 | 43 | ## Configs 44 | parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") 45 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 46 | parser.add_argument("--num_workers", default=6, type=int) 47 | parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") 48 | parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision instead of 32-bit") 49 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 50 | parser.add_argument("--do_test", action='store_true', help="Whether to run test on the test set.") 51 | parser.add_argument("--freeze_image_encoder", action='store_true', help="Whether to freeze image encoder or not") 52 | parser.add_argument("--freeze_text_encoder", action='store_true', help="Whether to freeze image encoder or not") 53 | parser.add_argument("--sum_persona_images", action='store_true', help="Whether to sum persona images or not") 54 | parser.add_argument("--remove_empty_images", action='store_true', help="Whether to remove empty images or not") 55 | parser.add_argument("--use_response", action='store_true', help="Whether to use response or not for gpp task") 56 | 57 | # Misc: other params (model, input, etc) 58 | parser.add_argument("--clip_model_name", default='openai/clip-vit-base-patch32', type=str, help="CLIP model name") 59 | parser.add_argument("--sbert_model_name", default='sentence-transformers/multi-qa-distilbert-cos-v1', type=str, help="SBERT model name") 60 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 61 | parser.add_argument("--per_gpu_eval_batch_size", default=32, type=int, help="Batch size per GPU/CPU for evaluation.") 62 | parser.add_argument("--max_num_personas", type=int, default=4, help="maximum number of multimodal personas") 63 | parser.add_argument("--max_num_candidate_persona_elements", type=int, default=100, help="maximum number of multimodal personas") 64 | parser.add_argument("--max_seq_length", type=int, default=77) 65 | parser.add_argument("--max_num_imgs", type=int, default=4) 66 | parser.add_argument("--img_size", type=int, default=224) 67 | parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") 68 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 69 | help="Number of updates steps to accumulate before performing a backward/update pass.") 70 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") 71 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 72 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 73 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 74 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 75 | parser.add_argument('--logging_steps', type=int, default=100, help="Log every X updates steps.") 76 | parser.add_argument('--save_epoch', type=int, default=5, help="Save checkpoint every X epochs.") 77 | parser.add_argument('--save_after_epoch', type=int, default=-1, help="Save checkpoint after epoch.") 78 | 79 | args = parser.parse_args() 80 | 81 | # Setup CUDA, GPU & distributed training 82 | if args.local_rank == -1 or args.no_cuda: 83 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 84 | args.n_gpu = torch.cuda.device_count() 85 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 86 | raise NotImplementedError 87 | args.device = device 88 | 89 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 90 | datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 91 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 92 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 93 | 94 | # set seed 95 | set_seed(args.seed, args.n_gpu) 96 | 97 | # Output config 98 | os.makedirs(args.output_dir, exist_ok=True) 99 | 100 | # Load saved checkpoint 101 | recover_args = load_checkpoint_args(args, logger) 102 | 103 | # tokenizer 104 | tokenizer = AutoTokenizer.from_pretrained(args.sbert_model_name) 105 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name) 106 | model_class, dataset_class = MODEL_CLASSES[args.model_type] 107 | 108 | # Prepare model 109 | model = model_class(args, clip_processor) 110 | if recover_args['last_checkpoint_dir'] is not None or args.model_name_or_path != '': # recovery 111 | model_logging = model.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'pytorch_model.bin'))) 112 | logger.info(f"{model_logging}") 113 | 114 | # Freeze model 115 | if args.freeze_image_encoder: 116 | for param in model.context_image_encoder.parameters(): 117 | param.requires_grad = False 118 | for param in model.persona_image_encoder.parameters(): 119 | param.requires_grad = False 120 | 121 | if args.freeze_text_encoder: 122 | for param in model.context_text_encoder.parameters(): 123 | param.requires_grad = False 124 | for param in model.persona_text_encoder.parameters(): 125 | param.requires_grad = False 126 | for param in model.response_encoder.parameters(): 127 | param.requires_grad = False 128 | 129 | total_params = sum(p.numel() for p in model.parameters()) 130 | logger.info('Total Parameters: {}'.format(total_params)) 131 | 132 | model.to(args.device) 133 | 134 | logger.info("Training/evaluation parameters %s", args) 135 | 136 | # load eval dataset 137 | eval_dataset = dataset_class(args, tokenizer, clip_processor, 'val') 138 | 139 | # load tensorboard 140 | tb_log_dir = os.path.join(args.output_dir, 'train_logs') 141 | meters = TensorboardLogger( 142 | log_dir=tb_log_dir, 143 | delimiter=" ", 144 | ) 145 | 146 | # training 147 | if args.do_train: 148 | train_dataset = dataset_class(args, tokenizer, clip_processor, 'train') 149 | global_step, tr_loss = train(args, train_dataset, eval_dataset, model, meters, recover_args, logger) 150 | logger.info("global_step = %s, average loss = %s", global_step, tr_loss) 151 | 152 | # test 153 | if args.do_test: 154 | test_dataset = dataset_class(args, tokenizer, clip_processor, 'test') 155 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 156 | 157 | try: 158 | with open(os.path.join(args.output_dir, "last_checkpoint"), "r") as f: 159 | texts = f.read().split('\n') 160 | best_saved = texts[1].split('best: ')[-1].strip() 161 | checkpoints = [ckpt for ckpt in checkpoints if best_saved in ckpt] 162 | except: 163 | logger.info("Cannot load checkpoint!") 164 | pass 165 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 166 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 167 | 168 | test_log_json = [] 169 | for checkpoint in checkpoints: 170 | epoch = checkpoint.split('-')[-2] 171 | global_step = checkpoint.split('-')[-1] 172 | model.load_state_dict(torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))) 173 | model.to(args.device) 174 | test_scores = evaluate(args, model, test_dataset, 'test', logger, prefix=global_step) 175 | 176 | epoch_log = {'epoch': epoch, 'test_scores': test_scores} 177 | test_log_json.append(epoch_log) 178 | 179 | if args.local_rank in [-1, 0]: 180 | with open(args.output_dir + '/test_logs.json', 'w') as fp: 181 | json.dump(test_log_json, fp) 182 | 183 | # close the tb logger 184 | meters.close() 185 | logger.info("Good Job Computer!") 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /modules/train_si.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import time 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch.utils.data import ( 11 | RandomSampler, 12 | SequentialSampler, 13 | DataLoader, 14 | ) 15 | from utils.misc import compute_metrics_from_logits 16 | 17 | from transformers import ( 18 | AdamW, 19 | get_linear_schedule_with_warmup, 20 | ) 21 | 22 | def train(args, train_dataset, eval_dataset, model, meters, recover_args, logger): 23 | """ Train the model """ 24 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 25 | train_sampler = RandomSampler(train_dataset) 26 | train_dataloader = DataLoader(train_dataset, num_workers=args.num_workers, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True) 27 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 28 | 29 | # Prepare optimizer 30 | param_optimizer = list(model.named_parameters()) 31 | no_decay = [] 32 | optimizer_grouped_parameters = [ 33 | {'params': [p for n, p in param_optimizer if 34 | not any(nd in n for nd in no_decay)], 35 | 'weight_decay': args.weight_decay} 36 | ] 37 | 38 | optimizer = AdamW(optimizer_grouped_parameters, 39 | lr=args.learning_rate, eps=args.adam_epsilon) 40 | scheduler = get_linear_schedule_with_warmup(optimizer, 41 | num_warmup_steps=args.warmup_steps, 42 | num_training_steps=t_total) 43 | 44 | if recover_args['global_step'] > 0 and os.path.isfile(os.path.join(recover_args['last_checkpoint_dir'], 'optimizer.pth')): # recovery 45 | last_checkpoint_dir = recover_args['last_checkpoint_dir'] 46 | logger.info( 47 | "Load optimizer from {}".format(last_checkpoint_dir)) 48 | optimizer_to_load = torch.load( 49 | os.path.join(last_checkpoint_dir, 'optimizer.pth'), 50 | map_location=torch.device("cpu")) 51 | optimizer.load_state_dict(optimizer_to_load.pop("optimizer")) 52 | scheduler.load_state_dict(optimizer_to_load.pop("scheduler")) 53 | 54 | # Train! 55 | logger.info("***** Running training *****") 56 | logger.info(" Num examples = %d", len(train_dataset)) 57 | logger.info(" Num Epochs = %d", args.num_train_epochs) 58 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 59 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 60 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 61 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 62 | logger.info(" Total optimization steps = %d", t_total) 63 | 64 | global_step = recover_args['global_step'] 65 | start_epoch = recover_args['epoch'] + 1 if global_step > 0 else 0 66 | tr_loss, logging_loss = 0.0, 0.0 67 | model.zero_grad() 68 | 69 | best_scores = { 70 | 'epoch': 0, 71 | 'global_step': 0, 72 | 'scores': {'recall@1': 0.0} 73 | } 74 | if recover_args['last_best_score'] is not None: 75 | best_scores = recover_args['last_best_score'] 76 | 77 | log_json = [] 78 | scaler = torch.cuda.amp.GradScaler(enabled=True) 79 | torch.autograd.set_detect_anomaly(True) 80 | for epoch in range(start_epoch, int(args.num_train_epochs)): 81 | t_start = time.time() 82 | tbar = tqdm(train_dataloader, ncols=70) 83 | for step, batch in enumerate(tbar): 84 | tbar.set_description(f'Training loss = {logging_loss}') 85 | model.train() 86 | 87 | context_input_ids = batch[0].to(args.device, non_blocking=True) 88 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 89 | response_input_ids = batch[2].to(args.device, non_blocking=True) 90 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 91 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 92 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 93 | dialog_img_feat = batch[6].to(args.device, non_blocking=True) 94 | persona_img_feats = batch[7].to(args.device, non_blocking=True) 95 | dialog_img_mask = batch[8].to(args.device, non_blocking=True) 96 | persona_img_mask = batch[9].to(args.device, non_blocking=True) 97 | 98 | if args.fp16: 99 | with torch.cuda.amp.autocast(enabled=True): 100 | outputs = model( 101 | context_input_ids=context_input_ids, 102 | context_attention_mask=context_attention_mask, 103 | response_input_ids=response_input_ids, 104 | response_attention_mask=response_attention_mask, 105 | persona_input_ids=persona_input_ids, 106 | persona_attention_mask=persona_attention_mask, 107 | dialog_img_feat=dialog_img_feat, 108 | persona_img_feats=persona_img_feats, 109 | dialog_img_mask=dialog_img_mask, 110 | persona_img_mask=persona_img_mask, 111 | mode='train', 112 | ) 113 | loss = outputs[0] 114 | else: 115 | outputs = model( 116 | context_input_ids=context_input_ids, 117 | context_attention_mask=context_attention_mask, 118 | response_input_ids=response_input_ids, 119 | response_attention_mask=response_attention_mask, 120 | persona_input_ids=persona_input_ids, 121 | persona_attention_mask=persona_attention_mask, 122 | dialog_img_feat=dialog_img_feat, 123 | persona_img_feats=persona_img_feats, 124 | dialog_img_mask=dialog_img_mask, 125 | persona_img_mask=persona_img_mask, 126 | mode='train', 127 | ) 128 | loss = outputs[0] 129 | 130 | if args.n_gpu > 1: loss = loss.mean() 131 | 132 | if args.gradient_accumulation_steps > 1: 133 | loss = loss / args.gradient_accumulation_steps 134 | 135 | if args.fp16: 136 | scaler.scale(loss).backward() 137 | else: 138 | loss.backward() 139 | 140 | tr_loss += loss.item() 141 | logging_loss = round(loss.item(), 5) 142 | if (step + 1) % args.gradient_accumulation_steps == 0: 143 | # do gradient clipping 144 | if args.max_grad_norm > 0: 145 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 146 | scheduler.step() # Update learning rate schedule 147 | if args.fp16: 148 | scaler.step(optimizer) 149 | scaler.update() 150 | else: 151 | optimizer.step() 152 | model.zero_grad() 153 | global_step += 1 154 | 155 | # update tensorboard 156 | meters.update_metrics({'batch_metrics': {'loss': loss}}) 157 | meters.update_params({'params': {'lr': optimizer.param_groups[0]['lr']}}) 158 | 159 | if args.logging_steps > 0 and (global_step + 1) % args.logging_steps == 0: 160 | meters.get_logs(global_step+1) 161 | 162 | # Evaluation 163 | logger.info("Epoch: %d, global_step: %d" % (epoch, global_step)) 164 | eval_scores = evaluate(args, model, eval_dataset, 'val', logger, prefix=global_step) 165 | 166 | # Select recall@1 score as metric 167 | if eval_scores['recall@1'] > best_scores['scores']['recall@1']: 168 | best_scores['scores'] = eval_scores 169 | best_scores['epoch'] = epoch 170 | best_scores['global_step'] = global_step 171 | 172 | # save checkpoints 173 | if (args.local_rank in [-1, 0]) and (args.save_epoch>0 and epoch % args.save_epoch == 0) and (epoch > args.save_after_epoch): 174 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}-{}'.format(epoch, global_step)) 175 | if not os.path.exists(output_dir): os.makedirs(output_dir) 176 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 177 | optimizer_to_save = { 178 | 'optimizer': optimizer.state_dict(), 179 | 'scheduler': scheduler.state_dict(), 180 | } 181 | 182 | save_num = 0 183 | while (save_num < 3): 184 | try: 185 | logger.info("Saving model attempt: {}".format(save_num)) 186 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) 187 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 188 | torch.save(optimizer_to_save, os.path.join(output_dir, 'optimizer.pth')) 189 | save_file = os.path.join(args.output_dir, 'last_checkpoint') 190 | with open(save_file, 'w') as f: 191 | f.write('checkpoint-{}-{}/pytorch_model.bin\n'.format(epoch, global_step)) 192 | f.write(f'best: checkpoint-{best_scores["epoch"]}-{best_scores["global_step"]}\n') 193 | json.dump(best_scores, f) 194 | break 195 | except: 196 | save_num += 1 197 | logger.info("Saving model checkpoint {0} to {1}".format(epoch, output_dir)) 198 | 199 | epoch_log = {'epoch': epoch, 'eval_scores': eval_scores, 'best_scores': best_scores['scores']} 200 | log_json.append(epoch_log) 201 | 202 | logger.info("PROGRESS: {}%".format(round(100*(epoch + 1) / args.num_train_epochs, 4))) 203 | 204 | if args.local_rank in [-1, 0]: 205 | with open(args.output_dir + '/eval_logs.json', 'w') as fp: 206 | json.dump(log_json, fp) 207 | 208 | t_end = time.time() 209 | logger.info('Epoch: %d, Train Time: %.3f' % (epoch, t_end - t_start)) 210 | 211 | return global_step, tr_loss / global_step 212 | 213 | def evaluate(args, model, eval_dataset, mode, logger, prefix=''): 214 | t_start = time.time() 215 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 216 | test_sampler = SequentialSampler(eval_dataset) 217 | test_dataloader = DataLoader(eval_dataset, num_workers=args.num_workers, sampler=test_sampler, batch_size=args.eval_batch_size, pin_memory=True) 218 | 219 | # Eval! 220 | logger.info("***** Running evaluation {} *****".format(prefix)) 221 | logger.info(" Num examples = %d", len(eval_dataset)) 222 | logger.info(" Batch size = %d", args.eval_batch_size) 223 | 224 | results_dict = defaultdict(list) 225 | for batch in tqdm(test_dataloader, ncols=70): 226 | model.eval() 227 | 228 | context_input_ids = batch[0].to(args.device, non_blocking=True) 229 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 230 | response_input_ids = batch[2].to(args.device, non_blocking=True) 231 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 232 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 233 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 234 | dialog_img_feat = batch[6].to(args.device, non_blocking=True) 235 | persona_img_feats = batch[7].to(args.device, non_blocking=True) 236 | labels = batch[8].to(args.device, non_blocking=True) 237 | dialog_img_mask = batch[9].to(args.device, non_blocking=True) 238 | persona_img_mask = batch[10].to(args.device, non_blocking=True) 239 | 240 | with torch.no_grad(): 241 | loss, logits = model( 242 | context_input_ids=context_input_ids, 243 | context_attention_mask=context_attention_mask, 244 | response_input_ids=response_input_ids, 245 | response_attention_mask=response_attention_mask, 246 | persona_input_ids=persona_input_ids, 247 | persona_attention_mask=persona_attention_mask, 248 | dialog_img_feat=dialog_img_feat, 249 | persona_img_feats=persona_img_feats, 250 | labels=labels, 251 | dialog_img_mask=dialog_img_mask, 252 | persona_img_mask=persona_img_mask, 253 | mode=mode, 254 | ) 255 | results_dict['loss'].append(loss.cpu().detach().numpy()) 256 | results_dict['logits'].append(logits.cpu().detach().numpy()) 257 | results_dict['labels'].append(labels.cpu().detach().numpy()) 258 | 259 | for key, value in results_dict.items(): 260 | if results_dict[key][0].shape == (): 261 | results_dict[key] = np.array(value) 262 | else: 263 | results_dict[key] = np.concatenate(value, axis=0) 264 | 265 | recall, mrr = compute_metrics_from_logits(torch.tensor(results_dict['logits']), 266 | torch.tensor(results_dict['labels'])) 267 | 268 | total_scores = { 269 | 'loss': round(np.mean(results_dict['loss']).item(), 4), 270 | 'mrr': round(mrr, 4), 271 | } 272 | for k,v in recall.items(): 273 | total_scores[k] = round(v, 4) 274 | 275 | logger.info("Eval Results:") 276 | logger.info(f'Eval Score: {total_scores}') 277 | 278 | t_end = time.time() 279 | logger.info('Eval Time Cost: %.3f' % (t_end - t_start)) 280 | 281 | return total_scores 282 | -------------------------------------------------------------------------------- /modules/train_nrp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import time 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch.utils.data import ( 11 | RandomSampler, 12 | SequentialSampler, 13 | DataLoader, 14 | ) 15 | from utils.misc import compute_metrics_from_logits 16 | 17 | from transformers import ( 18 | AdamW, 19 | get_linear_schedule_with_warmup, 20 | ) 21 | 22 | def train(args, train_dataset, eval_dataset, model, meters, recover_args, logger): 23 | """ Train the model """ 24 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 25 | train_sampler = RandomSampler(train_dataset) 26 | train_dataloader = DataLoader(train_dataset, num_workers=args.num_workers, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True) 27 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 28 | 29 | # Prepare optimizer 30 | param_optimizer = list(model.named_parameters()) 31 | no_decay = [] 32 | optimizer_grouped_parameters = [ 33 | {'params': [p for n, p in param_optimizer if 34 | not any(nd in n for nd in no_decay)], 35 | 'weight_decay': args.weight_decay} 36 | ] 37 | 38 | optimizer = AdamW(optimizer_grouped_parameters, 39 | lr=args.learning_rate, eps=args.adam_epsilon) 40 | scheduler = get_linear_schedule_with_warmup(optimizer, 41 | num_warmup_steps=args.warmup_steps, 42 | num_training_steps=t_total) 43 | 44 | if recover_args['global_step'] > 0 and os.path.isfile(os.path.join(recover_args['last_checkpoint_dir'], 'optimizer.pth')): # recovery 45 | last_checkpoint_dir = recover_args['last_checkpoint_dir'] 46 | logger.info( 47 | "Load optimizer from {}".format(last_checkpoint_dir)) 48 | optimizer_to_load = torch.load( 49 | os.path.join(last_checkpoint_dir, 'optimizer.pth'), 50 | map_location=torch.device("cpu")) 51 | optimizer.load_state_dict(optimizer_to_load.pop("optimizer")) 52 | scheduler.load_state_dict(optimizer_to_load.pop("scheduler")) 53 | 54 | # Train! 55 | logger.info("***** Running training *****") 56 | logger.info(" Num examples = %d", len(train_dataset)) 57 | logger.info(" Num Epochs = %d", args.num_train_epochs) 58 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 59 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 60 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 61 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 62 | logger.info(" Total optimization steps = %d", t_total) 63 | 64 | global_step = recover_args['global_step'] 65 | start_epoch = recover_args['epoch'] + 1 if global_step > 0 else 0 66 | tr_loss, logging_loss = 0.0, 0.0 67 | model.zero_grad() 68 | 69 | best_scores = { 70 | 'epoch': 0, 71 | 'global_step': 0, 72 | 'scores': {'recall@1': 0.0} 73 | } 74 | if recover_args['last_best_score'] is not None: 75 | best_scores = recover_args['last_best_score'] 76 | 77 | log_json = [] 78 | scaler = torch.cuda.amp.GradScaler(enabled=True) 79 | torch.autograd.set_detect_anomaly(True) 80 | for epoch in range(start_epoch, int(args.num_train_epochs)): 81 | t_start = time.time() 82 | tbar = tqdm(train_dataloader, ncols=70) 83 | for step, batch in enumerate(tbar): 84 | tbar.set_description(f'Training loss = {logging_loss}') 85 | model.train() 86 | 87 | context_input_ids = batch[0].to(args.device, non_blocking=True) 88 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 89 | response_input_ids = batch[2].to(args.device, non_blocking=True) 90 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 91 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 92 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 93 | dialog_img_feat = batch[6].to(args.device, non_blocking=True) 94 | persona_img_feats = batch[7].to(args.device, non_blocking=True) 95 | dialog_img_mask = batch[8].to(args.device, non_blocking=True) 96 | persona_img_mask = batch[9].to(args.device, non_blocking=True) 97 | 98 | if args.fp16: 99 | with torch.cuda.amp.autocast(enabled=True): 100 | outputs = model( 101 | context_input_ids=context_input_ids, 102 | context_attention_mask=context_attention_mask, 103 | response_input_ids=response_input_ids, 104 | response_attention_mask=response_attention_mask, 105 | persona_input_ids=persona_input_ids, 106 | persona_attention_mask=persona_attention_mask, 107 | dialog_img_feat=dialog_img_feat, 108 | persona_img_feats=persona_img_feats, 109 | dialog_img_mask=dialog_img_mask, 110 | persona_img_mask=persona_img_mask, 111 | mode='train', 112 | ) 113 | loss = outputs[0] 114 | else: 115 | outputs = model( 116 | context_input_ids=context_input_ids, 117 | context_attention_mask=context_attention_mask, 118 | response_input_ids=response_input_ids, 119 | response_attention_mask=response_attention_mask, 120 | persona_input_ids=persona_input_ids, 121 | persona_attention_mask=persona_attention_mask, 122 | dialog_img_feat=dialog_img_feat, 123 | persona_img_feats=persona_img_feats, 124 | dialog_img_mask=dialog_img_mask, 125 | persona_img_mask=persona_img_mask, 126 | mode='train', 127 | ) 128 | loss = outputs[0] 129 | 130 | if args.n_gpu > 1: loss = loss.mean() 131 | 132 | if args.gradient_accumulation_steps > 1: 133 | loss = loss / args.gradient_accumulation_steps 134 | 135 | if args.fp16: 136 | scaler.scale(loss).backward() 137 | else: 138 | loss.backward() 139 | 140 | tr_loss += loss.item() 141 | logging_loss = round(loss.item(), 5) 142 | if (step + 1) % args.gradient_accumulation_steps == 0: 143 | # do gradient clipping 144 | if args.max_grad_norm > 0: 145 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 146 | scheduler.step() # Update learning rate schedule 147 | if args.fp16: 148 | scaler.step(optimizer) 149 | scaler.update() 150 | else: 151 | optimizer.step() 152 | model.zero_grad() 153 | global_step += 1 154 | 155 | # update tensorboard 156 | meters.update_metrics({'batch_metrics': {'loss': loss}}) 157 | meters.update_params({'params': {'lr': optimizer.param_groups[0]['lr']}}) 158 | 159 | if args.logging_steps > 0 and (global_step + 1) % args.logging_steps == 0: 160 | meters.get_logs(global_step+1) 161 | 162 | # Evaluation 163 | logger.info("Epoch: %d, global_step: %d" % (epoch, global_step)) 164 | eval_scores = evaluate(args, model, eval_dataset, 'val', logger, prefix=global_step) 165 | 166 | # Select recall@1 score as metric 167 | if eval_scores['recall@1'] > best_scores['scores']['recall@1']: 168 | best_scores['scores'] = eval_scores 169 | best_scores['epoch'] = epoch 170 | best_scores['global_step'] = global_step 171 | 172 | # save checkpoints 173 | if (args.local_rank in [-1, 0]) and (args.save_epoch>0 and epoch % args.save_epoch == 0) and (epoch > args.save_after_epoch): 174 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}-{}'.format(epoch, global_step)) 175 | if not os.path.exists(output_dir): os.makedirs(output_dir) 176 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 177 | optimizer_to_save = { 178 | 'optimizer': optimizer.state_dict(), 179 | 'scheduler': scheduler.state_dict(), 180 | } 181 | 182 | save_num = 0 183 | while (save_num < 3): 184 | try: 185 | logger.info("Saving model attempt: {}".format(save_num)) 186 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) 187 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 188 | torch.save(optimizer_to_save, os.path.join(output_dir, 'optimizer.pth')) 189 | save_file = os.path.join(args.output_dir, 'last_checkpoint') 190 | with open(save_file, 'w') as f: 191 | f.write('checkpoint-{}-{}/pytorch_model.bin\n'.format(epoch, global_step)) 192 | f.write(f'best: checkpoint-{best_scores["epoch"]}-{best_scores["global_step"]}\n') 193 | json.dump(best_scores, f) 194 | break 195 | except: 196 | save_num += 1 197 | logger.info("Saving model checkpoint {0} to {1}".format(epoch, output_dir)) 198 | 199 | epoch_log = {'epoch': epoch, 'eval_scores': eval_scores, 'best_scores': best_scores['scores']} 200 | log_json.append(epoch_log) 201 | 202 | logger.info("PROGRESS: {}%".format(round(100*(epoch + 1) / args.num_train_epochs, 4))) 203 | 204 | if args.local_rank in [-1, 0]: 205 | with open(args.output_dir + '/eval_logs.json', 'w') as fp: 206 | json.dump(log_json, fp) 207 | 208 | t_end = time.time() 209 | logger.info('Epoch: %d, Train Time: %.3f' % (epoch, t_end - t_start)) 210 | 211 | return global_step, tr_loss / global_step 212 | 213 | def evaluate(args, model, eval_dataset, mode, logger, prefix=''): 214 | t_start = time.time() 215 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 216 | test_sampler = SequentialSampler(eval_dataset) 217 | test_dataloader = DataLoader(eval_dataset, num_workers=args.num_workers, sampler=test_sampler, batch_size=args.eval_batch_size, pin_memory=True) 218 | 219 | # Eval! 220 | logger.info("***** Running evaluation {} *****".format(prefix)) 221 | logger.info(" Num examples = %d", len(eval_dataset)) 222 | logger.info(" Batch size = %d", args.eval_batch_size) 223 | 224 | results_dict = defaultdict(list) 225 | for batch in tqdm(test_dataloader, ncols=70): 226 | model.eval() 227 | 228 | context_input_ids = batch[0].to(args.device, non_blocking=True) 229 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 230 | response_input_ids = batch[2].to(args.device, non_blocking=True) 231 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 232 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 233 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 234 | dialog_img_feat = batch[6].to(args.device, non_blocking=True) 235 | persona_img_feats = batch[7].to(args.device, non_blocking=True) 236 | labels = batch[8].to(args.device, non_blocking=True) 237 | example_indices = batch[9].to(args.device, non_blocking=True) 238 | dialog_img_mask = batch[10].to(args.device, non_blocking=True) 239 | persona_img_mask = batch[11].to(args.device, non_blocking=True) 240 | 241 | with torch.no_grad(): 242 | loss, logits = model( 243 | context_input_ids=context_input_ids, 244 | context_attention_mask=context_attention_mask, 245 | response_input_ids=response_input_ids, 246 | response_attention_mask=response_attention_mask, 247 | persona_input_ids=persona_input_ids, 248 | persona_attention_mask=persona_attention_mask, 249 | dialog_img_feat=dialog_img_feat, 250 | persona_img_feats=persona_img_feats, 251 | labels=labels, 252 | dialog_img_mask=dialog_img_mask, 253 | persona_img_mask=persona_img_mask, 254 | mode=mode, 255 | ) 256 | results_dict['loss'].append(loss.cpu().detach().numpy()) 257 | results_dict['logits'].append(logits.cpu().detach().numpy()) 258 | results_dict['labels'].append(labels.cpu().detach().numpy()) 259 | results_dict['example_indices'].append(example_indices.cpu().detach().numpy()) 260 | 261 | for key, value in results_dict.items(): 262 | if results_dict[key][0].shape == (): 263 | results_dict[key] = np.array(value) 264 | else: 265 | results_dict[key] = np.concatenate(value, axis=0) 266 | 267 | recall, mrr = compute_metrics_from_logits(torch.tensor(results_dict['logits']), 268 | torch.tensor(results_dict['labels'])) 269 | 270 | total_scores = { 271 | 'loss': round(np.mean(results_dict['loss']).item(), 4), 272 | 'mrr': round(mrr, 4), 273 | } 274 | for k,v in recall.items(): 275 | total_scores[k] = round(v, 4) 276 | 277 | logger.info("Eval Results:") 278 | logger.info(f'Eval Score: {total_scores}') 279 | 280 | t_end = time.time() 281 | logger.info('Eval Time Cost: %.3f' % (t_end - t_start)) 282 | 283 | return total_scores 284 | -------------------------------------------------------------------------------- /models/si_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import CrossEntropyLoss 4 | from typing import Any, Optional, Tuple, Union 5 | 6 | from .model import ( 7 | ClipClipModel, 8 | ClipSbertModel, 9 | clip_loss, 10 | mean_pooling, 11 | ) 12 | 13 | from transformers import ( 14 | CLIPProcessor, 15 | AdamW, 16 | get_linear_schedule_with_warmup, 17 | WEIGHTS_NAME, 18 | ) 19 | 20 | class ClipClipSi(ClipClipModel): 21 | def forward( 22 | self, 23 | context_input_ids: Optional[torch.LongTensor] = None, 24 | context_attention_mask: Optional[torch.LongTensor] = None, 25 | response_input_ids: Optional[torch.LongTensor] = None, 26 | response_attention_mask: Optional[torch.LongTensor] = None, 27 | persona_input_ids: Optional[torch.LongTensor] = None, 28 | persona_attention_mask: Optional[torch.LongTensor] = None, 29 | dialog_img_feat: Optional[torch.Tensor] = None, 30 | persona_img_feats: Optional[torch.Tensor] = None, 31 | dialog_img_mask: Optional[torch.LongTensor] = None, 32 | persona_img_mask: Optional[torch.LongTensor] = None, 33 | labels: Optional[torch.LongTensor] = None, 34 | mode: str = None, 35 | ): 36 | if mode == 'train': 37 | context_output = self.context_text_encoder( 38 | input_ids=context_input_ids, 39 | attention_mask=context_attention_mask 40 | )[1] 41 | context_output = self.context_text_projection(context_output) 42 | context_output = F.normalize(context_output, p=2, dim=1) 43 | 44 | response_output = self.response_encoder( 45 | input_ids=response_input_ids, 46 | attention_mask=response_attention_mask 47 | )[1] 48 | response_output = self.response_projection(response_output) 49 | response_output = F.normalize(response_output, p=2, dim=1) 50 | 51 | persona_output = self.persona_text_encoder( 52 | input_ids=persona_input_ids, 53 | attention_mask=persona_attention_mask 54 | )[1] 55 | persona_output = self.persona_text_projection(persona_output) 56 | persona_output = F.normalize(persona_output, p=2, dim=1) 57 | 58 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 59 | persona_image_output = self.persona_image_projection(persona_image_output) 60 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 61 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 62 | 63 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 64 | dialog_image_output = self.context_image_projection(dialog_image_output) 65 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 66 | 67 | assert self.args.sum_persona_images 68 | if self.args.remove_empty_images: 69 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 70 | else: 71 | persona_image_output = torch.mean(persona_image_output, dim=1) 72 | multimodal_persona_output = (persona_output + persona_image_output) / 2 73 | 74 | if self.args.remove_empty_images: 75 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + response_output) / (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 76 | else: 77 | multimodal_context_output = (dialog_image_output + context_output + response_output) / 3 78 | 79 | # cosine similarity as logits 80 | logit_scale = self.logit_scale.exp() 81 | dot_products = multimodal_context_output.mm(multimodal_persona_output.t()) * logit_scale 82 | loss = clip_loss(dot_products) 83 | 84 | outputs = (loss,) 85 | else: 86 | context_output = self.context_text_encoder( 87 | input_ids=context_input_ids, 88 | attention_mask=context_attention_mask 89 | )[1] 90 | context_output = self.context_text_projection(context_output) 91 | context_output = F.normalize(context_output, p=2, dim=1) 92 | 93 | response_output = self.response_encoder( 94 | input_ids=response_input_ids, 95 | attention_mask=response_attention_mask 96 | )[1] 97 | response_output = self.response_projection(response_output) 98 | response_output = F.normalize(response_output, p=2, dim=1) 99 | 100 | cand_persona_input_ids = persona_input_ids.view(-1, self.args.max_seq_length) 101 | cand_persona_attention_mask = persona_attention_mask.view(-1, self.args.max_seq_length) 102 | 103 | persona_output = self.persona_text_encoder( 104 | input_ids=cand_persona_input_ids, 105 | attention_mask=cand_persona_attention_mask 106 | )[1] 107 | persona_output = self.persona_text_projection(persona_output) 108 | persona_output = F.normalize(persona_output, p=2, dim=1) 109 | persona_output = persona_output.view(context_input_ids.size(0), self.args.max_num_candidate_authors, persona_output.size(-1)) 110 | 111 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 112 | dialog_image_output = self.context_image_projection(dialog_image_output) 113 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 114 | 115 | cand_persona_img_feats = persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size) 116 | persona_image_output = self.persona_image_encoder(pixel_values=cand_persona_img_feats)[1] 117 | persona_image_output = self.persona_image_projection(persona_image_output) 118 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 119 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_candidate_authors, self.args.max_num_imgs, persona_image_output.size(-1)) 120 | 121 | assert self.args.sum_persona_images 122 | if self.args.remove_empty_images: 123 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,1,dialog_image_output.size(-1)) * persona_image_output, dim=2) / torch.sum(persona_img_mask, dim=2).unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) 124 | else: 125 | persona_image_output = torch.mean(persona_image_output, dim=2) 126 | multimodal_persona_output = (persona_output + persona_image_output) / 2 127 | 128 | if self.args.remove_empty_images: 129 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + response_output) / (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 130 | else: 131 | multimodal_context_output = (dialog_image_output + context_output + response_output) / 3 132 | 133 | logits = torch.bmm(multimodal_context_output.unsqueeze(1), 134 | multimodal_persona_output.view(context_input_ids.size(0),self.args.max_num_candidate_authors,-1).transpose(1,2)).squeeze(1) 135 | 136 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 137 | outputs = (loss, logits,) 138 | return outputs 139 | 140 | class ClipSbertSi(ClipSbertModel): 141 | def forward( 142 | self, 143 | context_input_ids: Optional[torch.LongTensor] = None, 144 | context_attention_mask: Optional[torch.LongTensor] = None, 145 | response_input_ids: Optional[torch.LongTensor] = None, 146 | response_attention_mask: Optional[torch.LongTensor] = None, 147 | persona_input_ids: Optional[torch.LongTensor] = None, 148 | persona_attention_mask: Optional[torch.LongTensor] = None, 149 | dialog_img_feat: Optional[torch.Tensor] = None, 150 | persona_img_feats: Optional[torch.Tensor] = None, 151 | dialog_img_mask: Optional[torch.LongTensor] = None, 152 | persona_img_mask: Optional[torch.LongTensor] = None, 153 | labels: Optional[torch.LongTensor] = None, 154 | mode: str = None, 155 | ): 156 | if mode == 'train': 157 | context_output = self.context_text_encoder( 158 | input_ids=context_input_ids, 159 | attention_mask=context_attention_mask 160 | ) 161 | context_output = mean_pooling(context_output, context_attention_mask) 162 | 163 | response_output = self.response_encoder( 164 | input_ids=response_input_ids, 165 | attention_mask=response_attention_mask 166 | ) 167 | response_output = mean_pooling(response_output, response_attention_mask) 168 | 169 | persona_output = self.persona_text_encoder( 170 | input_ids=persona_input_ids, 171 | attention_mask=persona_attention_mask 172 | ) 173 | persona_output = mean_pooling(persona_output, persona_attention_mask) 174 | 175 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 176 | persona_image_output = self.persona_image_projection(persona_image_output) 177 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 178 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 179 | 180 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 181 | dialog_image_output = self.context_image_projection(dialog_image_output) 182 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 183 | 184 | assert self.args.sum_persona_images 185 | if self.args.remove_empty_images: 186 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 187 | else: 188 | persona_image_output = torch.mean(persona_image_output, dim=1) 189 | multimodal_persona_output = (persona_output + persona_image_output) / 2 190 | 191 | if self.args.remove_empty_images: 192 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + response_output) / (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 193 | else: 194 | multimodal_context_output = (dialog_image_output + context_output + response_output) / 3 195 | 196 | targets = torch.arange(context_output.shape[0], device=context_output.device) 197 | # dot_products: [batch, batch] 198 | dot_products = multimodal_context_output.mm(multimodal_persona_output.t()) 199 | log_prob = F.log_softmax(dot_products, dim=1) 200 | loss = F.nll_loss(log_prob, targets) 201 | 202 | outputs = (loss,) 203 | else: 204 | context_output = self.context_text_encoder( 205 | input_ids=context_input_ids, 206 | attention_mask=context_attention_mask 207 | ) 208 | context_output = mean_pooling(context_output, context_attention_mask) 209 | 210 | response_output = self.response_encoder( 211 | input_ids=response_input_ids, 212 | attention_mask=response_attention_mask 213 | ) 214 | response_output = mean_pooling(response_output, response_attention_mask) 215 | 216 | cand_persona_input_ids = persona_input_ids.view(-1, self.args.max_seq_length) 217 | cand_persona_attention_mask = persona_attention_mask.view(-1, self.args.max_seq_length) 218 | 219 | persona_output = self.persona_text_encoder( 220 | input_ids=cand_persona_input_ids, 221 | attention_mask=cand_persona_attention_mask 222 | ) 223 | persona_output = mean_pooling(persona_output, cand_persona_attention_mask) 224 | persona_output = persona_output.view(context_input_ids.size(0), self.args.max_num_candidate_authors, persona_output.size(-1)) 225 | 226 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 227 | dialog_image_output = self.context_image_projection(dialog_image_output) 228 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 229 | 230 | cand_persona_img_feats = persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size) 231 | persona_image_output = self.persona_image_encoder(pixel_values=cand_persona_img_feats)[1] 232 | persona_image_output = self.persona_image_projection(persona_image_output) 233 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 234 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_candidate_authors, self.args.max_num_imgs, persona_image_output.size(-1)) 235 | 236 | assert self.args.sum_persona_images 237 | if self.args.remove_empty_images: 238 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,1,dialog_image_output.size(-1)) * persona_image_output, dim=2) / torch.sum(persona_img_mask, dim=2).unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) 239 | else: 240 | persona_image_output = torch.mean(persona_image_output, dim=2) 241 | multimodal_persona_output = (persona_output + persona_image_output) / 2 242 | 243 | if self.args.remove_empty_images: 244 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + response_output) / (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 245 | else: 246 | multimodal_context_output = (dialog_image_output + context_output + response_output) / 3 247 | 248 | logits = torch.bmm(multimodal_context_output.unsqueeze(1), 249 | multimodal_persona_output.view(context_input_ids.size(0),self.args.max_num_candidate_authors,-1).transpose(1,2)).squeeze(1) 250 | 251 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 252 | outputs = (loss, logits,) 253 | return outputs 254 | -------------------------------------------------------------------------------- /modules/train_gpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import time 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch.utils.data import ( 11 | RandomSampler, 12 | SequentialSampler, 13 | DataLoader, 14 | ) 15 | from utils.misc import compute_metrics_from_logits 16 | 17 | from transformers import ( 18 | AdamW, 19 | get_linear_schedule_with_warmup, 20 | ) 21 | 22 | def train(args, train_dataset, eval_dataset, model, meters, recover_args, logger): 23 | """ Train the model """ 24 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 25 | train_sampler = RandomSampler(train_dataset) 26 | train_dataloader = DataLoader(train_dataset, num_workers=args.num_workers, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True) 27 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 28 | 29 | # Prepare optimizer 30 | param_optimizer = list(model.named_parameters()) 31 | no_decay = [] 32 | optimizer_grouped_parameters = [ 33 | {'params': [p for n, p in param_optimizer if 34 | not any(nd in n for nd in no_decay)], 35 | 'weight_decay': args.weight_decay} 36 | ] 37 | 38 | optimizer = AdamW(optimizer_grouped_parameters, 39 | lr=args.learning_rate, eps=args.adam_epsilon) 40 | scheduler = get_linear_schedule_with_warmup(optimizer, 41 | num_warmup_steps=args.warmup_steps, 42 | num_training_steps=t_total) 43 | 44 | if recover_args['global_step'] > 0 and os.path.isfile(os.path.join(recover_args['last_checkpoint_dir'], 'optimizer.pth')): # recovery 45 | last_checkpoint_dir = recover_args['last_checkpoint_dir'] 46 | logger.info( 47 | "Load optimizer from {}".format(last_checkpoint_dir)) 48 | optimizer_to_load = torch.load( 49 | os.path.join(last_checkpoint_dir, 'optimizer.pth'), 50 | map_location=torch.device("cpu")) 51 | optimizer.load_state_dict(optimizer_to_load.pop("optimizer")) 52 | scheduler.load_state_dict(optimizer_to_load.pop("scheduler")) 53 | 54 | # Train! 55 | logger.info("***** Running training *****") 56 | logger.info(" Num examples = %d", len(train_dataset)) 57 | logger.info(" Num Epochs = %d", args.num_train_epochs) 58 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 59 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 60 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 61 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 62 | logger.info(" Total optimization steps = %d", t_total) 63 | 64 | global_step = recover_args['global_step'] 65 | start_epoch = recover_args['epoch'] + 1 if global_step > 0 else 0 66 | tr_loss, logging_loss = 0.0, 0.0 67 | model.zero_grad() 68 | 69 | best_scores = { 70 | 'epoch': 0, 71 | 'global_step': 0, 72 | 'scores': {'recall@1': 0.0} 73 | } 74 | if recover_args['last_best_score'] is not None: 75 | best_scores = recover_args['last_best_score'] 76 | 77 | log_json = [] 78 | scaler = torch.cuda.amp.GradScaler(enabled=True) 79 | torch.autograd.set_detect_anomaly(True) 80 | for epoch in range(start_epoch, int(args.num_train_epochs)): 81 | t_start = time.time() 82 | tbar = tqdm(train_dataloader, ncols=70) 83 | for step, batch in enumerate(tbar): 84 | tbar.set_description(f'Training loss = {logging_loss}') 85 | model.train() 86 | 87 | context_input_ids = batch[0].to(args.device, non_blocking=True) 88 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 89 | response_input_ids = batch[2].to(args.device, non_blocking=True) 90 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 91 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 92 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 93 | final_persona_input_ids = batch[6].to(args.device, non_blocking=True) 94 | final_persona_attention_mask = batch[7].to(args.device, non_blocking=True) 95 | dialog_img_feat = batch[8].to(args.device, non_blocking=True) 96 | persona_img_feats = batch[9].to(args.device, non_blocking=True) 97 | final_persona_img_feats = batch[10].to(args.device, non_blocking=True) 98 | dialog_img_mask = batch[11].to(args.device, non_blocking=True) 99 | persona_img_mask = batch[12].to(args.device, non_blocking=True) 100 | 101 | if args.fp16: 102 | with torch.cuda.amp.autocast(enabled=True): 103 | outputs = model( 104 | context_input_ids=context_input_ids, 105 | context_attention_mask=context_attention_mask, 106 | response_input_ids=response_input_ids, 107 | response_attention_mask=response_attention_mask, 108 | persona_input_ids=persona_input_ids, 109 | persona_attention_mask=persona_attention_mask, 110 | final_persona_input_ids=final_persona_input_ids, 111 | final_persona_attention_mask=final_persona_attention_mask, 112 | dialog_img_feat=dialog_img_feat, 113 | persona_img_feats=persona_img_feats, 114 | final_persona_img_feats=final_persona_img_feats, 115 | dialog_img_mask=dialog_img_mask, 116 | persona_img_mask=persona_img_mask, 117 | mode='train', 118 | ) 119 | loss = outputs[0] 120 | else: 121 | outputs = model( 122 | context_input_ids=context_input_ids, 123 | context_attention_mask=context_attention_mask, 124 | response_input_ids=response_input_ids, 125 | response_attention_mask=response_attention_mask, 126 | persona_input_ids=persona_input_ids, 127 | persona_attention_mask=persona_attention_mask, 128 | final_persona_input_ids=final_persona_input_ids, 129 | final_persona_attention_mask=final_persona_attention_mask, 130 | dialog_img_feat=dialog_img_feat, 131 | persona_img_feats=persona_img_feats, 132 | final_persona_img_feats=final_persona_img_feats, 133 | dialog_img_mask=dialog_img_mask, 134 | persona_img_mask=persona_img_mask, 135 | mode='train', 136 | ) 137 | loss = outputs[0] 138 | 139 | if args.n_gpu > 1: loss = loss.mean() 140 | 141 | if args.gradient_accumulation_steps > 1: 142 | loss = loss / args.gradient_accumulation_steps 143 | 144 | if args.fp16: 145 | scaler.scale(loss).backward() 146 | else: 147 | loss.backward() 148 | 149 | tr_loss += loss.item() 150 | logging_loss = round(loss.item(), 5) 151 | if (step + 1) % args.gradient_accumulation_steps == 0: 152 | # do gradient clipping 153 | if args.max_grad_norm > 0: 154 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 155 | scheduler.step() # Update learning rate schedule 156 | if args.fp16: 157 | scaler.step(optimizer) 158 | scaler.update() 159 | else: 160 | optimizer.step() 161 | model.zero_grad() 162 | global_step += 1 163 | 164 | # update tensorboard 165 | meters.update_metrics({'batch_metrics': {'loss': loss}}) 166 | meters.update_params({'params': {'lr': optimizer.param_groups[0]['lr']}}) 167 | 168 | if args.logging_steps > 0 and (global_step + 1) % args.logging_steps == 0: 169 | meters.get_logs(global_step+1) 170 | 171 | # Evaluation 172 | logger.info("Epoch: %d, global_step: %d" % (epoch, global_step)) 173 | eval_scores = evaluate(args, model, eval_dataset, 'val', logger, prefix=global_step) 174 | 175 | # Select recall@1 score as metric 176 | if eval_scores['recall@1'] > best_scores['scores']['recall@1']: 177 | best_scores['scores'] = eval_scores 178 | best_scores['epoch'] = epoch 179 | best_scores['global_step'] = global_step 180 | 181 | # save checkpoints 182 | if (args.local_rank in [-1, 0]) and (args.save_epoch>0 and epoch % args.save_epoch == 0) and (epoch > args.save_after_epoch): 183 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}-{}'.format(epoch, global_step)) 184 | if not os.path.exists(output_dir): os.makedirs(output_dir) 185 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 186 | optimizer_to_save = { 187 | 'optimizer': optimizer.state_dict(), 188 | 'scheduler': scheduler.state_dict(), 189 | } 190 | 191 | save_num = 0 192 | while (save_num < 3): 193 | try: 194 | logger.info("Saving model attempt: {}".format(save_num)) 195 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) 196 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 197 | torch.save(optimizer_to_save, os.path.join(output_dir, 'optimizer.pth')) 198 | save_file = os.path.join(args.output_dir, 'last_checkpoint') 199 | with open(save_file, 'w') as f: 200 | f.write('checkpoint-{}-{}/pytorch_model.bin\n'.format(epoch, global_step)) 201 | f.write(f'best: checkpoint-{best_scores["epoch"]}-{best_scores["global_step"]}\n') 202 | json.dump(best_scores, f) 203 | break 204 | except: 205 | save_num += 1 206 | logger.info("Saving model checkpoint {0} to {1}".format(epoch, output_dir)) 207 | 208 | epoch_log = {'epoch': epoch, 'eval_scores': eval_scores, 'best_scores': best_scores['scores']} 209 | log_json.append(epoch_log) 210 | 211 | logger.info("PROGRESS: {}%".format(round(100*(epoch + 1) / args.num_train_epochs, 4))) 212 | 213 | if args.local_rank in [-1, 0]: 214 | with open(args.output_dir + '/eval_logs.json', 'w') as fp: 215 | json.dump(log_json, fp) 216 | 217 | t_end = time.time() 218 | logger.info('Epoch: %d, Train Time: %.3f' % (epoch, t_end - t_start)) 219 | 220 | return global_step, tr_loss / global_step 221 | 222 | def evaluate(args, model, eval_dataset, mode, logger, prefix=''): 223 | t_start = time.time() 224 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 225 | test_sampler = SequentialSampler(eval_dataset) 226 | test_dataloader = DataLoader(eval_dataset, num_workers=args.num_workers, sampler=test_sampler, batch_size=args.eval_batch_size, pin_memory=True) 227 | 228 | # Eval! 229 | logger.info("***** Running evaluation {} *****".format(prefix)) 230 | logger.info(" Num examples = %d", len(eval_dataset)) 231 | logger.info(" Batch size = %d", args.eval_batch_size) 232 | 233 | results_dict = defaultdict(list) 234 | for batch in tqdm(test_dataloader, ncols=70): 235 | model.eval() 236 | 237 | context_input_ids = batch[0].to(args.device, non_blocking=True) 238 | context_attention_mask = batch[1].to(args.device, non_blocking=True) 239 | response_input_ids = batch[2].to(args.device, non_blocking=True) 240 | response_attention_mask = batch[3].to(args.device, non_blocking=True) 241 | persona_input_ids = batch[4].to(args.device, non_blocking=True) 242 | persona_attention_mask = batch[5].to(args.device, non_blocking=True) 243 | final_persona_input_ids = batch[6].to(args.device, non_blocking=True) 244 | final_persona_attention_mask = batch[7].to(args.device, non_blocking=True) 245 | dialog_img_feat = batch[8].to(args.device, non_blocking=True) 246 | persona_img_feats = batch[9].to(args.device, non_blocking=True) 247 | final_persona_img_feats = batch[10].to(args.device, non_blocking=True) 248 | labels = batch[11].to(args.device, non_blocking=True) 249 | dialog_img_mask = batch[12].to(args.device, non_blocking=True) 250 | persona_img_mask = batch[13].to(args.device, non_blocking=True) 251 | 252 | with torch.no_grad(): 253 | loss, logits = model( 254 | context_input_ids=context_input_ids, 255 | context_attention_mask=context_attention_mask, 256 | response_input_ids=response_input_ids, 257 | response_attention_mask=response_attention_mask, 258 | persona_input_ids=persona_input_ids, 259 | persona_attention_mask=persona_attention_mask, 260 | final_persona_input_ids=final_persona_input_ids, 261 | final_persona_attention_mask=final_persona_attention_mask, 262 | dialog_img_feat=dialog_img_feat, 263 | persona_img_feats=persona_img_feats, 264 | final_persona_img_feats=final_persona_img_feats, 265 | labels=labels, 266 | dialog_img_mask=dialog_img_mask, 267 | persona_img_mask=persona_img_mask, 268 | mode=mode, 269 | ) 270 | results_dict['loss'].append(loss.cpu().detach().numpy()) 271 | results_dict['logits'].append(logits.cpu().detach().numpy()) 272 | results_dict['labels'].append(labels.cpu().detach().numpy()) 273 | 274 | for key, value in results_dict.items(): 275 | if results_dict[key][0].shape == (): 276 | results_dict[key] = np.array(value) 277 | else: 278 | results_dict[key] = np.concatenate(value, axis=0) 279 | 280 | recall, mrr = compute_metrics_from_logits(torch.tensor(results_dict['logits']), 281 | torch.tensor(results_dict['labels'])) 282 | 283 | total_scores = { 284 | 'loss': round(np.mean(results_dict['loss']).item(), 4), 285 | 'mrr': round(mrr, 4), 286 | } 287 | for k,v in recall.items(): 288 | total_scores[k] = round(v, 4) 289 | 290 | logger.info("Eval Results:") 291 | logger.info(f'Eval Score: {total_scores}') 292 | 293 | t_end = time.time() 294 | logger.info('Eval Time Cost: %.3f' % (t_end - t_start)) 295 | 296 | return total_scores 297 | -------------------------------------------------------------------------------- /models/nrp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import CrossEntropyLoss 4 | from typing import Any, Optional, Tuple, Union 5 | 6 | from .model import ( 7 | ClipClipModel, 8 | ClipSbertModel, 9 | clip_loss, 10 | mean_pooling, 11 | ) 12 | 13 | from transformers import ( 14 | CLIPProcessor, 15 | AdamW, 16 | get_linear_schedule_with_warmup, 17 | WEIGHTS_NAME, 18 | ) 19 | 20 | class ClipClipNrp(ClipClipModel): 21 | def forward( 22 | self, 23 | context_input_ids: Optional[torch.LongTensor] = None, 24 | context_attention_mask: Optional[torch.LongTensor] = None, 25 | response_input_ids: Optional[torch.LongTensor] = None, 26 | response_attention_mask: Optional[torch.LongTensor] = None, 27 | persona_input_ids: Optional[torch.LongTensor] = None, 28 | persona_attention_mask: Optional[torch.LongTensor] = None, 29 | dialog_img_feat: Optional[torch.Tensor] = None, 30 | persona_img_feats: Optional[torch.Tensor] = None, 31 | dialog_img_mask: Optional[torch.LongTensor] = None, 32 | persona_img_mask: Optional[torch.LongTensor] = None, 33 | labels: Optional[torch.LongTensor] = None, 34 | mode: str = None, 35 | ): 36 | if mode == 'train': 37 | context_output = self.context_text_encoder( 38 | input_ids=context_input_ids, 39 | attention_mask=context_attention_mask 40 | )[1] 41 | context_output = self.context_text_projection(context_output) 42 | context_output = F.normalize(context_output, p=2, dim=1) 43 | 44 | response_output = self.response_encoder( 45 | input_ids=response_input_ids, 46 | attention_mask=response_attention_mask 47 | )[1] 48 | response_output = self.response_projection(response_output) 49 | response_output = F.normalize(response_output, p=2, dim=1) 50 | 51 | persona_output = self.persona_text_encoder( 52 | input_ids=persona_input_ids, 53 | attention_mask=persona_attention_mask 54 | )[1] 55 | persona_output = self.persona_text_projection(persona_output) 56 | persona_output = F.normalize(persona_output, p=2, dim=1) 57 | 58 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 59 | persona_image_output = self.persona_image_projection(persona_image_output) 60 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 61 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 62 | 63 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 64 | dialog_image_output = self.context_image_projection(dialog_image_output) 65 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 66 | 67 | if self.args.sum_persona_images: 68 | if self.args.remove_empty_images: 69 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 70 | else: 71 | persona_image_output = torch.mean(persona_image_output, dim=1) 72 | multimodal_persona_output = (persona_output + persona_image_output) / 2 73 | if self.args.remove_empty_images: 74 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + multimodal_persona_output) 75 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 76 | else: 77 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 78 | else: 79 | raise NotImplementedError 80 | 81 | # cosine similarity as logits 82 | logit_scale = self.logit_scale.exp() 83 | logits_per_context = torch.matmul(multimodal_context_output, response_output.t()) * logit_scale 84 | loss = clip_loss(logits_per_context) 85 | 86 | outputs = (loss,) 87 | else: 88 | response_input_ids = response_input_ids.view(-1, response_input_ids.shape[-1]) 89 | response_attention_mask = response_attention_mask.view(-1, response_attention_mask.shape[-1]) 90 | context_output = self.context_text_encoder( 91 | input_ids=context_input_ids, 92 | attention_mask=context_attention_mask 93 | )[1] 94 | context_output = self.context_text_projection(context_output) 95 | context_output = F.normalize(context_output, p=2, dim=1) 96 | 97 | response_output = self.response_encoder( 98 | input_ids=response_input_ids, 99 | attention_mask=response_attention_mask 100 | )[1] 101 | response_output = self.response_projection(response_output) 102 | response_output = F.normalize(response_output, p=2, dim=1) 103 | 104 | persona_output = self.persona_text_encoder( 105 | input_ids=persona_input_ids, 106 | attention_mask=persona_attention_mask 107 | )[1] 108 | persona_output = self.persona_text_projection(persona_output) 109 | persona_output = F.normalize(persona_output, p=2, dim=1) 110 | 111 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 112 | persona_image_output = self.persona_image_projection(persona_image_output) 113 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 114 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 115 | 116 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 117 | dialog_image_output = self.context_image_projection(dialog_image_output) 118 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 119 | 120 | if self.args.sum_persona_images: 121 | if self.args.remove_empty_images: 122 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 123 | else: 124 | persona_image_output = torch.mean(persona_image_output, dim=1) 125 | multimodal_persona_output = (persona_output + persona_image_output) / 2 126 | if self.args.remove_empty_images: 127 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + multimodal_persona_output) 128 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 129 | else: 130 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 131 | else: 132 | raise NotImplementedError 133 | 134 | logits = torch.bmm( 135 | multimodal_context_output.unsqueeze(1), 136 | response_output.view(-1, self.args.max_num_responses, response_output.shape[-1]).transpose(1,2) 137 | ).squeeze(1) 138 | 139 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 140 | outputs = (loss, logits,) 141 | return outputs 142 | 143 | class ClipSbertNrp(ClipSbertModel): 144 | def forward( 145 | self, 146 | context_input_ids: Optional[torch.LongTensor] = None, 147 | context_attention_mask: Optional[torch.LongTensor] = None, 148 | response_input_ids: Optional[torch.LongTensor] = None, 149 | response_attention_mask: Optional[torch.LongTensor] = None, 150 | persona_input_ids: Optional[torch.LongTensor] = None, 151 | persona_attention_mask: Optional[torch.LongTensor] = None, 152 | dialog_img_feat: Optional[torch.Tensor] = None, 153 | persona_img_feats: Optional[torch.Tensor] = None, 154 | dialog_img_mask: Optional[torch.LongTensor] = None, 155 | persona_img_mask: Optional[torch.LongTensor] = None, 156 | labels: Optional[torch.LongTensor] = None, 157 | mode: str = None, 158 | ): 159 | if mode == 'train': 160 | context_output = self.context_text_encoder( 161 | input_ids=context_input_ids, 162 | attention_mask=context_attention_mask 163 | ) 164 | context_output = mean_pooling(context_output, context_attention_mask) 165 | 166 | response_output = self.response_encoder( 167 | input_ids=response_input_ids, 168 | attention_mask=response_attention_mask 169 | ) 170 | response_output = mean_pooling(response_output, response_attention_mask) 171 | 172 | persona_output = self.persona_text_encoder( 173 | input_ids=persona_input_ids, 174 | attention_mask=persona_attention_mask 175 | ) 176 | persona_output = mean_pooling(persona_output, persona_attention_mask) 177 | 178 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 179 | persona_image_output = self.persona_image_projection(persona_image_output) 180 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 181 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 182 | 183 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 184 | dialog_image_output = self.context_image_projection(dialog_image_output) 185 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 186 | 187 | if self.args.sum_persona_images: 188 | if self.args.remove_empty_images: 189 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 190 | else: 191 | persona_image_output = torch.mean(persona_image_output, dim=1) 192 | multimodal_persona_output = (persona_output + persona_image_output) / 2 193 | if self.args.remove_empty_images: 194 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + multimodal_persona_output) 195 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 196 | else: 197 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 198 | else: 199 | raise NotImplementedError 200 | 201 | targets = torch.arange(context_output.shape[0], device=context_output.device) 202 | # dot_products: [batch, batch] 203 | dot_products = multimodal_context_output.mm(response_output.t()) 204 | log_prob = F.log_softmax(dot_products, dim=1) 205 | loss = F.nll_loss(log_prob, targets) 206 | 207 | outputs = (loss,) 208 | else: 209 | response_input_ids = response_input_ids.view(-1, response_input_ids.shape[-1]) 210 | response_attention_mask = response_attention_mask.view(-1, response_attention_mask.shape[-1]) 211 | context_output = self.context_text_encoder( 212 | input_ids=context_input_ids, 213 | attention_mask=context_attention_mask 214 | ) 215 | context_output = mean_pooling(context_output, context_attention_mask) 216 | 217 | response_output = self.response_encoder( 218 | input_ids=response_input_ids, 219 | attention_mask=response_attention_mask 220 | ) 221 | response_output = mean_pooling(response_output, response_attention_mask) 222 | 223 | persona_output = self.persona_text_encoder( 224 | input_ids=persona_input_ids, 225 | attention_mask=persona_attention_mask 226 | ) 227 | persona_output = mean_pooling(persona_output, persona_attention_mask) 228 | 229 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 230 | persona_image_output = self.persona_image_projection(persona_image_output) 231 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 232 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 233 | 234 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 235 | dialog_image_output = self.context_image_projection(dialog_image_output) 236 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 237 | 238 | if self.args.sum_persona_images: 239 | if self.args.remove_empty_images: 240 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 241 | else: 242 | persona_image_output = torch.mean(persona_image_output, dim=1) 243 | multimodal_persona_output = (persona_output + persona_image_output) / 2 244 | if self.args.remove_empty_images: 245 | multimodal_context_output = (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output + context_output + multimodal_persona_output) 246 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 247 | else: 248 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 249 | else: 250 | raise NotImplementedError 251 | 252 | logits = torch.bmm( 253 | multimodal_context_output.unsqueeze(1), 254 | response_output.view(-1, self.args.max_num_responses, response_output.shape[-1]).transpose(1,2) 255 | ).squeeze(1) 256 | 257 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 258 | outputs = (loss, logits,) 259 | return outputs 260 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | ## Next Response Prediction 2 | 3 | **mpchat_nrp.json** \ 4 | Each split (train/val/test) contains a list of dialogues. \ 5 | A dialogue has the following structure: 6 | 7 | ``` 8 | [ 9 | { 10 | ## Example (not real) 11 | "subreddit": "itookapicture", 12 | "messages": [ 13 | "itap of my cat", 14 | "omg it is so cute! great shot!", 15 | "she is a model, she takes the pose for me." 16 | ], 17 | "message_ids": [ 18 | "ab3d5f", ## post (or comment) id 19 | "gaq65vy", 20 | "gaquc1k" 21 | ], 22 | "main_author": "johndoe", 23 | "authors": [ 24 | "johndoe", 25 | "mickeymouse", 26 | "johndoe" 27 | ] 28 | "created_utcs": [ 29 | 1604117284.0, 30 | 1604173517.0, 31 | 1604188317.0, 32 | ], 33 | "has_image": true, 34 | "all_personas": [ 35 | { 36 | "id": "iawlby", 37 | "subreddit": "barista", 38 | "url" "https://i.redd.it/ghifjk.jpg", 39 | "score": 5, ## upvotes - downvotes 40 | "author": "johndoe", 41 | "created_utcs": 1597599418.0, 42 | "permalink": "/r/barista/comments/iawlby/my_dripper_art" 43 | "title": "my dripper art pieces that i made into stickers recently!", 44 | "direct_url": "https://i.redd.it/ghifjk.jpg", 45 | "file_name": "iawlby_ghifjk.jpg" ## {post_id}_{direct_url.split('/')[-1]} 46 | }, 47 | ... 48 | ], 49 | "grounded_personas": [ 50 | [], 51 | [], 52 | [ 53 | { 54 | "id": "jg9ml1", 55 | "subreddit": "cats", 56 | "url": "https://i.redd.it/0qBk1ej.jpg", 57 | "score": 238, 58 | "author": "johndoe", 59 | "created_utc": 1597549415, 60 | "permalink": "/r/cats/comments/jg9ml1/from_feline_to_fashion/", 61 | "title": "from feline to fashion: my journey training a cat supermodel", 62 | "direct_url": "https://i.redd.it/0qBk1ej.jpg", 63 | "file_name": "jg9ml1_0qBk1ej.jpg", 64 | "label_overall": "(strong) E", 65 | ################################# 66 | ## (strong) E: response strongly entailed by (=grounded on) the persona element - entailment score = 3/3 67 | ## E: response entailed by the persona element - entailment score = 2/3 68 | ## I: response irrelevant to the persona element - entailment score = 1/3 69 | ## (strong) I: response strongly irrelevant to persona element - entailment score = 0/3 70 | ################################# 71 | "label_per_worker": [ ## three workers labeled it as entailed (=1) 72 | [ 73 | "A2NSS746CFCT4N", 74 | 1 75 | ], 76 | [ 77 | "AKQAI78JTXXC8", 78 | 1, 79 | ], 80 | [ 81 | "ANX6Q4NMZL8EL", 82 | 1 83 | ], 84 | } 85 | ] 86 | ], 87 | "ungrounded_personas": [ 88 | [], 89 | [], 90 | [] 91 | ], 92 | "direct_url": "https://i.redd.it/abcdef.jpg", 93 | "file_name": "ab3d5f_abcdef.jpg" ## {post_id}_{direct_url.split('/')[-1]} 94 | "candidate_personas": [ ## max 5 elements among 'all_personas' 95 | { 96 | "id": "iawlby", 97 | "subreddit": "barista", 98 | "url" "https://i.redd.it/ghifjk.jpg", 99 | "score": 5, 100 | "author": "johndoe", 101 | "created_utcs": 1597599418.0, 102 | "permalink": "/r/barista/comments/iawlby/my_dripper_art" 103 | "title": "my dripper art pieces that i made into stickers recently!", 104 | "direct_url": "https://i.redd.it/ghifjk.jpg", 105 | "file_name": "iawlby_ghifjk.jpg" 106 | }, 107 | { 108 | "id": "jg9ml1", 109 | "subreddit": "cats", 110 | "url": "https://i.redd.it/0qBk1ej.jpg", 111 | "score": 238, 112 | "author": "johndoe", 113 | "created_utc": 1597549415, 114 | "permalink": "/r/cats/comments/jg9ml1/from_feline_to_fashion/", 115 | "title": "from feline to fashion: my journey training a cat supermodel", 116 | "direct_url": "https://i.redd.it/0qBk1ej.jpg", 117 | "file_name": "jg9ml1_0qBk1ej.jpg", 118 | }, 119 | ... 120 | ], 121 | "nrp_candidate_responses": [ ## only in val and test set 122 | [ 123 | "itap of my cat", 124 | (99 candidate responses) 125 | ... 126 | ], 127 | [], 128 | [ 129 | "she is a model, she takes the pose for me." 130 | (99 candidate responses) 131 | ... 132 | ] 133 | ] 134 | }, 135 | ... 136 | ] 137 | ``` 138 | 139 | Please see below for a description of each attribute in the dataset: 140 | 141 | attribute | type | description 142 | --- | --- | --- 143 | `subreddit` | str | subreddit of post 144 | `messages` | list of str | dialogue between multiple authors 145 | `message_ids` | list of str | post (or comment) id of each utterance 146 | `main_author` | str | main author with multimodal persona info 147 | `authors` | list of str | author info of each utterance 148 | `created_utcs` | str | UTC epoch when post (or comment) was submitted 149 | `has_image` | bool | whether post has image or not 150 | `direct_url` | str | direct url of post image 151 | `file_name` | str | saved file name of post image (format: {post_id}_{direct_url.split('/')[-1]}) 152 | `all_personas` | list of dict | main author's all persona elements 153 | `grounded_personas` | list of list of dict | grounding persona elements of each utterance, only provided in main author's turn (labeled by workers) 154 | `ungrounded_personas` | list of list of dict | un-grounding persona elements of each utterance, only provided in main author's turn (labeled by workers) 155 | `candidate_personas` | list of dict | main author's candidate (max 5) persona elements 156 | `nrp_candidate_responses` | list of list of str | 100 candidate respones, only provided in main author's turn 157 | 158 | ## Grounding Persona Prediction 159 | 160 | **mpchat_gpp.json** \ 161 | A dialogue has the following structure: 162 | 163 | ``` 164 | [ 165 | { 166 | "subreddit": "itookapicture", 167 | "messages": [ 168 | "itap of my cat", 169 | "omg it is so cute! great shot!", 170 | "she is a model, she takes the pose for me." 171 | ], 172 | "message_ids": [ 173 | "ab3d5f", 174 | "gaq65vy", 175 | "gaquc1k" 176 | ], 177 | "main_author": "johndoe", 178 | "authors": [ 179 | "johndoe", 180 | "mickeymouse", 181 | "johndoe" 182 | ] 183 | "created_utcs": [ 184 | 1604117284.0, 185 | 1604173517.0, 186 | 1604188317.0, 187 | ], 188 | "has_image": true, 189 | "all_personas": [ 190 | ... 191 | ], 192 | "direct_url": "https://i.redd.it/abcdef.jpg", 193 | "file_name": "ab3d5f_abcdef.jpg" 194 | "gpp_grounded_persona": [ ## the persona element grounding on the response for each turn 195 | null, 196 | null, 197 | { 198 | "id": "jg9ml1", 199 | "subreddit": "cats", 200 | "url": "https://i.redd.it/0qBk1ej.jpg", 201 | "score": 238, 202 | "author": "johndoe", 203 | "created_utc": 1597549415, 204 | "permalink": "/r/cats/comments/jg9ml1/from_feline_to_fashion/", 205 | "title": "from feline to fashion: my journey training a cat supermodel", 206 | "direct_url": "https://i.redd.it/0qBk1ej.jpg", 207 | "file_name": "jg9ml1_0qBk1ej.jpg", 208 | } 209 | ], 210 | "gpp_candidate_personas": [ ## max 4 persona elements for each turn 211 | [], 212 | [], 213 | [ 214 | { 215 | "id": "iawlby", 216 | "subreddit": "barista", 217 | "url" "https://i.redd.it/ghifjk.jpg", 218 | "score": 5, 219 | "author": "johndoe", 220 | "created_utcs": 1597599418.0, 221 | "permalink": "/r/barista/comments/iawlby/my_dripper_art" 222 | "title": "my dripper art pieces that i made into stickers recently!", 223 | "direct_url": "https://i.redd.it/ghifjk.jpg", 224 | "file_name": "iawlby_ghifjk.jpg" 225 | }, 226 | ... 227 | ] 228 | ], 229 | "gpp_candidate_authors_candidate_personas": [ ## only in val and test set 230 | [], 231 | [], 232 | [ 233 | { 234 | "id": "jg9ml1", 235 | ... 236 | } 237 | (99 candidate persona elements) 238 | ... 239 | ] 240 | ] 241 | }, 242 | ... 243 | ] 244 | ``` 245 | 246 | Please see below for a description of each attribute in the dataset: 247 | 248 | attribute | type | description 249 | --- | --- | --- 250 | `subreddit` | str | subreddit of post 251 | `messages` | list of str | dialogue between multiple authors 252 | `message_ids` | list of str | post (or comment) id of each utterance 253 | `main_author` | str | main author with multimodal persona info 254 | `authors` | list of str | author info of each utterance 255 | `created_utcs` | str | UTC epoch when post (or comment) was submitted 256 | `has_image` | bool | whether post has image or not 257 | `direct_url` | str | direct url of post image 258 | `file_name` | str | saved file name of post image (format: {post_id}_{direct_url.split('/')[-1]}) 259 | `all_personas` | list of dict | main author's all persona elements 260 | `gpp_grounded_persona` | list of dict | grounding persona element of each utterance, only provided in main author's turn 261 | `gpp_candidate_personas` | list of list of dict | main author's candidate (max 4) persona elements, only provided in main author's turn and only if `gpp_grounded_persona` exists in the turn 262 | `gpp_candidate_authors_candidate_personas` | list of list of dict | 100 candidate persona elements, only provided in main author's turn and only if `gpp_grounded_persona` exists in the turn 263 | 264 | ## Speaker Identification 265 | 266 | **mpchat_si.json** \ 267 | A dialogue has the following structure: 268 | 269 | ``` 270 | [ 271 | { 272 | "subreddit": "itookapicture", 273 | "messages": [ 274 | "itap of my cat", 275 | "omg it is so cute! great shot!", 276 | "she is a model, she takes the pose for me." 277 | ], 278 | "message_ids": [ 279 | "ab3d5f", 280 | "gaq65vy", 281 | "gaquc1k" 282 | ], 283 | "main_author": "johndoe", 284 | "authors": [ 285 | "johndoe", 286 | "mickeymouse", 287 | "johndoe" 288 | ] 289 | "created_utcs": [ 290 | 1604117284.0, 291 | 1604173517.0, 292 | 1604188317.0, 293 | ], 294 | "has_image": true, 295 | "all_personas": [ 296 | { 297 | "id": "iawlby", 298 | "subreddit": "barista", 299 | "url" "https://i.redd.it/ghifjk.jpg", 300 | "score": 5, ## upvotes - downvotes 301 | "author": "johndoe", 302 | "created_utcs": 1597599418.0, 303 | "permalink": "/r/barista/comments/iawlby/my_dripper_art" 304 | "title": "my dripper art pieces that i made into stickers recently!", 305 | "direct_url": "https://i.redd.it/ghifjk.jpg", 306 | "file_name": "iawlby_ghifjk.jpg" 307 | }, 308 | ... 309 | ], 310 | "direct_url": "https://i.redd.it/abcdef.jpg", 311 | "file_name": "ab3d5f_abcdef.jpg" 312 | "si_main_author_candidate_personas": [ ## max 5 elements among 'all_personas' 313 | { 314 | "id": "iawlby", 315 | "subreddit": "barista", 316 | "url" "https://i.redd.it/ghifjk.jpg", 317 | "score": 5, 318 | "author": "johndoe", 319 | "created_utcs": 1597599418.0, 320 | "permalink": "/r/barista/comments/iawlby/my_dripper_art" 321 | "title": "my dripper art pieces that i made into stickers recently!", 322 | "direct_url": "https://i.redd.it/ghifjk.jpg", 323 | "file_name": "iawlby_ghifjk.jpg" 324 | }, 325 | { 326 | "id": "jg9ml1", 327 | "subreddit": "cats", 328 | "url": "https://i.redd.it/0qBk1ej.jpg", 329 | "score": 238, 330 | "author": "johndoe", 331 | "created_utc": 1597549415, 332 | "permalink": "/r/cats/comments/jg9ml1/from_feline_to_fashion/", 333 | "title": "from feline to fashion: my journey training a cat supermodel", 334 | "direct_url": "https://i.redd.it/0qBk1ej.jpg", 335 | "file_name": "jg9ml1_0qBk1ej.jpg", 336 | }, 337 | ... 338 | ], 339 | "si_candidate_authors_candidate_personas": [ ## only in val and test set 340 | [ 341 | { 342 | "id": "jg9ml1", 343 | ... 344 | }, 345 | { 346 | "id": "iawlby", 347 | ... 348 | }, 349 | ... 350 | ], 351 | (99 candidate authors' candidate personas) 352 | ] 353 | }, 354 | ... 355 | ] 356 | ``` 357 | 358 | Please see below for a description of each attribute in the dataset: 359 | 360 | attribute | type | description 361 | --- | --- | --- 362 | `subreddit` | str | subreddit of post 363 | `messages` | list of str | dialogue between multiple authors 364 | `message_ids` | list of str | post (or comment) id of each utterance 365 | `main_author` | str | main author with multimodal persona info 366 | `authors` | list of str | author info of each utterance 367 | `created_utcs` | str | UTC epoch when post (or comment) was submitted 368 | `has_image` | bool | whether post has image or not 369 | `direct_url` | str | direct url of post image 370 | `file_name` | str | saved file name of post image (format: {post_id}_{direct_url.split('/')[-1]}) 371 | `all_personas` | list of dict | main author's all persona elements 372 | `si_main_author_candidate_personas` | list of dict | main author's candidate (max 5) persona elements 373 | `si_candidate_authors_candidate_personas` | list of list of dict | 100 candidate authors' persona elements 374 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. 396 | -------------------------------------------------------------------------------- /models/gpp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import CrossEntropyLoss 4 | from typing import Any, Optional, Tuple, Union 5 | 6 | from .model import ( 7 | ClipClipModel, 8 | ClipSbertModel, 9 | clip_loss, 10 | mean_pooling, 11 | ) 12 | 13 | from transformers import ( 14 | CLIPProcessor, 15 | AdamW, 16 | get_linear_schedule_with_warmup, 17 | WEIGHTS_NAME, 18 | ) 19 | 20 | class ClipClipGpp(ClipClipModel): 21 | def forward( 22 | self, 23 | context_input_ids: Optional[torch.LongTensor] = None, 24 | context_attention_mask: Optional[torch.LongTensor] = None, 25 | response_input_ids: Optional[torch.LongTensor] = None, 26 | response_attention_mask: Optional[torch.LongTensor] = None, 27 | persona_input_ids: Optional[torch.LongTensor] = None, 28 | persona_attention_mask: Optional[torch.LongTensor] = None, 29 | final_persona_input_ids: Optional[torch.LongTensor] = None, 30 | final_persona_attention_mask: Optional[torch.LongTensor] = None, 31 | dialog_img_feat: Optional[torch.Tensor] = None, 32 | persona_img_feats: Optional[torch.Tensor] = None, 33 | final_persona_img_feats: Optional[torch.Tensor] = None, 34 | dialog_img_mask: Optional[torch.LongTensor] = None, 35 | persona_img_mask: Optional[torch.LongTensor] = None, 36 | labels: Optional[torch.LongTensor] = None, 37 | mode: str = None, 38 | ): 39 | if mode == 'train': 40 | context_output = self.context_text_encoder( 41 | input_ids=context_input_ids, 42 | attention_mask=context_attention_mask 43 | )[1] 44 | context_output = self.context_text_projection(context_output) 45 | context_output = F.normalize(context_output, p=2, dim=1) 46 | 47 | persona_output = self.persona_text_encoder( 48 | input_ids=persona_input_ids, 49 | attention_mask=persona_attention_mask 50 | )[1] 51 | persona_output = self.persona_text_projection(persona_output) 52 | persona_output = F.normalize(persona_output, p=2, dim=1) 53 | 54 | if self.args.use_response: 55 | response_output = self.response_encoder( 56 | input_ids=response_input_ids, 57 | attention_mask=response_attention_mask 58 | )[1] 59 | response_output = self.response_projection(response_output) 60 | response_output = F.normalize(response_output, p=2, dim=1) 61 | 62 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 63 | dialog_image_output = self.context_image_projection(dialog_image_output) 64 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 65 | 66 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 67 | persona_image_output = self.persona_image_projection(persona_image_output) 68 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 69 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 70 | 71 | if self.args.sum_persona_images: 72 | if self.args.remove_empty_images: 73 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) 74 | persona_image_output = persona_image_output / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 75 | multimodal_persona_output = (persona_output + persona_image_output) / 2 76 | if self.args.use_response: 77 | multimodal_context_output = context_output + response_output + multimodal_persona_output 78 | else: 79 | multimodal_context_output = context_output + multimodal_persona_output 80 | multimodal_context_output += (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output) 81 | if self.args.use_response: 82 | multimodal_context_output /= (dialog_img_mask + 3).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 83 | else: 84 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1, dialog_image_output.size(-1)) 85 | else: 86 | persona_image_output = torch.mean(persona_image_output, dim=1) 87 | multimodal_persona_output = (persona_output + persona_image_output) / 2 88 | if self.args.use_response: 89 | multimodal_context_output = (context_output + response_output + dialog_image_output + multimodal_persona_output) / 4 90 | else: 91 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 92 | else: 93 | raise NotImplementedError 94 | 95 | final_persona_output = self.persona_text_encoder( 96 | input_ids=final_persona_input_ids, 97 | attention_mask=final_persona_attention_mask 98 | )[1] 99 | final_persona_output = self.persona_text_projection(final_persona_output) 100 | final_persona_output = F.normalize(final_persona_output, p=2, dim=1) 101 | 102 | final_persona_image_output = self.persona_image_encoder(pixel_values=final_persona_img_feats)[1] 103 | final_persona_image_output = self.persona_image_projection(final_persona_image_output) 104 | final_persona_image_output = F.normalize(final_persona_image_output, p=2, dim=1) 105 | 106 | final_multimodal_persona_output = (final_persona_output + final_persona_image_output) / 2 107 | 108 | logit_scale = self.logit_scale.exp() 109 | dot_products = multimodal_context_output.mm(final_multimodal_persona_output.t()) * logit_scale 110 | loss = clip_loss(dot_products) 111 | 112 | outputs = (loss,) 113 | else: 114 | context_output = self.context_text_encoder( 115 | input_ids=context_input_ids, 116 | attention_mask=context_attention_mask 117 | )[1] 118 | context_output = self.context_text_projection(context_output) 119 | context_output = F.normalize(context_output, p=2, dim=1) 120 | 121 | persona_output = self.persona_text_encoder( 122 | input_ids=persona_input_ids, 123 | attention_mask=persona_attention_mask 124 | )[1] 125 | persona_output = self.persona_text_projection(persona_output) 126 | persona_output = F.normalize(persona_output, p=2, dim=1) 127 | 128 | response_output = self.response_encoder( 129 | input_ids=response_input_ids, 130 | attention_mask=response_attention_mask 131 | )[1] 132 | response_output = self.response_projection(response_output) 133 | response_output = F.normalize(response_output, p=2, dim=1) 134 | 135 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 136 | dialog_image_output = self.context_image_projection(dialog_image_output) 137 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 138 | 139 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 140 | persona_image_output = self.persona_image_projection(persona_image_output) 141 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 142 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 143 | 144 | if self.args.sum_persona_images: 145 | if self.args.remove_empty_images: 146 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) 147 | persona_image_output = persona_image_output / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 148 | multimodal_persona_output = (persona_output + persona_image_output) / 2 149 | if self.args.use_response: 150 | multimodal_context_output = context_output + response_output + multimodal_persona_output 151 | else: 152 | multimodal_context_output = context_output + multimodal_persona_output 153 | multimodal_context_output += (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output) 154 | if self.args.use_response: 155 | multimodal_context_output /= (dialog_img_mask + 3).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 156 | else: 157 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1, dialog_image_output.size(-1)) 158 | else: 159 | persona_image_output = torch.mean(persona_image_output, dim=1) 160 | multimodal_persona_output = (persona_output + persona_image_output) / 2 161 | if self.args.use_response: 162 | multimodal_context_output = (context_output + response_output + dialog_image_output + multimodal_persona_output) / 4 163 | else: 164 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 165 | else: 166 | raise NotImplementedError 167 | 168 | cand_final_persona_input_ids = final_persona_input_ids.view(-1, final_persona_input_ids.size(-1)) 169 | cand_final_persona_attention_mask = final_persona_attention_mask.view(-1, final_persona_attention_mask.size(-1)) 170 | cand_final_persona_output = self.persona_text_encoder( 171 | input_ids=cand_final_persona_input_ids, 172 | attention_mask=cand_final_persona_attention_mask 173 | )[1] 174 | cand_final_persona_output = self.persona_text_projection(cand_final_persona_output) 175 | cand_final_persona_output = F.normalize(cand_final_persona_output, p=2, dim=1) 176 | 177 | cand_final_persona_img_feats = final_persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size) 178 | 179 | cand_final_persona_image_output = self.persona_image_encoder(pixel_values=cand_final_persona_img_feats)[1] 180 | cand_final_persona_image_output = self.persona_image_projection(cand_final_persona_image_output) 181 | cand_final_persona_image_output = F.normalize(cand_final_persona_image_output, p=2, dim=1) 182 | 183 | cand_final_multimodal_persona_output = (cand_final_persona_output + cand_final_persona_image_output) / 2 184 | logits = torch.bmm(multimodal_context_output.unsqueeze(1), 185 | cand_final_multimodal_persona_output.view( 186 | context_input_ids.size(0), 187 | self.args.max_num_candidate_persona_elements, 188 | -1).transpose(1,2)).squeeze(1) 189 | 190 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 191 | outputs = (loss, logits,) 192 | return outputs 193 | 194 | class ClipSbertGpp(ClipSbertModel): 195 | def forward( 196 | self, 197 | context_input_ids: Optional[torch.LongTensor] = None, 198 | context_attention_mask: Optional[torch.LongTensor] = None, 199 | response_input_ids: Optional[torch.LongTensor] = None, 200 | response_attention_mask: Optional[torch.LongTensor] = None, 201 | persona_input_ids: Optional[torch.LongTensor] = None, 202 | persona_attention_mask: Optional[torch.LongTensor] = None, 203 | final_persona_input_ids: Optional[torch.LongTensor] = None, 204 | final_persona_attention_mask: Optional[torch.LongTensor] = None, 205 | dialog_img_feat: Optional[torch.Tensor] = None, 206 | persona_img_feats: Optional[torch.Tensor] = None, 207 | final_persona_img_feats: Optional[torch.Tensor] = None, 208 | dialog_img_mask: Optional[torch.LongTensor] = None, 209 | persona_img_mask: Optional[torch.LongTensor] = None, 210 | labels: Optional[torch.LongTensor] = None, 211 | mode: str = None, 212 | ): 213 | if mode == 'train': 214 | context_output = self.context_text_encoder( 215 | input_ids=context_input_ids, 216 | attention_mask=context_attention_mask 217 | ) 218 | context_output = mean_pooling(context_output, context_attention_mask) 219 | 220 | persona_output = self.persona_text_encoder( 221 | input_ids=persona_input_ids, 222 | attention_mask=persona_attention_mask 223 | ) 224 | persona_output = mean_pooling(persona_output, persona_attention_mask) 225 | 226 | response_output = self.response_encoder( 227 | input_ids=response_input_ids, 228 | attention_mask=response_attention_mask 229 | ) 230 | response_output = mean_pooling(response_output, response_attention_mask) 231 | 232 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 233 | dialog_image_output = self.context_image_projection(dialog_image_output) 234 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 235 | 236 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 237 | persona_image_output = self.persona_image_projection(persona_image_output) 238 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 239 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 240 | 241 | if self.args.sum_persona_images: 242 | if self.args.remove_empty_images: 243 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) 244 | persona_image_output = persona_image_output / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 245 | multimodal_persona_output = (persona_output + persona_image_output) / 2 246 | if self.args.use_response: 247 | multimodal_context_output = context_output + response_output + multimodal_persona_output 248 | else: 249 | multimodal_context_output = context_output + multimodal_persona_output 250 | multimodal_context_output += (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output) 251 | if self.args.use_response: 252 | multimodal_context_output /= (dialog_img_mask + 3).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 253 | else: 254 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1, dialog_image_output.size(-1)) 255 | else: 256 | persona_image_output = torch.mean(persona_image_output, dim=1) 257 | multimodal_persona_output = (persona_output + persona_image_output) / 2 258 | if self.args.use_response: 259 | multimodal_context_output = (context_output + response_output + dialog_image_output + multimodal_persona_output) / 4 260 | else: 261 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 262 | else: 263 | raise NotImplementedError 264 | 265 | final_persona_output = self.persona_text_encoder( 266 | input_ids=final_persona_input_ids, 267 | attention_mask=final_persona_attention_mask 268 | ) 269 | final_persona_output = mean_pooling(final_persona_output, final_persona_attention_mask) 270 | 271 | final_persona_image_output = self.persona_image_encoder(pixel_values=final_persona_img_feats)[1] 272 | final_persona_image_output = self.persona_image_projection(final_persona_image_output) 273 | final_persona_image_output = F.normalize(final_persona_image_output, p=2, dim=1) 274 | 275 | final_multimodal_persona_output = (final_persona_output + final_persona_image_output) / 2 276 | 277 | targets = torch.arange(context_output.shape[0], device=context_output.device) 278 | # dot_products: [batch, batch] 279 | dot_products = multimodal_context_output.mm(final_multimodal_persona_output.t()) 280 | log_prob = F.log_softmax(dot_products, dim=1) 281 | loss = F.nll_loss(log_prob, targets) 282 | 283 | outputs = (loss,) 284 | else: 285 | context_output = self.context_text_encoder( 286 | input_ids=context_input_ids, 287 | attention_mask=context_attention_mask 288 | ) 289 | context_output = mean_pooling(context_output, context_attention_mask) 290 | 291 | persona_output = self.persona_text_encoder( 292 | input_ids=persona_input_ids, 293 | attention_mask=persona_attention_mask 294 | ) 295 | persona_output = mean_pooling(persona_output, persona_attention_mask) 296 | 297 | response_output = self.response_encoder( 298 | input_ids=response_input_ids, 299 | attention_mask=response_attention_mask 300 | ) 301 | response_output = mean_pooling(response_output, response_attention_mask) 302 | 303 | dialog_image_output = self.context_image_encoder(pixel_values=dialog_img_feat)[1] 304 | dialog_image_output = self.context_image_projection(dialog_image_output) 305 | dialog_image_output = F.normalize(dialog_image_output, p=2, dim=1) 306 | 307 | persona_image_output = self.persona_image_encoder(pixel_values=persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size))[1] 308 | persona_image_output = self.persona_image_projection(persona_image_output) 309 | persona_image_output = F.normalize(persona_image_output, p=2, dim=1) 310 | persona_image_output = persona_image_output.view(persona_img_feats.size(0), self.args.max_num_imgs, persona_image_output.size(-1)) 311 | 312 | if self.args.sum_persona_images: 313 | if self.args.remove_empty_images: 314 | persona_image_output = torch.sum(persona_img_mask.unsqueeze(-1).repeat(1,1,dialog_image_output.size(-1)) * persona_image_output, dim=1) 315 | persona_image_output = persona_image_output / torch.sum(persona_img_mask, dim=1).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 316 | multimodal_persona_output = (persona_output + persona_image_output) / 2 317 | if self.args.use_response: 318 | multimodal_context_output = context_output + response_output + multimodal_persona_output 319 | else: 320 | multimodal_context_output = context_output + multimodal_persona_output 321 | multimodal_context_output += (dialog_img_mask.unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) * dialog_image_output) 322 | if self.args.use_response: 323 | multimodal_context_output /= (dialog_img_mask + 3).unsqueeze(-1).repeat(1,dialog_image_output.size(-1)) 324 | else: 325 | multimodal_context_output /= (dialog_img_mask + 2).unsqueeze(-1).repeat(1, dialog_image_output.size(-1)) 326 | else: 327 | persona_image_output = torch.mean(persona_image_output, dim=1) 328 | multimodal_persona_output = (persona_output + persona_image_output) / 2 329 | if self.args.use_response: 330 | multimodal_context_output = (context_output + response_output + dialog_image_output + multimodal_persona_output) / 4 331 | else: 332 | multimodal_context_output = (context_output + dialog_image_output + multimodal_persona_output) / 3 333 | else: 334 | raise NotImplementedError 335 | 336 | cand_final_persona_input_ids = final_persona_input_ids.view(-1, final_persona_input_ids.size(-1)) 337 | cand_final_persona_attention_mask = final_persona_attention_mask.view(-1, final_persona_attention_mask.size(-1)) 338 | cand_final_persona_output = self.persona_text_encoder( 339 | input_ids=cand_final_persona_input_ids, 340 | attention_mask=cand_final_persona_attention_mask 341 | ) 342 | cand_final_persona_output = mean_pooling(cand_final_persona_output, cand_final_persona_attention_mask) 343 | 344 | cand_final_persona_img_feats = final_persona_img_feats.view(-1, 3, self.args.img_size, self.args.img_size) 345 | 346 | cand_final_persona_image_output = self.persona_image_encoder(pixel_values=cand_final_persona_img_feats)[1] 347 | cand_final_persona_image_output = self.persona_image_projection(cand_final_persona_image_output) 348 | cand_final_persona_image_output = F.normalize(cand_final_persona_image_output, p=2, dim=1) 349 | 350 | cand_final_multimodal_persona_output = (cand_final_persona_output + cand_final_persona_image_output) / 2 351 | logits = torch.bmm(multimodal_context_output.unsqueeze(1), 352 | cand_final_multimodal_persona_output.view( 353 | context_input_ids.size(0), 354 | self.args.max_num_candidate_persona_elements, 355 | -1).transpose(1,2)).squeeze(1) 356 | 357 | loss = CrossEntropyLoss(reduction='none')(logits, labels) 358 | outputs = (loss, logits,) 359 | return outputs 360 | -------------------------------------------------------------------------------- /data/mpchat_nrp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | 5 | import torch 6 | from torch.utils.data import ( 7 | Dataset, 8 | ) 9 | 10 | from utils.misc import pil_loader 11 | 12 | class MpchatClipClipNrpDataset(Dataset): 13 | def __init__(self, 14 | args, 15 | tokenizer, 16 | clip_processor, 17 | mode): 18 | super(MpchatClipClipNrpDataset, self).__init__() 19 | assert mode in ['train', 'val', 'test'] 20 | 21 | self.args = args 22 | self.clip_processor = clip_processor 23 | self.mode = mode 24 | self.examples = [] 25 | 26 | with open(os.path.join(args.dialog_data_dir, 'mpchat_nrp.json'), 'r') as fp: 27 | data = json.load(fp)[f'{mode}'] 28 | 29 | num_examples = 0 30 | for dialog_idx, dialog in enumerate(data): 31 | main_author = dialog['main_author'] 32 | turn_indices = [] 33 | for turn_idx, author in enumerate(dialog['authors']): 34 | if main_author == author: 35 | turn_indices.append(turn_idx) 36 | 37 | dialog_subreddit = dialog['subreddit'] 38 | for turn_idx in turn_indices: 39 | context = ' '.join(dialog['messages'][:turn_idx]) 40 | response = dialog['messages'][turn_idx] 41 | persona_sentences = ' '.join([f"{x['title']}" for x in dialog['candidate_personas']]) 42 | persona_fpaths = [os.path.join(args.persona_image_data_dir, x['file_name']) for x in dialog['candidate_personas']] 43 | if dialog['has_image']: 44 | fname_context = dialog['file_name'] 45 | dialog_fpath = os.path.join(args.dialog_image_data_dir, fname_context) 46 | else: 47 | dialog_fpath = '' 48 | 49 | if mode == 'train': 50 | self.examples.append((context, response, dialog_fpath, persona_sentences, persona_fpaths, mode)) 51 | else: 52 | assert response == dialog['nrp_candidate_responses'][turn_idx][0] 53 | self.examples.append((context, dialog['nrp_candidate_responses'][turn_idx], dialog_fpath, persona_sentences, persona_fpaths, num_examples, 0, mode)) 54 | num_examples += 1 55 | 56 | print(f'num. of {mode} dataset: {len(self.examples)}') 57 | 58 | def __len__(self): 59 | return len(self.examples) 60 | 61 | def __getitem__(self, index): 62 | mode = self.examples[index][-1] 63 | 64 | if mode == 'train': 65 | context, response, dialog_fpath, persona_sentences, persona_fpaths, mode = self.examples[index] 66 | 67 | context_inputs = self.clip_processor(text=context) 68 | context_input_ids = context_inputs['input_ids'] 69 | context_attention_mask = context_inputs['attention_mask'] 70 | 71 | if len(context_input_ids) > self.args.max_seq_length: 72 | context_input_ids = context_input_ids[len(context_input_ids) - self.args.max_seq_length:] 73 | context_attention_mask = context_attention_mask[len(context_attention_mask) - self.args.max_seq_length:] 74 | while len(context_input_ids) < self.args.max_seq_length: 75 | context_input_ids.append(self.clip_processor.tokenizer.pad_token_id) 76 | context_attention_mask.append(0) 77 | 78 | assert len(context_input_ids) == self.args.max_seq_length 79 | assert len(context_attention_mask) == self.args.max_seq_length 80 | 81 | response_inputs = self.clip_processor(text=response) 82 | response_input_ids = response_inputs['input_ids'] 83 | response_attention_mask = response_inputs['attention_mask'] 84 | 85 | if len(response_input_ids) > self.args.max_seq_length: 86 | response_input_ids = response_input_ids[len(response_input_ids) - self.args.max_seq_length:] 87 | response_attention_mask = response_attention_mask[len(response_attention_mask) - self.args.max_seq_length:] 88 | while len(response_input_ids) < self.args.max_seq_length: 89 | response_input_ids.append(self.clip_processor.tokenizer.pad_token_id) 90 | response_attention_mask.append(0) 91 | 92 | assert len(response_input_ids) == self.args.max_seq_length 93 | assert len(response_attention_mask) == self.args.max_seq_length 94 | 95 | persona_inputs = self.clip_processor(text=persona_sentences) 96 | persona_input_ids = persona_inputs['input_ids'] 97 | persona_attention_mask = persona_inputs['attention_mask'] 98 | 99 | if len(persona_input_ids) > self.args.max_seq_length: 100 | persona_input_ids = persona_input_ids[len(persona_input_ids) - self.args.max_seq_length:] 101 | persona_attention_mask = persona_attention_mask[len(persona_attention_mask) - self.args.max_seq_length:] 102 | while len(persona_input_ids) < self.args.max_seq_length: 103 | persona_input_ids.append(self.clip_processor.tokenizer.pad_token_id) 104 | persona_attention_mask.append(0) 105 | 106 | assert len(persona_input_ids) == self.args.max_seq_length 107 | assert len(persona_attention_mask) == self.args.max_seq_length 108 | 109 | if dialog_fpath == '': 110 | dialog_img_feat = torch.rand(3, self.args.img_size, self.args.img_size) 111 | dialog_img_mask = 0 112 | else: 113 | dialog_img = pil_loader(dialog_fpath) 114 | dialog_img_feat = self.clip_processor(images=dialog_img, return_tensors='pt')['pixel_values'].squeeze(0) 115 | dialog_img_mask = 1 116 | assert dialog_img_feat.shape == torch.Size([3, self.args.img_size, self.args.img_size]) 117 | 118 | persona_imgs = [pil_loader(x) for x in persona_fpaths] 119 | persona_img_feats = self.clip_processor(images=persona_imgs, return_tensors='pt')['pixel_values'] 120 | persona_img_mask = [1 for _ in persona_fpaths] 121 | if len(persona_fpaths) < self.args.max_num_imgs: 122 | empty_img_feats = torch.stack([torch.rand(3, self.args.img_size, self.args.img_size) for _ in range(self.args.max_num_imgs - len(persona_fpaths))]) 123 | persona_img_feats = torch.cat([persona_img_feats, empty_img_feats], dim=0) 124 | persona_img_mask += [0 for _ in range(self.args.max_num_imgs - len(persona_fpaths))] 125 | assert persona_img_feats.shape == torch.Size([self.args.max_num_imgs, 3, self.args.img_size, self.args.img_size]) 126 | assert len(persona_img_mask) == self.args.max_num_imgs 127 | 128 | feature = [ 129 | torch.as_tensor(context_input_ids, dtype=torch.long), 130 | torch.as_tensor(context_attention_mask, dtype=torch.long), 131 | torch.as_tensor(response_input_ids, dtype=torch.long), 132 | torch.as_tensor(response_attention_mask, dtype=torch.long), 133 | torch.as_tensor(persona_input_ids, dtype=torch.long), 134 | torch.as_tensor(persona_attention_mask, dtype=torch.long), 135 | dialog_img_feat, 136 | persona_img_feats, 137 | torch.as_tensor(dialog_img_mask, dtype=torch.long), 138 | torch.as_tensor(persona_img_mask, dtype=torch.long), 139 | ] 140 | else: 141 | context, responses, dialog_fpath, persona_sentences, persona_fpaths, example_idx, label_idx, mode = self.examples[index] 142 | 143 | context_inputs = self.clip_processor(text=context) 144 | context_input_ids = context_inputs['input_ids'] 145 | context_attention_mask = context_inputs['attention_mask'] 146 | 147 | if len(context_input_ids) > self.args.max_seq_length: 148 | context_input_ids = context_input_ids[len(context_input_ids) - self.args.max_seq_length:] 149 | context_attention_mask = context_attention_mask[len(context_attention_mask) - self.args.max_seq_length:] 150 | while len(context_input_ids) < self.args.max_seq_length: 151 | context_input_ids.append(self.clip_processor.tokenizer.pad_token_id) 152 | context_attention_mask.append(0) 153 | 154 | assert len(context_input_ids) == self.args.max_seq_length 155 | assert len(context_attention_mask) == self.args.max_seq_length 156 | 157 | response_inputs = self.clip_processor.tokenizer(responses, 158 | truncation=True, 159 | padding='max_length', 160 | max_length=self.args.max_seq_length) 161 | response_input_ids = torch.as_tensor(response_inputs['input_ids'], dtype=torch.long) 162 | response_attention_mask = torch.as_tensor(response_inputs['attention_mask'], dtype=torch.long) 163 | 164 | assert response_input_ids.shape[0] == self.args.max_num_responses and response_input_ids.shape[1] == self.args.max_seq_length 165 | assert response_attention_mask.shape[0] == self.args.max_num_responses and response_attention_mask.shape[1] == self.args.max_seq_length 166 | 167 | persona_inputs = self.clip_processor(text=persona_sentences) 168 | persona_input_ids = persona_inputs['input_ids'] 169 | persona_attention_mask = persona_inputs['attention_mask'] 170 | 171 | if len(persona_input_ids) > self.args.max_seq_length: 172 | persona_input_ids = persona_input_ids[len(persona_input_ids) - self.args.max_seq_length:] 173 | persona_attention_mask = persona_attention_mask[len(persona_attention_mask) - self.args.max_seq_length:] 174 | while len(persona_input_ids) < self.args.max_seq_length: 175 | persona_input_ids.append(self.clip_processor.tokenizer.pad_token_id) 176 | persona_attention_mask.append(0) 177 | 178 | assert len(persona_input_ids) == self.args.max_seq_length 179 | assert len(persona_attention_mask) == self.args.max_seq_length 180 | 181 | if dialog_fpath == '': 182 | dialog_img_feat = torch.rand(3, self.args.img_size, self.args.img_size) 183 | dialog_img_mask = 0 184 | else: 185 | dialog_img = pil_loader(dialog_fpath) 186 | dialog_img_feat = self.clip_processor(images=dialog_img, return_tensors='pt')['pixel_values'].squeeze(0) 187 | dialog_img_mask = 1 188 | assert dialog_img_feat.shape == torch.Size([3, self.args.img_size, self.args.img_size]) 189 | 190 | persona_imgs = [pil_loader(x) for x in persona_fpaths] 191 | persona_img_feats = self.clip_processor(images=persona_imgs, return_tensors='pt')['pixel_values'] 192 | persona_img_mask = [1 for _ in persona_fpaths] 193 | if len(persona_fpaths) < self.args.max_num_imgs: 194 | empty_img_feats = torch.stack([torch.rand(3, self.args.img_size, self.args.img_size) for _ in range(self.args.max_num_imgs - len(persona_fpaths))]) 195 | persona_img_feats = torch.cat([persona_img_feats, empty_img_feats], dim=0) 196 | persona_img_mask += [0 for _ in range(self.args.max_num_imgs - len(persona_fpaths))] 197 | assert persona_img_feats.shape == torch.Size([self.args.max_num_imgs, 3, self.args.img_size, self.args.img_size]) 198 | assert len(persona_img_mask) == self.args.max_num_imgs 199 | 200 | feature = [ 201 | torch.as_tensor(context_input_ids, dtype=torch.long), 202 | torch.as_tensor(context_attention_mask, dtype=torch.long), 203 | response_input_ids, 204 | response_attention_mask, 205 | torch.as_tensor(persona_input_ids, dtype=torch.long), 206 | torch.as_tensor(persona_attention_mask, dtype=torch.long), 207 | dialog_img_feat, 208 | persona_img_feats, 209 | torch.as_tensor(label_idx, dtype=torch.long), 210 | torch.as_tensor(example_idx, dtype=torch.long), 211 | torch.as_tensor(dialog_img_mask, dtype=torch.long), 212 | torch.as_tensor(persona_img_mask, dtype=torch.long), 213 | ] 214 | 215 | return feature 216 | 217 | class MpchatClipSbertNrpDataset(Dataset): 218 | def __init__(self, 219 | args, 220 | tokenizer, 221 | clip_processor, 222 | mode): 223 | super(MpchatClipSbertNrpDataset, self).__init__() 224 | assert mode in ['train', 'val', 'test'] 225 | 226 | self.args = args 227 | self.tokenizer = tokenizer 228 | self.clip_processor = clip_processor 229 | self.mode = mode 230 | self.examples = [] 231 | 232 | with open(os.path.join(args.dialog_data_dir, 'mpchat_nrp.json'), 'r') as fp: 233 | data = json.load(fp)[f'{mode}'] 234 | 235 | num_examples = 0 236 | for dialog_idx, dialog in enumerate(data): 237 | main_author = dialog['main_author'] 238 | turn_indices = [] 239 | for turn_idx, author in enumerate(dialog['authors']): 240 | if main_author == author: 241 | turn_indices.append(turn_idx) 242 | 243 | for turn_idx in turn_indices: 244 | context = ' '.join(dialog['messages'][:turn_idx]) 245 | response = dialog['messages'][turn_idx] 246 | persona_sentences = ' '.join([f"{x['title']}" for x in dialog['candidate_personas']]) 247 | persona_fpaths = [os.path.join(args.persona_image_data_dir, x['file_name']) for x in dialog['candidate_personas']] 248 | if dialog['has_image']: 249 | fname_context = dialog['file_name'] 250 | dialog_fpath = os.path.join(args.dialog_image_data_dir, fname_context) 251 | else: 252 | dialog_fpath = '' 253 | 254 | if mode == 'train': 255 | self.examples.append((context, response, dialog_fpath, persona_sentences, persona_fpaths, mode)) 256 | else: 257 | assert response == dialog['nrp_candidate_responses'][turn_idx][0] 258 | self.examples.append((context, dialog['nrp_candidate_responses'][turn_idx], dialog_fpath, persona_sentences, persona_fpaths, num_examples, 0, mode)) 259 | num_examples += 1 260 | 261 | print(f'num. of {mode} dataset: {len(self.examples)}') 262 | 263 | def __len__(self): 264 | return len(self.examples) 265 | 266 | def __getitem__(self, index): 267 | mode = self.examples[index][-1] 268 | 269 | if mode == 'train': 270 | context, response, dialog_fpath, persona_sentences, persona_fpaths, mode = self.examples[index] 271 | 272 | context_inputs = self.tokenizer(context, 273 | truncation=True, 274 | padding='max_length', 275 | max_length=self.args.max_seq_length) 276 | context_input_ids = context_inputs['input_ids'] 277 | context_attention_mask = context_inputs['attention_mask'] 278 | 279 | assert len(context_input_ids) == self.args.max_seq_length 280 | assert len(context_attention_mask) == self.args.max_seq_length 281 | 282 | response_inputs = self.tokenizer(response, 283 | truncation=True, 284 | padding='max_length', 285 | max_length=self.args.max_seq_length) 286 | response_input_ids = response_inputs['input_ids'] 287 | response_attention_mask = response_inputs['attention_mask'] 288 | 289 | assert len(response_input_ids) == self.args.max_seq_length 290 | assert len(response_attention_mask) == self.args.max_seq_length 291 | 292 | persona_inputs = self.tokenizer(persona_sentences, 293 | truncation=True, 294 | padding='max_length', 295 | max_length=self.args.max_seq_length) 296 | persona_input_ids = persona_inputs['input_ids'] 297 | persona_attention_mask = persona_inputs['attention_mask'] 298 | 299 | assert len(persona_input_ids) == self.args.max_seq_length 300 | assert len(persona_attention_mask) == self.args.max_seq_length 301 | 302 | if dialog_fpath == '': 303 | dialog_img_feat = torch.rand(3, self.args.img_size, self.args.img_size) 304 | dialog_img_mask = 0 305 | else: 306 | dialog_img = pil_loader(dialog_fpath) 307 | dialog_img_feat = self.clip_processor(images=dialog_img, return_tensors='pt')['pixel_values'].squeeze(0) 308 | dialog_img_mask = 1 309 | assert dialog_img_feat.shape == torch.Size([3, self.args.img_size, self.args.img_size]) 310 | 311 | persona_imgs = [pil_loader(x) for x in persona_fpaths] 312 | persona_img_feats = self.clip_processor(images=persona_imgs, return_tensors='pt')['pixel_values'] 313 | persona_img_mask = [1 for _ in persona_fpaths] 314 | if len(persona_fpaths) < self.args.max_num_imgs: 315 | empty_img_feats = torch.stack([torch.rand(3, self.args.img_size, self.args.img_size) for _ in range(self.args.max_num_imgs - len(persona_fpaths))]) 316 | persona_img_feats = torch.cat([persona_img_feats, empty_img_feats], dim=0) 317 | persona_img_mask += [0 for _ in range(self.args.max_num_imgs - len(persona_fpaths))] 318 | assert persona_img_feats.shape == torch.Size([self.args.max_num_imgs, 3, self.args.img_size, self.args.img_size]) 319 | assert len(persona_img_mask) == self.args.max_num_imgs 320 | 321 | feature = [ 322 | torch.as_tensor(context_input_ids, dtype=torch.long), 323 | torch.as_tensor(context_attention_mask, dtype=torch.long), 324 | torch.as_tensor(response_input_ids, dtype=torch.long), 325 | torch.as_tensor(response_attention_mask, dtype=torch.long), 326 | torch.as_tensor(persona_input_ids, dtype=torch.long), 327 | torch.as_tensor(persona_attention_mask, dtype=torch.long), 328 | dialog_img_feat, 329 | persona_img_feats, 330 | torch.as_tensor(dialog_img_mask, dtype=torch.long), 331 | torch.as_tensor(persona_img_mask, dtype=torch.long), 332 | ] 333 | else: 334 | context, responses, dialog_fpath, persona_sentences, persona_fpaths, example_idx, label_idx, mode = self.examples[index] 335 | 336 | context_inputs = self.tokenizer(context, 337 | truncation=True, 338 | padding='max_length', 339 | max_length=self.args.max_seq_length) 340 | context_input_ids = context_inputs['input_ids'] 341 | context_attention_mask = context_inputs['attention_mask'] 342 | 343 | assert len(context_input_ids) == self.args.max_seq_length 344 | assert len(context_attention_mask) == self.args.max_seq_length 345 | 346 | response_inputs = self.tokenizer(responses, 347 | truncation=True, 348 | padding='max_length', 349 | max_length=self.args.max_seq_length) 350 | response_input_ids = torch.as_tensor(response_inputs['input_ids'], dtype=torch.long) 351 | response_attention_mask = torch.as_tensor(response_inputs['attention_mask'], dtype=torch.long) 352 | 353 | assert response_input_ids.shape[0] == self.args.max_num_responses and response_input_ids.shape[1] == self.args.max_seq_length 354 | assert response_attention_mask.shape[0] == self.args.max_num_responses and response_attention_mask.shape[1] == self.args.max_seq_length 355 | 356 | persona_inputs = self.tokenizer(persona_sentences, 357 | truncation=True, 358 | padding='max_length', 359 | max_length=self.args.max_seq_length) 360 | persona_input_ids = persona_inputs['input_ids'] 361 | persona_attention_mask = persona_inputs['attention_mask'] 362 | 363 | assert len(persona_input_ids) == self.args.max_seq_length 364 | assert len(persona_attention_mask) == self.args.max_seq_length 365 | 366 | if dialog_fpath == '': 367 | dialog_img_feat = torch.rand(3, self.args.img_size, self.args.img_size) 368 | dialog_img_mask = 0 369 | else: 370 | dialog_img = pil_loader(dialog_fpath) 371 | dialog_img_feat = self.clip_processor(images=dialog_img, return_tensors='pt')['pixel_values'].squeeze(0) 372 | dialog_img_mask = 1 373 | assert dialog_img_feat.shape == torch.Size([3, self.args.img_size, self.args.img_size]) 374 | 375 | persona_imgs = [pil_loader(x) for x in persona_fpaths] 376 | persona_img_feats = self.clip_processor(images=persona_imgs, return_tensors='pt')['pixel_values'] 377 | persona_img_mask = [1 for _ in persona_fpaths] 378 | if len(persona_fpaths) < self.args.max_num_imgs: 379 | empty_img_feats = torch.stack([torch.rand(3, self.args.img_size, self.args.img_size) for _ in range(self.args.max_num_imgs - len(persona_fpaths))]) 380 | persona_img_feats = torch.cat([persona_img_feats, empty_img_feats], dim=0) 381 | persona_img_mask += [0 for _ in range(self.args.max_num_imgs - len(persona_fpaths))] 382 | assert persona_img_feats.shape == torch.Size([self.args.max_num_imgs, 3, self.args.img_size, self.args.img_size]) 383 | assert len(persona_img_mask) == self.args.max_num_imgs 384 | 385 | feature = [ 386 | torch.as_tensor(context_input_ids, dtype=torch.long), 387 | torch.as_tensor(context_attention_mask, dtype=torch.long), 388 | response_input_ids, 389 | response_attention_mask, 390 | torch.as_tensor(persona_input_ids, dtype=torch.long), 391 | torch.as_tensor(persona_attention_mask, dtype=torch.long), 392 | dialog_img_feat, 393 | persona_img_feats, 394 | torch.as_tensor(label_idx, dtype=torch.long), 395 | torch.as_tensor(example_idx, dtype=torch.long), 396 | torch.as_tensor(dialog_img_mask, dtype=torch.long), 397 | torch.as_tensor(persona_img_mask, dtype=torch.long), 398 | ] 399 | 400 | return feature 401 | --------------------------------------------------------------------------------