├── .gitignore ├── LICENSE ├── README.md ├── configs ├── FashionIQ_trans_config.json └── FashionIQ_trans_g2_res50_config.json ├── data ├── __init__.py ├── abc.py ├── collate_fns.py ├── fashionIQ.py └── utils.py ├── evaluators ├── __init__.py ├── abc.py ├── metric_calculators.py ├── tirg_evaluator.py ├── utils.py └── visualizers.py ├── jupyter_files ├── fashion_iq_vocab.pkl ├── how_to_create_fashion_iq_vocab.ipynb └── how_to_create_vocab.ipynb ├── language ├── __init__.py ├── abc.py ├── test_tokenizers.py ├── test_vocabulary.py ├── tokenizers.py ├── utils.py └── vocabulary.py ├── loggers ├── __init__.py ├── file_loggers.py └── wandb_loggers.py ├── losses ├── __init__.py └── batch_based_classification_loss.py ├── main.py ├── models ├── __init__.py ├── abc.py ├── attention_modules │ ├── __init__.py │ ├── self_attention.py │ └── test_self_attention.py ├── compositors │ ├── __init__.py │ ├── global_style_models.py │ └── transformers.py ├── image_encoders │ ├── __init__.py │ └── resnet.py ├── text_encoders │ ├── __init__.py │ ├── lstm.py │ ├── test_lstm.py │ ├── test_utils.py │ └── utils.py └── utils.py ├── optimizers.py ├── options ├── __init__.py ├── command_line.py └── config_file.py ├── readme_resources ├── CoSMo poster.pdf └── cosmo_fig.png ├── requirements.txt ├── set_up.py ├── trainers ├── __init__.py ├── abc.py └── tirg_trainer.py ├── transforms ├── __init__.py ├── image_transforms.py └── text_transforms.py └── utils ├── __init__.py ├── metrics.py ├── mixins.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Custom 4 | data/dataset_from_tirg.py 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | ### JetBrains template 111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 113 | 114 | # User-specific stuff 115 | .idea/**/workspace.xml 116 | .idea/**/tasks.xml 117 | .idea/**/usage.statistics.xml 118 | .idea/**/dictionaries 119 | .idea/**/shelf 120 | 121 | # Sensitive or high-churn files 122 | .idea/**/dataSources/ 123 | .idea/**/dataSources.ids 124 | .idea/**/dataSources.local.xml 125 | .idea/**/sqlDataSources.xml 126 | .idea/**/dynamic.xml 127 | .idea/**/uiDesigner.xml 128 | .idea/**/dbnavigator.xml 129 | 130 | # Gradle 131 | .idea/**/gradle.xml 132 | .idea/**/libraries 133 | 134 | # Gradle and Maven with auto-import 135 | # When using Gradle or Maven with auto-import, you should exclude module files, 136 | # since they will be recreated, and may cause churn. Uncomment if using 137 | # auto-import. 138 | # .idea/modules.xml 139 | # .idea/*.iml 140 | # .idea/modules 141 | 142 | # CMake 143 | cmake-build-*/ 144 | 145 | # Mongo Explorer plugin 146 | .idea/**/mongoSettings.xml 147 | 148 | # File-based project format 149 | *.iws 150 | 151 | # IntelliJ 152 | out/ 153 | 154 | # mpeltonen/sbt-idea plugin 155 | .idea_modules/ 156 | 157 | # JIRA plugin 158 | atlassian-ide-plugin.xml 159 | 160 | # Cursive Clojure plugin 161 | .idea/replstate.xml 162 | 163 | # Crashlytics plugin (for Android Studio and IntelliJ) 164 | com_crashlytics_export_strings.xml 165 | crashlytics.properties 166 | crashlytics-build.properties 167 | fabric.properties 168 | 169 | # Editor-based Rest Client 170 | .idea/httpRequests 171 | .idea/ 172 | 173 | # Custom 174 | mappings.zip 175 | experiments/ 176 | wandb/ 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2021 Samsung Electronics Co. LTD 2 | 3 | This software is a property of Samsung Electronics. 4 | No part of this software, either material or conceptual may be copied or distributed, transmitted, 5 | transcribed, stored in a retrieval system or translated into any human or computer language in any form by any means, 6 | electronic, mechanical, manual or otherwise, or disclosed 7 | to third parties without the express written permission of Samsung Electronics. 8 | (Use of the Software is restricted to non-commercial, personal or academic, research purpose only) 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoSMo.pytorch 2 | 3 | Official Implementation of **[CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback](https://openaccess.thecvf.com/content/CVPR2021/html/Lee_CoSMo_Content-Style_Modulation_for_Image_Retrieval_With_Text_Feedback_CVPR_2021_paper.html)**, Seungmin Lee*, Dongwan Kim*, Bohyung Han. *(*denotes equal contribution)* 4 | 5 | Presented at [CVPR2021](http://cvpr2021.thecvf.com/) 6 | 7 | [Paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Lee_CoSMo_Content-Style_Modulation_for_Image_Retrieval_With_Text_Feedback_CVPR_2021_paper.pdf) | [Poster](readme_resources/CoSMo%20poster.pdf) | [5 min Video](https://youtu.be/GPwTILo6fS4) 8 | 9 | ![fig](readme_resources/cosmo_fig.png) 10 | 11 | ## :gear: Setup 12 | Python: python3.7 13 | 14 | ### :package: Install required packages 15 | 16 | Install torch and torchvision via following command (CUDA10) 17 | 18 | ```bash 19 | pip install torch==1.2.0 torchvision==0.4.0 -f https://download.pytorch.org/whl/torch_stable.html 20 | ``` 21 | 22 | Install other packages 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ### :open_file_folder: Dataset 28 | Download the FashionIQ dataset by following the instructions on this [link](https://github.com/XiaoxiaoGuo/fashion-iq). 29 | 30 | We have set the default path for FashionIQ datasets in [data/fashionIQ.py](data/fashionIQ.py) as `_DEFAULT_FASHION_IQ_DATASET_ROOT = '/data/image_retrieval/fashionIQ'`. You can change this path to wherever you plan on storing the dataset. 31 | 32 | #### :arrows_counterclockwise: Update Dec 8th, 2021 33 | It seems like more and more download links to FashionIQ images are being taken down. As a *temporary* solution, we have uploaded a version of the dataset in [Google Drive](https://drive.google.com/drive/folders/14JG_w0V58iex62bVUHSBDYGBUECbDdx9?usp=sharing). Please be aware that this link is not permanent, and may be taken down in the near future. 34 | 35 | ### :books: Vocabulary file 36 | Open up a python console and run the following lines to download NLTK punkt: 37 | ```python 38 | import nltk 39 | nltk.download('punkt') 40 | ``` 41 | 42 | Then, open up a Jupyter notebook and run [jupyter_files/how_to_create_fashion_iq_vocab.ipynb](jupyter_files/how_to_create_fashion_iq_vocab.ipynb). As with the dataset, the default path is set in [data/fashionIQ.py](data/fashionIQ.py). 43 | 44 | **We have provided a vocab file in `jupyter_files/fashion_iq_vocab.pkl`.** 45 | 46 | ### :chart_with_upwards_trend: Weights & Biases 47 | We use [Weights and Biases](https://wandb.ai/) to log our experiments. 48 | 49 | If you already have a Weights & Biases account, head over to `configs/FashionIQ_trans_g2_res50_config.json` and fill out your `wandb_account_name`. You can also change the default at `options/command_line.py`. 50 | 51 | If you do not have a Weights & Biases account, you can either create one or change the code and logging functions to your liking. 52 | 53 | ## :running_man: Run 54 | 55 | You can run the code by the following command: 56 | ```bash 57 | python main.py --config_path=configs/FashionIQ_trans_g2_res50_config.json --experiment_description=test_cosmo_fashionIQDress --device_idx=0,1,2,3 58 | ``` 59 | 60 | Note that you do not need to assign `--device_idx` if you have already specified `CUDA_VISIBLE_DEVICES=0,1,2,3` in your terminal. 61 | 62 | We run on 4 12GB GPUs, and the main gpu `gpu:0` uses around 4GB of VRAM. 63 | 64 | ### :warning: Notes on Evaluation 65 | In our paper, we mentioned that we use a slightly different evaluation method than the original FashionIQ dataset. This was done to match the evaluation method used by [VAL](https://openaccess.thecvf.com/content_CVPR_2020/html/Chen_Image_Search_With_Text_Feedback_by_Visiolinguistic_Attention_Learning_CVPR_2020_paper.html). 66 | 67 | By default, this code uses the proper evaluation method (as intended by the creators of the dataset). The results for this is shown in our supplementary materials. If you'd like to use the same evaluation method as our main paper (and VAL), head over to [data/fashionIQ.py](data/fashionIQ.py#L129) and uncomment the commented section. 68 | 69 | 70 | ## :scroll: Citation 71 | If you use our code, please cite our work: 72 | ``` 73 | @InProceedings{CoSMo2021_CVPR, 74 | author = {Lee, Seungmin and Kim, Dongwan and Han, Bohyung}, 75 | title = {CoSMo: Content-Style Modulation for Image Retrieval With Text Feedback}, 76 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 77 | month = {June}, 78 | year = {2021}, 79 | pages = {802-812} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /configs/FashionIQ_trans_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch": 40, 3 | "batch_size": 128, 4 | "num_workers": 16, 5 | "metric_loss": "batch_based_classification_loss", 6 | "global_styler": "global2", 7 | "selector": "all", 8 | "feature_size": 512, 9 | "text_encoder": "lstm", 10 | "compositor": "transformer", 11 | "image_encoder": "resnet18_layer4", 12 | "norm_scale": 4, 13 | "optimizer": "RAdam", 14 | "weight_decay": 5e-5, 15 | "momentum": 0.9, 16 | "decay_step": 10, 17 | "gamma": 0.1, 18 | "random_seed": 0, 19 | "experiment_dir": "experiments", 20 | "lr": 2e-4, 21 | "topk": "1,5,10,50", 22 | "num_heads": 8, 23 | "dataset": "fashionIQ_dress", 24 | "device_idx": "", 25 | "experiment_description": "" 26 | } -------------------------------------------------------------------------------- /configs/FashionIQ_trans_g2_res50_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch": 50, 3 | "batch_size": 32, 4 | "num_workers": 16, 5 | "metric_loss": "batch_based_classification_loss", 6 | "selector": "all", 7 | "global_styler": "global2", 8 | "lstm_hidden_size": 512, 9 | "text_feature_size": 512, 10 | "image_encoder": "resnet50_layer4", 11 | "text_encoder": "lstm", 12 | "compositor": "transformer", 13 | "norm_scale": 4, 14 | "optimizer": "RAdam", 15 | "weight_decay": 5e-5, 16 | "momentum": 0.9, 17 | "decay_step": 30, 18 | "gamma": 0.1, 19 | "random_seed": 0, 20 | "experiment_dir": "experiments", 21 | "num_heads": 8, 22 | "lr": 2e-4, 23 | "topk": "1,5,10,50", 24 | "margin": 12, 25 | "dataset": "fashionIQ_dress", 26 | "wandb_project_name": "CoSMo.pytorch", 27 | "wandb_account_name": "your_account_name_here", 28 | "device_idx": "", 29 | "experiment_description": "" 30 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader, Subset 3 | 4 | from data.collate_fns import PaddingCollateFunction, PaddingCollateFunctionTest 5 | from data.fashionIQ import FashionIQDataset, FashionIQTestDataset, FashionIQTestQueryDataset 6 | from language import AbstractBaseVocabulary 7 | 8 | DEFAULT_VOCAB_PATHS = { 9 | **dict.fromkeys(FashionIQDataset.all_codes(), FashionIQDataset.vocab_path()) 10 | } 11 | 12 | 13 | def _random_indices(dataset_length, limit_size): 14 | return np.random.randint(0, dataset_length, limit_size) 15 | 16 | 17 | def train_dataset_factory(transforms, config): 18 | image_transform = transforms['image_transform'] 19 | text_transform = transforms['text_transform'] 20 | dataset_code = config['dataset'] 21 | use_subset = config.get('use_subset', False) 22 | 23 | if FashionIQDataset.code() in dataset_code: 24 | dataset_clothing_split = dataset_code.split("_") 25 | if len(dataset_clothing_split) == 1: 26 | raise ValueError("Please specify clothing type for this dataset: fashionIQ_[dress_type]") 27 | clothing_type = dataset_clothing_split[1] 28 | dataset = FashionIQDataset(split='train', clothing_type=clothing_type, img_transform=image_transform, 29 | text_transform=text_transform) 30 | else: 31 | raise ValueError("There's no {} dataset".format(dataset_code)) 32 | 33 | if use_subset: 34 | return Subset(dataset, _random_indices(len(dataset), 1000)) 35 | 36 | return dataset 37 | 38 | 39 | def test_dataset_factory(transforms, config, split='val'): 40 | image_transform = transforms['image_transform'] 41 | text_transform = transforms['text_transform'] 42 | dataset_code = config['dataset'] 43 | use_subset = config.get('use_subset', False) 44 | 45 | if FashionIQDataset.code() in dataset_code: 46 | dataset_clothing_split = dataset_code.split("_") 47 | if len(dataset_clothing_split) == 1: 48 | raise ValueError("Please specify clothing type for this dataset: fashionIQ_[dress_type]") 49 | clothing_type = dataset_clothing_split[1] 50 | test_samples_dataset = FashionIQTestDataset(split=split, clothing_type=clothing_type, 51 | img_transform=image_transform, text_transform=text_transform) 52 | test_query_dataset = FashionIQTestQueryDataset(split=split, clothing_type=clothing_type, 53 | img_transform=image_transform, text_transform=text_transform) 54 | else: 55 | raise ValueError("There's no {} dataset".format(dataset_code)) 56 | 57 | if use_subset: 58 | return {"samples": Subset(test_samples_dataset, _random_indices(len(test_samples_dataset), 1000)), 59 | "query": Subset(test_query_dataset, _random_indices(len(test_query_dataset), 1000))} 60 | 61 | return {"samples": test_samples_dataset, 62 | "query": test_query_dataset} 63 | 64 | 65 | def train_dataloader_factory(dataset, config, collate_fn=None): 66 | batch_size = config['batch_size'] 67 | num_workers = config.get('num_workers', 16) 68 | shuffle = config.get('shuffle', True) 69 | # TODO: remove this 70 | drop_last = batch_size == 32 71 | 72 | return DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, 73 | collate_fn=collate_fn, drop_last=drop_last) 74 | 75 | 76 | def test_dataloader_factory(datasets, config, collate_fn=None): 77 | batch_size = config['batch_size'] 78 | num_workers = config.get('num_workers', 16) 79 | shuffle = False 80 | 81 | return { 82 | 'query': DataLoader(datasets['query'], batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, 83 | collate_fn=collate_fn), 84 | 'samples': DataLoader(datasets['samples'], batch_size, shuffle=shuffle, num_workers=num_workers, 85 | pin_memory=True, 86 | collate_fn=collate_fn) 87 | } 88 | 89 | 90 | def create_dataloaders(image_transform, text_transform, configs): 91 | train_dataset = train_dataset_factory( 92 | transforms={'image_transform': image_transform['train'], 'text_transform': text_transform['train']}, 93 | config=configs) 94 | test_datasets = test_dataset_factory( 95 | transforms={'image_transform': image_transform['val'], 'text_transform': text_transform['val']}, 96 | config=configs) 97 | train_val_datasets = test_dataset_factory( 98 | transforms={'image_transform': image_transform['val'], 'text_transform': text_transform['val']}, 99 | config=configs, split='train') 100 | collate_fn = PaddingCollateFunction(padding_idx=AbstractBaseVocabulary.pad_id()) 101 | collate_fn_test = PaddingCollateFunctionTest(padding_idx=AbstractBaseVocabulary.pad_id()) 102 | train_dataloader = train_dataloader_factory(dataset=train_dataset, config=configs, collate_fn=collate_fn) 103 | test_dataloaders = test_dataloader_factory(datasets=test_datasets, config=configs, collate_fn=collate_fn_test) 104 | train_val_dataloaders = test_dataloader_factory(datasets=train_val_datasets, config=configs, 105 | collate_fn=collate_fn_test) 106 | return train_dataloader, test_dataloaders, train_val_dataloaders 107 | -------------------------------------------------------------------------------- /data/abc.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class AbstractBaseDataset(Dataset, abc.ABC): 8 | """Base class for a dataset.""" 9 | 10 | def __init__(self, root_path, split='train', img_transform=None, text_transform=None): 11 | pass 12 | 13 | @classmethod 14 | @abc.abstractmethod 15 | def code(cls): 16 | raise NotImplementedError 17 | 18 | @classmethod 19 | @abc.abstractmethod 20 | def vocab_path(cls): 21 | raise NotImplementedError 22 | 23 | 24 | class AbstractBaseTestDataset(Dataset, abc.ABC): 25 | @abc.abstractmethod 26 | def sample_img_for_visualizing(self, gt) -> Image: 27 | """return gt image""" 28 | raise NotImplementedError 29 | -------------------------------------------------------------------------------- /data/collate_fns.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | 7 | class PaddingCollateFunction(object): 8 | def __init__(self, padding_idx): 9 | self.padding_idx = padding_idx 10 | 11 | def __call__(self, batch: List[tuple]): 12 | reference_images, target_images, modifiers, lengths = zip(*batch) 13 | 14 | reference_images = torch.stack(reference_images, dim=0) 15 | target_images = torch.stack(target_images, dim=0) 16 | seq_lengths = torch.tensor(lengths).long() 17 | modifiers = pad_sequence(modifiers, padding_value=self.padding_idx, batch_first=True) 18 | return reference_images, target_images, modifiers, seq_lengths 19 | 20 | 21 | class PaddingCollateFunctionTest(object): 22 | def __init__(self, padding_idx): 23 | self.padding_idx = padding_idx 24 | 25 | @staticmethod 26 | def _collate_test_dataset(batch): 27 | reference_images, ids = zip(*batch) 28 | reference_images = torch.stack(reference_images, dim=0) 29 | return reference_images, ids 30 | 31 | def _collate_test_query_dataset(self, batch): 32 | reference_images, ref_attrs, modifiers, target_attrs, lengths = zip(*batch) 33 | reference_images = torch.stack(reference_images, dim=0) 34 | seq_lengths = torch.tensor(lengths).long() 35 | modifiers = pad_sequence(modifiers, padding_value=self.padding_idx, batch_first=True) 36 | return reference_images, ref_attrs, modifiers, target_attrs, seq_lengths 37 | 38 | def __call__(self, batch: List[tuple]): 39 | num_items = len(batch[0]) 40 | if num_items > 2: 41 | return self._collate_test_query_dataset(batch) 42 | else: 43 | return self._collate_test_dataset(batch) 44 | -------------------------------------------------------------------------------- /data/fashionIQ.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import os 4 | 5 | from data.utils import _get_img_from_path 6 | from data.abc import AbstractBaseDataset, AbstractBaseTestDataset 7 | 8 | _DEFAULT_FASHION_IQ_DATASET_ROOT = '/data/image_retrieval/fashionIQ' 9 | _DEFAULT_FASHION_IQ_VOCAB_PATH = '/data/image_retrieval/fashionIQ/fashion_iq_vocab.pkl' 10 | 11 | 12 | def _get_img_caption_json(dataset_root, clothing_type, split): 13 | with open(os.path.join(dataset_root, 'captions', 'cap.{}.{}.json'.format(clothing_type, split))) as json_file: 14 | img_caption_data = json.load(json_file) 15 | return img_caption_data 16 | 17 | 18 | def _get_img_split_json_as_list(dataset_root, clothing_type, split): 19 | with open(os.path.join(dataset_root, 'image_splits', 'split.{}.{}.json'.format(clothing_type, split))) as json_file: 20 | img_split_list = json.load(json_file) 21 | return img_split_list 22 | 23 | 24 | def _create_img_path_from_id(root, id): 25 | return os.path.join(root, '{}.jpg'.format(id)) 26 | 27 | 28 | def _get_img_path_using_idx(img_caption_data, img_root, idx, is_ref=True): 29 | img_caption_pair = img_caption_data[idx] 30 | key = 'candidate' if is_ref else 'target' 31 | 32 | img = _create_img_path_from_id(img_root, img_caption_pair[key]) 33 | id = img_caption_pair[key] 34 | return img, id 35 | 36 | 37 | def _get_modifier(img_caption_data, idx, reverse=False): 38 | img_caption_pair = img_caption_data[idx] 39 | cap1, cap2 = img_caption_pair['captions'] 40 | return _create_modifier_from_attributes(cap1, cap2) if not reverse else _create_modifier_from_attributes(cap2, cap1) 41 | 42 | 43 | def _create_modifier_from_attributes(ref_attribute, targ_attribute): 44 | return ref_attribute + " and " + targ_attribute 45 | 46 | 47 | class AbstractBaseFashionIQDataset(AbstractBaseDataset): 48 | 49 | @classmethod 50 | def code(cls): 51 | return 'fashionIQ' 52 | 53 | @classmethod 54 | def all_codes(cls): 55 | return ['fashionIQ_dress', 'fashionIQ_shirt', 'fashionIQ_toptee'] 56 | 57 | @classmethod 58 | def vocab_path(cls): 59 | return _DEFAULT_FASHION_IQ_VOCAB_PATH 60 | 61 | 62 | class FashionIQDataset(AbstractBaseFashionIQDataset): 63 | """ 64 | Fashion200K dataset. 65 | Image pairs in {root_path}/image_pairs/{split}_pairs.pkl 66 | 67 | """ 68 | 69 | def __init__(self, root_path=_DEFAULT_FASHION_IQ_DATASET_ROOT, clothing_type='dress', split='train', 70 | img_transform=None, text_transform=None): 71 | super().__init__(root_path, split, img_transform, text_transform) 72 | self.root_path = root_path 73 | self.img_root_path = os.path.join(self.root_path, 'images') 74 | self.clothing_type = clothing_type 75 | self.split = split 76 | self.img_transform = img_transform 77 | self.text_transform = text_transform 78 | self.img_caption_data = _get_img_caption_json(root_path, clothing_type, split) 79 | 80 | def __getitem__(self, idx): 81 | safe_idx = idx // 2 82 | reverse = (idx % 2 == 1) 83 | 84 | ref_img_path, _ = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, is_ref=True) 85 | targ_img_path, _ = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, is_ref=False) 86 | reference_img = _get_img_from_path(ref_img_path, self.img_transform) 87 | target_img = _get_img_from_path(targ_img_path, self.img_transform) 88 | 89 | modifier = _get_modifier(self.img_caption_data, safe_idx, reverse=reverse) 90 | modifier = self.text_transform(modifier) if self.text_transform else modifier 91 | 92 | return reference_img, target_img, modifier, len(modifier) 93 | 94 | def get_original_item(self, idx): 95 | safe_idx = idx // 2 96 | reverse = (idx % 2 == 1) 97 | 98 | ref_img_path, _ = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, is_ref=True) 99 | targ_img_path, _ = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, is_ref=False) 100 | reference_img = _get_img_from_path(ref_img_path) 101 | target_img = _get_img_from_path(targ_img_path) 102 | 103 | modifier = _get_modifier(self.img_caption_data, safe_idx, reverse=reverse) 104 | 105 | return reference_img, target_img, modifier, len(modifier) 106 | 107 | def __len__(self): 108 | return len(self.img_caption_data) * 2 109 | 110 | 111 | class FashionIQTestDataset(AbstractBaseFashionIQDataset, AbstractBaseTestDataset): 112 | """ 113 | FashionIQ Test (Samples) dataset. 114 | indexing returns target samples and their unique ID 115 | """ 116 | 117 | def __init__(self, root_path=_DEFAULT_FASHION_IQ_DATASET_ROOT, clothing_type='dress', split='val', 118 | img_transform=None, text_transform=None): 119 | super().__init__(root_path, split, img_transform, text_transform) 120 | self.root_path = root_path 121 | self.img_root_path = os.path.join(self.root_path, 'images') 122 | self.clothing_type = clothing_type 123 | self.img_transform = img_transform 124 | self.text_transform = text_transform 125 | 126 | self.img_list = _get_img_split_json_as_list(root_path, clothing_type, split) 127 | 128 | ''' Uncomment below for VAL Evaluation method ''' 129 | # self.img_caption_data = _get_img_caption_json(root_path, clothing_type, split) 130 | # self.img_list = [] 131 | # for d in self.img_caption_data: 132 | # self.img_list.append(d['target']) 133 | # self.img_list.append(d['candidate']) 134 | # self.img_list = list(set(self.img_list)) 135 | 136 | def __getitem__(self, idx, use_transform=True): 137 | img_transform = self.img_transform if use_transform else None 138 | img_id = self.img_list[idx] 139 | img_path = _create_img_path_from_id(self.img_root_path, img_id) 140 | 141 | target_img = _get_img_from_path(img_path, img_transform) 142 | 143 | return target_img, img_id 144 | 145 | def sample_img_for_visualizing(self, gt): 146 | img_path = _create_img_path_from_id(self.img_root_path, gt) 147 | img = _get_img_from_path(img_path, None) 148 | return img 149 | 150 | def __len__(self): 151 | return len(self.img_list) 152 | 153 | 154 | class FashionIQTestQueryDataset(AbstractBaseFashionIQDataset): 155 | """ 156 | FashionIQ Test (Query) dataset. 157 | indexing returns ref samples, modifier, target attribute (caption, text) and modifier length 158 | """ 159 | 160 | def __init__(self, root_path=_DEFAULT_FASHION_IQ_DATASET_ROOT, clothing_type='dress', split='val', 161 | img_transform=None, text_transform=None): 162 | super().__init__(root_path, split, img_transform, text_transform) 163 | self.root_path = root_path 164 | self.img_root_path = os.path.join(self.root_path, 'images') 165 | self.clothing_type = clothing_type 166 | self.img_transform = img_transform 167 | self.text_transform = text_transform 168 | 169 | self.img_caption_data = _get_img_caption_json(root_path, clothing_type, split) 170 | 171 | def __getitem__(self, idx, use_transform=True): 172 | safe_idx = idx // 2 173 | reverse = (idx % 2 == 1) 174 | 175 | img_transform = self.img_transform if use_transform else None 176 | text_transform = self.text_transform if use_transform else None 177 | ref_img_path, ref_id = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, is_ref=True) 178 | targ_img_path, targ_id = _get_img_path_using_idx(self.img_caption_data, self.img_root_path, safe_idx, 179 | is_ref=False) 180 | ref_img = _get_img_from_path(ref_img_path, img_transform) 181 | 182 | modifier = _get_modifier(self.img_caption_data, safe_idx, reverse=reverse) 183 | modifier = text_transform(modifier) if text_transform else modifier 184 | 185 | return ref_img, ref_id, modifier, targ_id, len(modifier) 186 | 187 | def __len__(self): 188 | return len(self.img_caption_data) * 2 189 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def _get_img_from_path(img_path, transform=None): 5 | with open(img_path, 'rb') as f: 6 | img = Image.open(f).convert('RGB') 7 | if transform is not None: 8 | img = transform(img) 9 | return img 10 | -------------------------------------------------------------------------------- /evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluators.tirg_evaluator import SimpleEvaluator 2 | 3 | 4 | def get_evaluator_cls(configs): 5 | evaluator_code = configs['evaluator'] 6 | if evaluator_code == 'simple': 7 | return SimpleEvaluator 8 | else: 9 | raise ValueError("There's no evaluator that has {} as a code".format(evaluator_code)) 10 | -------------------------------------------------------------------------------- /evaluators/abc.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from evaluators.metric_calculators import ValidationMetricsCalculator 8 | from evaluators.utils import multiple_index_from_attribute_list 9 | from utils.metrics import AverageMeterSet 10 | 11 | 12 | class AbstractBaseEvaluator(abc.ABC): 13 | def __init__(self, models, dataloaders, top_k=(1, 10, 50), visualizer=None): 14 | self.models = models 15 | self.test_samples_dataloader = dataloaders['samples'] 16 | self.test_query_dataloader = dataloaders['query'] 17 | self.top_k = top_k if type(top_k) is tuple else tuple([int(k) for k in top_k.split(",")]) 18 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.visualizer = visualizer 20 | self.attribute_matching_matrix = None 21 | self.ref_matching_matrix = None 22 | 23 | def evaluate(self, epoch): 24 | all_results = {} 25 | all_test_features, all_test_attributes = self.extract_test_features_and_attributes() 26 | all_original_query_features, all_composed_query_features, all_query_attributes, all_ref_attributes = \ 27 | self.extract_query_features_and_attributes() 28 | 29 | # Make sure test_loader is not shuffled! Otherwise, this will be incorrect 30 | if self.attribute_matching_matrix is None: 31 | self.attribute_matching_matrix = self._calculate_attribute_matching_matrix(all_query_attributes, 32 | all_test_attributes) 33 | if self.ref_matching_matrix is None: 34 | self.ref_matching_matrix = self._calculate_attribute_matching_matrix(all_ref_attributes, 35 | all_test_attributes) 36 | 37 | recall_calculator = ValidationMetricsCalculator(all_original_query_features, all_composed_query_features, 38 | all_test_features, self.attribute_matching_matrix, 39 | self.ref_matching_matrix, self.top_k) 40 | recall_results = recall_calculator() 41 | all_results.update(recall_results) 42 | print(all_results) 43 | 44 | pos_similarity_histogram, neg_similarity_histogram, orig_pos_similarity_histogram = \ 45 | recall_calculator.get_similarity_histogram() 46 | all_results.update({"composed_pos_similarities": pos_similarity_histogram, 47 | "composed_neg_similarities": neg_similarity_histogram, 48 | "original_pos_similarities": orig_pos_similarity_histogram}) 49 | 50 | positive_samples_info = recall_calculator.get_positive_sample_info(num_samples=3, num_imgs_per_sample=5, 51 | positive_at_k=min(self.top_k)) 52 | negative_samples_info = recall_calculator.get_negative_sample_info(num_samples=3, num_imgs_per_sample=5, 53 | negative_at_k=min(self.top_k)) 54 | 55 | if self.visualizer is not None: 56 | positive_visualizations = self.visualizer(positive_samples_info, is_positive=True) 57 | negative_visualizations = self.visualizer(negative_samples_info, is_positive=False) 58 | all_results.update(positive_visualizations) 59 | all_results.update(negative_visualizations) 60 | return all_results, recall_calculator 61 | 62 | @abc.abstractmethod 63 | def _extract_image_features(self, images): 64 | raise NotImplementedError 65 | 66 | @abc.abstractmethod 67 | def _extract_original_and_composed_features(self, images, modifiers, len_modifiers): 68 | raise NotImplementedError 69 | 70 | def extract_test_features_and_attributes(self): 71 | """ 72 | Returns: (1) torch.Tensor of all test features, with shape (N_test, Embed_size) 73 | (2) list of test attributes, Size = N_test 74 | """ 75 | self._to_eval_mode() 76 | 77 | dataloader = tqdm(self.test_samples_dataloader) 78 | all_test_attributes = [] 79 | all_test_features = [] 80 | with torch.no_grad(): 81 | for batch_idx, (test_images, test_attr) in enumerate(dataloader): 82 | batch_size = test_images.size(0) 83 | test_images = test_images.to(self.device) 84 | 85 | features = self._extract_image_features(test_images) 86 | features = features.view(batch_size, -1).cpu() 87 | 88 | all_test_features.extend(features) 89 | all_test_attributes.extend(test_attr) 90 | 91 | return torch.stack(all_test_features), all_test_attributes 92 | 93 | def extract_query_features_and_attributes(self): 94 | """ 95 | Returns: (1) torch.Tensor of all query features, with shape (N_query, Embed_size) 96 | (2) list of target attributes, Size = N_query 97 | """ 98 | self._to_eval_mode() 99 | 100 | dataloader = tqdm(self.test_query_dataloader) 101 | all_target_attributes = [] 102 | all_ref_attributes = [] 103 | all_composed_query_features = [] 104 | all_original_query_features = [] 105 | 106 | with torch.no_grad(): 107 | for batch_idx, (ref_images, ref_attribute, modifiers, target_attribute, len_modifiers) in enumerate( 108 | dataloader): 109 | batch_size = ref_images.size(0) 110 | ref_images = ref_images.to(self.device) 111 | modifiers, len_modifiers = modifiers.to(self.device), len_modifiers.to(self.device) 112 | 113 | original_features, composed_features = \ 114 | self._extract_original_and_composed_features(ref_images, modifiers, len_modifiers) 115 | original_features = original_features.view(batch_size, -1).cpu() 116 | composed_features = composed_features.view(batch_size, -1).cpu() 117 | all_original_query_features.extend(original_features) 118 | all_composed_query_features.extend(composed_features) 119 | all_target_attributes.extend(target_attribute) 120 | all_ref_attributes.extend(ref_attribute) 121 | return torch.stack(all_original_query_features), torch.stack( 122 | all_composed_query_features), all_target_attributes, all_ref_attributes 123 | 124 | def _to_eval_mode(self, keys=None): 125 | keys = keys if keys else self.models.keys() 126 | for key in keys: 127 | self.models[key].eval() 128 | 129 | def _calculate_recall_at_k(self, most_similar_idx, all_test_attributes, all_target_attributes): 130 | average_meter_set = AverageMeterSet() 131 | 132 | for k in self.top_k: 133 | k_most_similar_idx = most_similar_idx[:, :k] 134 | for i, row in enumerate(k_most_similar_idx): 135 | most_similar_sample_attributes = multiple_index_from_attribute_list(all_test_attributes, row) 136 | target_attribute = all_target_attributes[i] 137 | correct = 1 if target_attribute in most_similar_sample_attributes else 0 138 | average_meter_set.update('recall_@{}'.format(k), correct) 139 | recall_results = average_meter_set.averages() 140 | return recall_results 141 | 142 | @staticmethod 143 | def _calculate_attribute_matching_matrix(all_query_attributes, all_test_attributes): 144 | all_query_attributes, all_test_attributes = np.array(all_query_attributes).reshape((-1, 1)), \ 145 | np.array(all_test_attributes).reshape((1, -1)) 146 | return all_test_attributes == all_query_attributes 147 | -------------------------------------------------------------------------------- /evaluators/metric_calculators.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import numpy as np 5 | import wandb 6 | 7 | from utils.metrics import AverageMeterSet 8 | 9 | 10 | class ValidationMetricsCalculator: 11 | def __init__(self, original_query_features: torch.tensor, composed_query_features: torch.tensor, 12 | test_features: torch.tensor, attribute_matching_matrix: np.array, 13 | ref_attribute_matching_matrix: np.array, top_k: tuple): 14 | self.original_query_features = original_query_features 15 | self.composed_query_features = composed_query_features 16 | self.test_features = test_features 17 | self.top_k = top_k 18 | self.attribute_matching_matrix = attribute_matching_matrix 19 | self.ref_attribute_matching_matrix = ref_attribute_matching_matrix 20 | self.num_query_features = composed_query_features.size(0) 21 | self.num_test_features = test_features.size(0) 22 | self.similarity_matrix = torch.zeros(self.num_query_features, self.num_test_features) 23 | self.top_scores = torch.zeros(self.num_query_features, max(top_k)) 24 | self.most_similar_idx = torch.zeros(self.num_query_features, max(top_k)) 25 | self.recall_results = {} 26 | self.recall_positive_queries_idxs = {k: [] for k in top_k} 27 | self.similarity_matrix_calculated = False 28 | self.top_scores_calculated = False 29 | 30 | def __call__(self): 31 | self._calculate_similarity_matrix() 32 | # Filter query_feat == target_feat 33 | assert self.similarity_matrix.shape == self.ref_attribute_matching_matrix.shape 34 | self.similarity_matrix[self.ref_attribute_matching_matrix == True] = self.similarity_matrix.min() 35 | return self._calculate_recall_at_k() 36 | 37 | def _calculate_similarity_matrix(self) -> torch.tensor: 38 | """ 39 | query_features = torch.tensor. Size = (N_test_query, Embed_size) 40 | test_features = torch.tensor. Size = (N_test_dataset, Embed_size) 41 | output = torch.tensor, similarity matrix. Size = (N_test_query, N_test_dataset) 42 | """ 43 | if not self.similarity_matrix_calculated: 44 | self.similarity_matrix = self.composed_query_features.mm(self.test_features.t()) 45 | self.similarity_matrix_calculated = True 46 | 47 | def _calculate_recall_at_k(self): 48 | average_meter_set = AverageMeterSet() 49 | self.top_scores, self.most_similar_idx = self.similarity_matrix.topk(max(self.top_k)) 50 | self.top_scores_calculated = True 51 | topk_attribute_matching = np.take_along_axis(self.attribute_matching_matrix, self.most_similar_idx.numpy(), 52 | axis=1) 53 | 54 | for k in self.top_k: 55 | query_matched_vector = topk_attribute_matching[:, :k].sum(axis=1).astype(bool) 56 | self.recall_positive_queries_idxs[k] = list(np.where(query_matched_vector > 0)[0]) 57 | num_correct = query_matched_vector.sum() 58 | num_samples = len(query_matched_vector) 59 | average_meter_set.update('recall_@{}'.format(k), num_correct, n=num_samples) 60 | recall_results = average_meter_set.averages() 61 | return recall_results 62 | 63 | def get_positive_sample_info(self, num_samples, num_imgs_per_sample, positive_at_k): 64 | info = [] 65 | num_samples = min(num_samples, len(self.recall_positive_queries_idxs[positive_at_k])) 66 | for ref_idx in random.sample(self.recall_positive_queries_idxs[positive_at_k], num_samples): 67 | targ_img_ids = self.most_similar_idx[ref_idx, :num_imgs_per_sample].tolist() 68 | targ_scores = self.top_scores[ref_idx, :num_imgs_per_sample].tolist() 69 | gt_idx = np.where(self.attribute_matching_matrix[ref_idx, :] == True)[0] 70 | gt_score = self.similarity_matrix[ref_idx, gt_idx[0]].item() 71 | info.append( 72 | {'ref_idx': ref_idx, 'targ_idxs': targ_img_ids, 'targ_scores': targ_scores, 'gt_score': gt_score}) 73 | return info 74 | 75 | def get_negative_sample_info(self, num_samples, num_imgs_per_sample, negative_at_k): 76 | info = [] 77 | negative_idxs_list = list( 78 | set(range(self.num_query_features)) - set(self.recall_positive_queries_idxs[negative_at_k])) 79 | num_samples = min(num_samples, len(negative_idxs_list)) 80 | for ref_idx in random.sample(negative_idxs_list, num_samples): 81 | targ_img_ids = self.most_similar_idx[ref_idx, :num_imgs_per_sample].tolist() 82 | targ_scores = self.top_scores[ref_idx, :num_imgs_per_sample].tolist() 83 | gt_idx = np.where(self.attribute_matching_matrix[ref_idx, :] == True)[0] 84 | gt_score = self.similarity_matrix[ref_idx, gt_idx[0]].item() 85 | info.append( 86 | {'ref_idx': ref_idx, 'targ_idxs': targ_img_ids, 'targ_scores': targ_scores, 'gt_score': gt_score}) 87 | return info 88 | 89 | def get_similarity_histogram(self, negative_hist_topk=10) -> (wandb.Histogram, wandb.Histogram, wandb.Histogram): 90 | self._calculate_similarity_matrix() 91 | sim_matrix_np = self.similarity_matrix.numpy() 92 | original_features_sim_matrix_np = self.original_query_features.mm(self.test_features.t()).numpy() 93 | 94 | if not self.top_scores_calculated: 95 | self.top_scores, self.most_similar_idx = self.similarity_matrix.topk(max(self.top_k)) 96 | 97 | # Get the scores of negative images that are in topk=negative_hist_topk 98 | hardest_k_negative_mask = np.zeros_like(self.attribute_matching_matrix) 99 | np.put_along_axis(hardest_k_negative_mask, self.most_similar_idx[:, :negative_hist_topk].numpy(), True, axis=1) 100 | hardest_k_negative_mask = hardest_k_negative_mask & ~self.attribute_matching_matrix 101 | 102 | composed_positive_score_distr = sim_matrix_np[self.attribute_matching_matrix] 103 | composed_negative_score_distr = sim_matrix_np[hardest_k_negative_mask] 104 | original_positive_score_distr = original_features_sim_matrix_np[self.attribute_matching_matrix] 105 | 106 | composed_pos_histogram = wandb.Histogram(np_histogram=np.histogram(composed_positive_score_distr, bins=200)) 107 | composed_neg_histogram = wandb.Histogram(np_histogram=np.histogram(composed_negative_score_distr, bins=200)) 108 | original_pos_histogram = wandb.Histogram(np_histogram=np.histogram(original_positive_score_distr, bins=200)) 109 | 110 | return composed_pos_histogram, composed_neg_histogram, original_pos_histogram 111 | 112 | @staticmethod 113 | def _multiple_index_from_attribute_list(attribute_list, indices): 114 | attributes = [] 115 | for idx in indices: 116 | attributes.append(attribute_list[idx.item()]) 117 | return attributes 118 | -------------------------------------------------------------------------------- /evaluators/tirg_evaluator.py: -------------------------------------------------------------------------------- 1 | from evaluators.abc import AbstractBaseEvaluator 2 | 3 | 4 | class SimpleEvaluator(AbstractBaseEvaluator): 5 | def __init__(self, models, dataloaders, top_k=(1, 10, 50), visualizer=None): 6 | super().__init__(models, dataloaders, top_k, visualizer) 7 | self.lower_image_encoder = self.models['lower_image_encoder'] 8 | self.upper_image_encoder = self.models['upper_image_encoder'] 9 | self.text_encoder = self.models['text_encoder'] 10 | self.compositor = self.models['layer4'] 11 | 12 | def _extract_image_features(self, images): 13 | mid_features, _ = self.lower_image_encoder(images) 14 | return self.upper_image_encoder(mid_features) 15 | 16 | def _extract_original_and_composed_features(self, images, modifiers, len_modifiers): 17 | mid_image_features, _ = self.lower_image_encoder(images) 18 | text_features = self.text_encoder(modifiers, len_modifiers) 19 | composed_features, _ = self.compositor(mid_image_features, text_features) 20 | return self.upper_image_encoder(mid_image_features), self.upper_image_encoder(composed_features) 21 | -------------------------------------------------------------------------------- /evaluators/utils.py: -------------------------------------------------------------------------------- 1 | def multiple_index_from_attribute_list(attribute_list, indices): 2 | attributes = [] 3 | for idx in indices: 4 | attributes.append(attribute_list[idx.item()]) 5 | return attributes 6 | -------------------------------------------------------------------------------- /evaluators/visualizers.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from PIL import ImageOps, Image 3 | 4 | 5 | def draw_border(img, color='red'): 6 | return ImageOps.expand(img, border=5, fill=color) 7 | 8 | 9 | class RecallVisualizer(object): 10 | def __init__(self, test_dataloaders): 11 | self.test_dataset = test_dataloaders['samples'].dataset 12 | self.query_dataset = test_dataloaders['query'].dataset 13 | 14 | def __call__(self, sample_info, is_positive=True): 15 | prefix_label = "positive sample " if is_positive else "negative_sample" 16 | visualization_dict = {} 17 | for i, info in enumerate(sample_info): 18 | ref_idx = info['ref_idx'] 19 | targ_idxs = info['targ_idxs'] 20 | targ_scores = info['targ_scores'] 21 | gt_score = info['gt_score'] 22 | img_data = [] 23 | ref_img, ref_gt, modifier, targ_gt, _ = self.query_dataset.__getitem__(ref_idx, use_transform=False) 24 | ref_caption = 'Ref: {}'.format(modifier) 25 | formatted_ref_img = self._crop_and_center_img(ref_img) 26 | img_data.append(wandb.Image(formatted_ref_img, caption=ref_caption)) 27 | 28 | # Load GT 29 | gt_img = self.test_dataset.sample_img_for_visualizing(targ_gt) 30 | formatted_gt_img = self._crop_and_center_img(gt_img) 31 | img_data.append(wandb.Image(formatted_gt_img, caption='GT: {:.3f}'.format(gt_score))) 32 | 33 | for score, targ_idx in zip(targ_scores, targ_idxs): 34 | targ_img, targ_attr = self.test_dataset.__getitem__(targ_idx, use_transform=False) 35 | caption = targ_attr + ": {:.3f}".format(score) 36 | border_color = 'green' if targ_attr == targ_gt else 'red' 37 | formatted_targ_img = draw_border(targ_img, color=border_color) 38 | formatted_targ_img = self._crop_and_center_img(formatted_targ_img) 39 | img_data.append(wandb.Image(formatted_targ_img, caption=caption)) 40 | visualization_dict[prefix_label + str(i)] = img_data 41 | return visualization_dict 42 | 43 | @staticmethod 44 | def _crop_and_center_img(img, background_size=(300, 500)): 45 | background_w, background_h = background_size 46 | background = Image.new('RGB', background_size, (255, 255, 255)) 47 | img_w, img_h = img.size 48 | reduce_rate_w = background_w / img_w 49 | reduce_rate_h = background_h / background_w 50 | epsilon = 2 51 | if int(img_h * reduce_rate_w) <= background_h + epsilon: 52 | new_size = (int(img_w * reduce_rate_w), int(img_h * reduce_rate_w)) 53 | else: 54 | new_size = (int(img_w * reduce_rate_h), int(img_h * reduce_rate_h)) 55 | resized_img = img.resize(new_size) 56 | offset = ((background_w - new_size[0]) // 2, (background_h - new_size[1]) // 2) 57 | background.paste(resized_img, offset) 58 | return background 59 | -------------------------------------------------------------------------------- /jupyter_files/fashion_iq_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/postBG/CosMo.pytorch/ab4af36fafd7a0214d55810f42e74b034fb31f3c/jupyter_files/fashion_iq_vocab.pkl -------------------------------------------------------------------------------- /jupyter_files/how_to_create_fashion_iq_vocab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append(\"../\")" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from data.fashionIQ import FashionIQDataset, FashionIQTestQueryDataset\n", 20 | "\n", 21 | "clothing_types = [\"dress\", \"shirt\", \"toptee\"]\n", 22 | "datasets = []\n", 23 | "for clothing in clothing_types:\n", 24 | " datasets.append(FashionIQDataset(clothing_type=clothing))\n", 25 | " datasets.append(FashionIQTestQueryDataset(clothing_type=clothing))\n", 26 | "caption_positions = [2 for _ in datasets]" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from language.vocabulary import SimpleVocabulary\n", 36 | "from language.tokenizers import NltkTokenizer\n", 37 | "from language.utils import create_read_func, create_write_func" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": { 44 | "scrolled": true 45 | }, 46 | "outputs": [ 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "100%|██████████| 11970/11970 [00:55<00:00, 216.73it/s]\n", 52 | "100%|██████████| 4034/4034 [00:09<00:00, 437.02it/s]\n", 53 | "100%|██████████| 11976/11976 [00:56<00:00, 213.67it/s]\n", 54 | "100%|██████████| 4076/4076 [00:09<00:00, 420.40it/s]\n", 55 | "100%|██████████| 12054/12054 [00:57<00:00, 208.97it/s]\n", 56 | "100%|██████████| 3922/3922 [00:09<00:00, 409.20it/s]\n" 57 | ] 58 | }, 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "" 63 | ] 64 | }, 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "tokenizer = NltkTokenizer()\n", 72 | "write_func = create_write_func('fashion_iq_vocab.pkl')\n", 73 | "read_func = create_read_func('fashion_iq_vocab.pkl')\n", 74 | "SimpleVocabulary.create_and_store_vocabulary_from_datasets(datasets, tokenizer, write_func, caption_pos=caption_positions)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "4676\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "vocab = SimpleVocabulary.create_vocabulary_from_storage(read_func)\n", 92 | "print(len(vocab))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "dongwan", 106 | "language": "python", 107 | "name": "dongwan" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.7.7" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 4 124 | } 125 | -------------------------------------------------------------------------------- /jupyter_files/how_to_create_vocab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append(\"../\")" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from data.mit_states import MITStatesDataset\n", 20 | "\n", 21 | "mit_train = MITStatesDataset(split='train')\n", 22 | "mit_test = MITStatesDataset(split='test')\n", 23 | "datasets = [mit_train, mit_test]" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "text/plain": [ 34 | "" 35 | ] 36 | }, 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "from language.vocabulary import SimpleVocabulary\n", 44 | "from language.tokenizers import BasicTokenizer\n", 45 | "from language.utils import create_read_func, create_write_func\n", 46 | "\n", 47 | "tokenizer = BasicTokenizer()\n", 48 | "read_func = create_read_func('test.pkl')\n", 49 | "write_func = create_write_func('test.pkl')\n", 50 | "SimpleVocabulary.create_and_store_vocabulary_from_datasets(datasets, tokenizer, write_func)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 5, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "vocab = SimpleVocabulary.create_vocabulary_from_storage(read_func)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 6, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "{0: '',\n", 71 | " 1: '',\n", 72 | " 2: '',\n", 73 | " 3: '',\n", 74 | " 4: 'tiny',\n", 75 | " 5: 'huge',\n", 76 | " 6: 'old',\n", 77 | " 7: 'young',\n", 78 | " 8: 'burnt',\n", 79 | " 9: 'small',\n", 80 | " 10: 'large',\n", 81 | " 11: 'weathered',\n", 82 | " 12: 'modern',\n", 83 | " 13: 'unpainted',\n", 84 | " 14: 'chipped',\n", 85 | " 15: 'painted',\n", 86 | " 16: 'ancient',\n", 87 | " 17: 'dirty',\n", 88 | " 18: 'frozen',\n", 89 | " 19: 'spilled',\n", 90 | " 20: 'fresh',\n", 91 | " 21: 'melted',\n", 92 | " 22: 'moldy',\n", 93 | " 23: 'narrow',\n", 94 | " 24: 'frayed',\n", 95 | " 25: 'thin',\n", 96 | " 26: 'folded',\n", 97 | " 27: 'wide',\n", 98 | " 28: 'engraved',\n", 99 | " 29: 'ruffled',\n", 100 | " 30: 'thick',\n", 101 | " 31: 'brushed',\n", 102 | " 32: 'broken',\n", 103 | " 33: 'muddy',\n", 104 | " 34: 'dry',\n", 105 | " 35: 'eroded',\n", 106 | " 36: 'barren',\n", 107 | " 37: 'windblown',\n", 108 | " 38: 'verdant',\n", 109 | " 39: 'mossy',\n", 110 | " 40: 'crushed',\n", 111 | " 41: 'molten',\n", 112 | " 42: 'whipped',\n", 113 | " 43: 'caramelized',\n", 114 | " 44: 'crumpled',\n", 115 | " 45: 'wilted',\n", 116 | " 46: 'pressed',\n", 117 | " 47: 'crinkled',\n", 118 | " 48: 'deflated',\n", 119 | " 49: 'cored',\n", 120 | " 50: 'coiled',\n", 121 | " 51: 'rusty',\n", 122 | " 52: 'cracked',\n", 123 | " 53: 'draped',\n", 124 | " 54: 'pierced',\n", 125 | " 55: 'shiny',\n", 126 | " 56: 'dented',\n", 127 | " 57: 'dull',\n", 128 | " 58: 'blunt',\n", 129 | " 59: 'curved',\n", 130 | " 60: 'sharp',\n", 131 | " 61: 'straight',\n", 132 | " 62: 'bent',\n", 133 | " 63: 'clean',\n", 134 | " 64: 'empty',\n", 135 | " 65: 'full',\n", 136 | " 66: 'cluttered',\n", 137 | " 67: 'grimy',\n", 138 | " 68: 'diced',\n", 139 | " 69: 'sliced',\n", 140 | " 70: 'pureed',\n", 141 | " 71: 'ripe',\n", 142 | " 72: 'unripe',\n", 143 | " 73: 'peeled',\n", 144 | " 74: 'new',\n", 145 | " 75: 'browned',\n", 146 | " 76: 'cooked',\n", 147 | " 77: 'thawed',\n", 148 | " 78: 'raw',\n", 149 | " 79: 'clear',\n", 150 | " 80: 'steaming',\n", 151 | " 81: 'heavy',\n", 152 | " 82: 'lightweight',\n", 153 | " 83: 'torn',\n", 154 | " 84: 'shattered',\n", 155 | " 85: 'fallen',\n", 156 | " 86: 'creased',\n", 157 | " 87: 'foggy',\n", 158 | " 88: 'squished',\n", 159 | " 89: 'runny',\n", 160 | " 90: 'viscous',\n", 161 | " 91: 'cut',\n", 162 | " 92: 'rough',\n", 163 | " 93: 'smooth',\n", 164 | " 94: 'mashed',\n", 165 | " 95: 'loose',\n", 166 | " 96: 'tight',\n", 167 | " 97: 'wet',\n", 168 | " 98: 'wrinkled',\n", 169 | " 99: 'worn',\n", 170 | " 100: 'damp',\n", 171 | " 101: 'splintered',\n", 172 | " 102: 'filled',\n", 173 | " 103: 'dark',\n", 174 | " 104: 'bright',\n", 175 | " 105: 'inflated',\n", 176 | " 106: 'ripped',\n", 177 | " 107: 'scratched',\n", 178 | " 108: 'toppled',\n", 179 | " 109: 'upright',\n", 180 | " 110: 'short',\n", 181 | " 111: 'tall',\n", 182 | " 112: 'murky',\n", 183 | " 113: 'winding',\n", 184 | " 114: 'sunny',\n", 185 | " 115: 'standing',\n", 186 | " 116: 'closed',\n", 187 | " 117: 'cloudy',\n", 188 | " 118: 'open'}" 189 | ] 190 | }, 191 | "execution_count": 6, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "vocab._id2token" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python 3", 204 | "language": "python", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.7.6" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 4 222 | } 223 | -------------------------------------------------------------------------------- /language/__init__.py: -------------------------------------------------------------------------------- 1 | from language.abc import AbstractBaseVocabulary 2 | from language.tokenizers import BasicTokenizer 3 | from language.utils import create_read_func 4 | from language.vocabulary import SimpleVocabulary 5 | 6 | 7 | # TODO: Automatically generate vocab file 8 | def vocabulary_factory(config): 9 | vocab_path = config['vocab_path'] 10 | vocab_threshold = config['vocab_threshold'] 11 | 12 | read_func = create_read_func(vocab_path) 13 | 14 | vocab = SimpleVocabulary.create_vocabulary_from_storage(read_func) 15 | vocab.threshold_rare_words(vocab_threshold) 16 | return vocab 17 | -------------------------------------------------------------------------------- /language/abc.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import abc 4 | 5 | _UNK_TOKEN = '' 6 | _BOS_TOKEN = '' 7 | _EOS_TOKEN = '' 8 | _PAD_TOKEN = '' 9 | _DEFAULT_TOKEN2ID = {_PAD_TOKEN: 0, _UNK_TOKEN: 1, _BOS_TOKEN: 2, _EOS_TOKEN: 3} 10 | 11 | 12 | class AbstractBaseVocabulary(abc.ABC): 13 | @abc.abstractmethod 14 | def add_text_to_vocab(self, text): 15 | raise NotImplementedError 16 | 17 | @abc.abstractmethod 18 | def convert_text_to_ids(self, text: str) -> List[int]: 19 | raise NotImplementedError 20 | 21 | @abc.abstractmethod 22 | def convert_ids_to_text(self, ids: List[int]) -> str: 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def threshold_rare_words(self, wordcount_threshold=5): 27 | raise NotImplementedError 28 | 29 | @staticmethod 30 | def pad_id(): 31 | return _DEFAULT_TOKEN2ID[_PAD_TOKEN] 32 | 33 | @staticmethod 34 | def eos_id(): 35 | return _DEFAULT_TOKEN2ID[_EOS_TOKEN] 36 | 37 | @abc.abstractmethod 38 | def __len__(self): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /language/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from language import tokenizers 4 | from language.abc import _BOS_TOKEN, _EOS_TOKEN, _PAD_TOKEN 5 | 6 | 7 | class TestBasicTokenizer(unittest.TestCase): 8 | def test_basic_tokenizers_make_text_lowercase_and_add_start_end_tokens_when_tokenizing(self): 9 | tokenizer = tokenizers.BasicTokenizer() 10 | self.assertListEqual([_BOS_TOKEN, 'hi', 'bye', 'mama', _EOS_TOKEN], tokenizer.tokenize('hi, bye, mama!')) 11 | 12 | def test_basic_tokenizers_remove_start_end_tokens_and_padding_when_detokenizing(self): 13 | tokenizer = tokenizers.BasicTokenizer() 14 | tokens = ['1', '2', _BOS_TOKEN, 'hi', _PAD_TOKEN, 'bye', 'mama', _PAD_TOKEN, _EOS_TOKEN, '3', '4'] 15 | self.assertEqual('hi bye mama', tokenizer.detokenize(tokens)) 16 | 17 | 18 | class TestNltkTokenizer(unittest.TestCase): 19 | def test_nltk_tokenizers_make_text_lowercase_and_add_start_end_tokens_when_tokenizing(self): 20 | tokenizer = tokenizers.NltkTokenizer() 21 | self.assertListEqual([_BOS_TOKEN, 'hi', ',', 'bye', ',', 'mama', '!', _EOS_TOKEN], 22 | tokenizer.tokenize('hi, bye, Mama!')) 23 | 24 | def test_basic_tokenizers_remove_start_end_tokens_and_padding_when_detokenizing(self): 25 | tokenizer = tokenizers.NltkTokenizer() 26 | tokens = ['1', '2', _BOS_TOKEN, 'hi', _PAD_TOKEN, 'bye', 'mama', _PAD_TOKEN, _EOS_TOKEN, '3', '4'] 27 | self.assertEqual('hi bye mama', tokenizer.detokenize(tokens)) 28 | 29 | 30 | if __name__ == '__main__': 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /language/test_vocabulary.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from language.tokenizers import BasicTokenizer 4 | from language.vocabulary import SimpleVocabulary 5 | 6 | 7 | class TestSimpleVocabulary(unittest.TestCase): 8 | def setUp(self): 9 | self.vocabulary = SimpleVocabulary(BasicTokenizer()) 10 | self.vocabulary.add_text_to_vocab("Huge cat") 11 | self.vocabulary.add_text_to_vocab("Tiny Tiger") 12 | self.vocabulary.add_text_to_vocab("Huge Tiger") 13 | self.vocabulary.add_text_to_vocab("tiny cat") 14 | 15 | def test_tokenizing_then_detokenizing_reproduces_the_same_text_when_there_is_no_unknown_word(self): 16 | text = 'Huge Tiger Tiny cat' 17 | ids = self.vocabulary.convert_text_to_ids(text) 18 | self.assertEqual(text.lower(), self.vocabulary.convert_ids_to_text(ids)) 19 | 20 | def test_tokenizing_and_detokenizing_replaces_unknown_text_with_unknown_token(self): 21 | text = 'Huge Tiger and Tiny cat' 22 | ids = self.vocabulary.convert_text_to_ids(text) 23 | self.assertEqual("huge tiger tiny cat", self.vocabulary.convert_ids_to_text(ids)) 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /language/tokenizers.py: -------------------------------------------------------------------------------- 1 | import string 2 | from typing import List 3 | 4 | import nltk 5 | 6 | from language.vocabulary import AbstractBaseTokenizer 7 | 8 | _punctuation_translator = str.maketrans('', '', string.punctuation) 9 | 10 | 11 | class BasicTokenizer(AbstractBaseTokenizer): 12 | def _tokenize(self, text): 13 | tokens = str(text).lower().translate(_punctuation_translator).strip().split() 14 | return tokens 15 | 16 | def _detokenize(self, tokens: List[str]) -> str: 17 | return ' '.join(tokens) 18 | 19 | 20 | class NltkTokenizer(AbstractBaseTokenizer): 21 | def _tokenize(self, text: str) -> List[str]: 22 | return nltk.tokenize.word_tokenize(text.lower()) 23 | 24 | def _detokenize(self, tokens: List[str]) -> str: 25 | return ' '.join(tokens) 26 | -------------------------------------------------------------------------------- /language/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def create_read_func(vocab_path): 5 | def read_func(): 6 | with open(vocab_path, 'rb') as f: 7 | data = pickle.load(f) 8 | return data 9 | 10 | return read_func 11 | 12 | 13 | def create_write_func(vocab_path): 14 | def write_func(data): 15 | with open(vocab_path, 'wb') as f: 16 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 17 | 18 | return write_func 19 | -------------------------------------------------------------------------------- /language/vocabulary.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List 3 | 4 | import abc 5 | from tqdm import tqdm 6 | 7 | from language.abc import _UNK_TOKEN, _BOS_TOKEN, _EOS_TOKEN, _PAD_TOKEN, _DEFAULT_TOKEN2ID, AbstractBaseVocabulary 8 | 9 | 10 | class AbstractBaseTokenizer(abc.ABC): 11 | def tokenize(self, text: str) -> List[str]: 12 | return [_BOS_TOKEN] + self._tokenize(text) + [_EOS_TOKEN] 13 | 14 | def detokenize(self, tokens: List[str]) -> str: 15 | start_idx = tokens.index(_BOS_TOKEN) 16 | end_idx = tokens.index(_EOS_TOKEN) 17 | tokens = tokens[start_idx + 1: end_idx] 18 | tokens = list(filter(_PAD_TOKEN.__ne__, tokens)) 19 | return self._detokenize(tokens) 20 | 21 | @abc.abstractmethod 22 | def _tokenize(self, text: str) -> List[str]: 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def _detokenize(self, tokens: List[str]) -> str: 27 | raise NotImplementedError 28 | 29 | 30 | # TODO: We can add read / write to persistent memory for optimizing this process 31 | class SimpleVocabulary(AbstractBaseVocabulary): 32 | def __init__(self, tokenizer: AbstractBaseTokenizer): 33 | self.tokenizer = tokenizer 34 | self._token2id = _DEFAULT_TOKEN2ID 35 | self._id2token = {i: token for token, i in _DEFAULT_TOKEN2ID.items()} 36 | self._token_count = defaultdict(int) 37 | self._token_count[_UNK_TOKEN] = int(9e9) 38 | self._token_count[_PAD_TOKEN] = int(9e9) 39 | self._token_count[_BOS_TOKEN] = int(9e9) 40 | self._token_count[_EOS_TOKEN] = int(9e9) 41 | 42 | def add_text_to_vocab(self, text): 43 | tokens = self.tokenizer.tokenize(text) 44 | for token in tokens: 45 | if token not in self._token2id: 46 | idx = len(self._token2id) 47 | self._token2id[token] = idx 48 | self._id2token[idx] = token 49 | self._token_count[token] += 1 50 | 51 | def threshold_rare_words(self, wordcount_threshold=5): 52 | for w in self._token2id: 53 | if self._token_count[w] < wordcount_threshold: 54 | self._token2id[w] = _DEFAULT_TOKEN2ID[_UNK_TOKEN] 55 | 56 | def convert_text_to_ids(self, text): 57 | tokens = self.tokenizer.tokenize(text) 58 | encoded_text = [self._token2id.get(t, _DEFAULT_TOKEN2ID[_UNK_TOKEN]) for t in tokens] 59 | return encoded_text 60 | 61 | def convert_ids_to_text(self, ids): 62 | tokens = [self._id2token.get(token_id, _UNK_TOKEN) for token_id in ids] 63 | return self.tokenizer.detokenize(tokens) 64 | 65 | def __len__(self): 66 | return len(self._token2id) 67 | 68 | @staticmethod 69 | def create_and_store_vocabulary_from_txt_files(txt_file_paths, tokenizer, write_func, txt_reader_func): 70 | vocab = SimpleVocabulary(tokenizer) 71 | for txt_path in txt_file_paths: 72 | texts = txt_reader_func(txt_path) 73 | for t in tqdm(texts): 74 | vocab.add_text_to_vocab(t) 75 | write_func(vocab) 76 | return vocab 77 | 78 | @staticmethod 79 | def create_and_store_vocabulary_from_list(list_data, tokenizer, write_func): 80 | vocab = SimpleVocabulary(tokenizer) 81 | for l in tqdm(list_data): 82 | vocab.add_text_to_vocab(l) 83 | write_func(vocab) 84 | return vocab 85 | 86 | @staticmethod 87 | def create_and_store_vocabulary_from_datasets(datasets, tokenizer, write_func, caption_pos=(2, 1)): 88 | vocab = SimpleVocabulary(tokenizer) 89 | for pos, dataset in zip(caption_pos, datasets): 90 | for record in tqdm(dataset): 91 | vocab.add_text_to_vocab(record[pos]) 92 | write_func(vocab) 93 | return vocab 94 | 95 | @staticmethod 96 | def create_vocabulary_from_storage(read_func): 97 | return read_func() 98 | -------------------------------------------------------------------------------- /loggers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /loggers/file_loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from trainers.abc import AbstractBaseLogger 6 | 7 | 8 | def _checkpoint_file_path(export_path, filename): 9 | return os.path.join(export_path, filename) 10 | 11 | 12 | def _set_up_path(path): 13 | if not os.path.exists(path): 14 | os.mkdir(path) 15 | 16 | 17 | def _save_state_dict_with_step(log_data, step, path, filename): 18 | log_data = {k: v for k, v in log_data.items() if isinstance(v, dict)} 19 | log_data['step'] = step 20 | torch.save(log_data, _checkpoint_file_path(path, filename)) 21 | 22 | 23 | class RecentModelTracker(AbstractBaseLogger): 24 | def __init__(self, export_path, ckpt_filename='recent.pth'): 25 | self.export_path = export_path 26 | _set_up_path(self.export_path) 27 | self.ckpt_filename = ckpt_filename 28 | 29 | def log(self, log_data, step, commit=False): 30 | _save_state_dict_with_step(log_data, step, self.export_path, self.ckpt_filename) 31 | 32 | def complete(self, log_data, step): 33 | pass 34 | 35 | 36 | class BestModelTracker(AbstractBaseLogger): 37 | def __init__(self, export_path, ckpt_filename='best.pth', metric_key='recall_@10'): 38 | self.export_path = export_path 39 | _set_up_path(self.export_path) 40 | 41 | self.metric_key = metric_key 42 | self.ckpt_filename = ckpt_filename 43 | 44 | self.best_value = -9e9 45 | 46 | def log(self, log_data, step, commit=False): 47 | if self.metric_key not in log_data: 48 | print("WARNING: The key: {} is not in logged data. Not saving best model".format(self.metric_key)) 49 | return 50 | recent_value = log_data[self.metric_key] 51 | if self.best_value < recent_value: 52 | self.best_value = recent_value 53 | _save_state_dict_with_step(log_data, step, self.export_path, self.ckpt_filename) 54 | print("Update Best {} Model at Step {} with value {}".format(self.metric_key, step, self.best_value)) 55 | 56 | def complete(self, *args, **kwargs): 57 | pass 58 | -------------------------------------------------------------------------------- /loggers/wandb_loggers.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | from trainers.abc import AbstractBaseLogger 4 | 5 | 6 | class WandbSimplePrinter(AbstractBaseLogger): 7 | def __init__(self, prefix): 8 | self.prefix = prefix 9 | 10 | def log(self, log_data, step, commit=False): 11 | log_metrics = {self.prefix + k: v for k, v in log_data.items() if not isinstance(v, dict)} 12 | wandb.log(log_metrics, step=step, commit=commit) 13 | 14 | def complete(self, log_data, step): 15 | self.log(log_data, step) 16 | 17 | 18 | class WandbSummaryPrinter(AbstractBaseLogger): 19 | def __init__(self, prefix, summary_keys: list): 20 | self.prefix = prefix 21 | self.summary_keys = summary_keys 22 | self.previous_best_vals = {key: 0 for key in self.summary_keys} 23 | 24 | def log(self, log_data, step, commit=False): 25 | for key in self.summary_keys: 26 | if key in log_data: 27 | log_value = log_data[key] 28 | if log_value > self.previous_best_vals[key]: 29 | wandb.run.summary[self.prefix+key] = log_value 30 | self.previous_best_vals[key] = log_value 31 | 32 | def complete(self, log_data, step): 33 | self.log(log_data, step) 34 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from losses.batch_based_classification_loss import BatchBasedClassificationLoss 2 | 3 | 4 | def loss_factory(config): 5 | return { 6 | 'metric_loss': BatchBasedClassificationLoss(), 7 | } 8 | -------------------------------------------------------------------------------- /losses/batch_based_classification_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from trainers.abc import AbstractBaseMetricLoss 5 | 6 | 7 | class BatchBasedClassificationLoss(AbstractBaseMetricLoss): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, ref_features, tar_features): 12 | batch_size = ref_features.size(0) 13 | device = ref_features.device 14 | 15 | pred = ref_features.mm(tar_features.transpose(0, 1)) 16 | labels = torch.arange(0, batch_size).long().to(device) 17 | return F.cross_entropy(pred, labels) 18 | 19 | @classmethod 20 | def code(cls): 21 | return 'batch_based_classification_loss' 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from data import DEFAULT_VOCAB_PATHS, create_dataloaders 2 | from evaluators import get_evaluator_cls 3 | from evaluators.visualizers import RecallVisualizer 4 | from language import vocabulary_factory 5 | from loggers.file_loggers import BestModelTracker 6 | from loggers.wandb_loggers import WandbSimplePrinter, WandbSummaryPrinter 7 | from losses import loss_factory 8 | from models import create_models 9 | from optimizers import create_optimizers, create_lr_schedulers 10 | from options import get_experiment_config 11 | from set_up import setup_experiment 12 | from trainers import get_trainer_cls 13 | from transforms import image_transform_factory, text_transform_factory 14 | 15 | 16 | def main(): 17 | configs = get_experiment_config() 18 | export_root, configs = setup_experiment(configs) 19 | 20 | vocabulary = vocabulary_factory(config={ 21 | 'vocab_path': configs['vocab_path'] if configs['vocab_path'] else DEFAULT_VOCAB_PATHS[configs['dataset']], 22 | 'vocab_threshold': configs['vocab_threshold'] 23 | }) 24 | image_transform = image_transform_factory(config=configs) 25 | text_transform = text_transform_factory(config={'vocabulary': vocabulary}) 26 | train_dataloader, test_dataloaders, train_val_dataloaders = create_dataloaders(image_transform, text_transform, 27 | configs) 28 | criterions = loss_factory(configs) 29 | models = create_models(configs, vocabulary) 30 | optimizers = create_optimizers(models=models, config=configs) 31 | lr_schedulers = create_lr_schedulers(optimizers, config=configs) 32 | train_loggers = [WandbSimplePrinter('train/')] 33 | summary_keys = get_summary_keys(configs) 34 | best_metric_key = [key for key in summary_keys if '10' in key][0] 35 | val_loggers = [WandbSimplePrinter('val/'), WandbSummaryPrinter('best_', summary_keys), 36 | BestModelTracker(export_root, metric_key=best_metric_key)] 37 | visualizer = RecallVisualizer(test_dataloaders) 38 | evaluator = get_evaluator_cls(configs)(models, test_dataloaders, top_k=configs['topk'], visualizer=visualizer) 39 | train_evaluator = get_evaluator_cls(configs)(models, train_val_dataloaders, top_k=configs['topk'], 40 | visualizer=None) 41 | trainer = get_trainer_cls(configs)(models, train_dataloader, criterions, optimizers, lr_schedulers, 42 | configs['epoch'], train_loggers, val_loggers, evaluator, train_evaluator, 43 | start_epoch=0) 44 | trainer.run() 45 | 46 | 47 | def get_summary_keys(configs): 48 | summary_keys = ['recall_@{}'.format(k) for k in configs['topk'].split(",")] 49 | return summary_keys 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.compositors import transformer_factory 2 | from models.image_encoders import image_encoder_factory 3 | from models.text_encoders import text_encoder_factory 4 | from utils.mixins import GradientControlDataParallel 5 | 6 | 7 | def create_models(configs, vocabulary): 8 | text_encoder = text_encoder_factory(vocabulary, config=configs) 9 | lower_img_encoder, upper_img_encoder = image_encoder_factory(config=configs) 10 | 11 | layer_shapes = lower_img_encoder.layer_shapes() 12 | compositors = transformer_factory({'layer4': layer_shapes['layer4'], 13 | 'image_feature_size': upper_img_encoder.feature_size, 14 | 'text_feature_size': text_encoder.feature_size}, configs=configs) 15 | 16 | models = { 17 | 'text_encoder': text_encoder, 18 | 'lower_image_encoder': lower_img_encoder, 19 | 'upper_image_encoder': upper_img_encoder 20 | } 21 | models.update(compositors) 22 | 23 | if configs['num_gpu'] >= 1: 24 | for name, model in models.items(): 25 | models[name] = GradientControlDataParallel(model.cuda()) 26 | 27 | return models 28 | -------------------------------------------------------------------------------- /models/abc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /models/attention_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/postBG/CosMo.pytorch/ab4af36fafd7a0214d55810f42e74b034fb31f3c/models/attention_modules/__init__.py -------------------------------------------------------------------------------- /models/attention_modules/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.utils import reshape_text_features_to_concat 5 | 6 | 7 | class AttentionModule(nn.Module): 8 | def __init__(self, feature_size, text_feature_size, num_heads, *args, **kwargs): 9 | super().__init__() 10 | 11 | self.n_heads = num_heads 12 | self.c_per_head = feature_size // num_heads 13 | assert feature_size == self.n_heads * self.c_per_head 14 | 15 | self.self_att_generator = SelfAttentionMap(feature_size, num_heads, *args, **kwargs) 16 | self.global_att_generator = GlobalCrossAttentionMap(feature_size, text_feature_size, num_heads, *args, **kwargs) 17 | 18 | self.merge = nn.Conv2d(feature_size + text_feature_size, feature_size, kernel_size=1, bias=False) 19 | self.W_v = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False) 20 | self.W_r = nn.Conv2d(feature_size, feature_size, kernel_size=1) 21 | 22 | def forward(self, x, t, return_map=False, *args, **kwargs): 23 | b, c, h, w = x.size() 24 | 25 | t_reshaped = reshape_text_features_to_concat(t, x.size()) 26 | vl_features = self.merge(torch.cat([x, t_reshaped], dim=1)) # (b, c, h, w) 27 | 28 | values = self.W_v(vl_features) 29 | values = values.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w) 30 | 31 | self_att_map = self.self_att_generator(x) # (b, num_heads, h * w, h * w) 32 | global_cross_att_map = self.global_att_generator(x, t) 33 | global_cross_att_map = global_cross_att_map.view(b, self.n_heads, 1, h * w) # (b, num_heads, 1, h * w) 34 | att_map = self_att_map + global_cross_att_map # (b, num_heads, h * w, h * w) 35 | att_map_reshaped = att_map.view(b * self.n_heads, h * w, h * w) # (b * num_heads, h * w, h * w) 36 | 37 | att_out = torch.bmm(values, att_map_reshaped.transpose(1, 2)) # (b * num_heads, c_per_head, h * w) 38 | att_out = att_out.view(b, self.n_heads * self.c_per_head, h * w) 39 | att_out = att_out.view(b, self.n_heads * self.c_per_head, h, w) 40 | att_out = self.W_r(att_out) 41 | 42 | return att_out, att_map if return_map else att_out 43 | 44 | 45 | class SelfAttentionMap(nn.Module): 46 | def __init__(self, feature_size, num_heads, *args, **kwargs): 47 | super().__init__() 48 | 49 | self.n_heads = num_heads 50 | self.c_per_head = feature_size // num_heads 51 | assert feature_size == self.n_heads * self.c_per_head 52 | 53 | self.W_k = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False) 54 | self.W_q = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False) 55 | self.softmax = nn.Softmax(dim=2) 56 | 57 | def forward(self, x, *args, **kwargs): 58 | b, c, h, w = x.size() 59 | 60 | keys, queries = self.W_k(x), self.W_q(x) 61 | keys = keys.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w) 62 | queries = queries.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w) 63 | 64 | att_map = torch.bmm(queries.transpose(1, 2), keys) / (self.c_per_head ** 0.5) 65 | att_map = self.softmax(att_map) # (b * num_heads, h * w, h * w), torch.sum(att_map[batch_idx][?]) == 1 66 | att_map = att_map.view(b, self.n_heads, h * w, h * w) 67 | 68 | return att_map 69 | 70 | 71 | class GlobalCrossAttentionMap(nn.Module): 72 | def __init__(self, feature_size, text_feature_size, num_heads, normalizer=None, *args, **kwargs): 73 | super().__init__() 74 | 75 | self.n_heads = num_heads 76 | self.c_per_head = feature_size // num_heads 77 | assert feature_size == self.n_heads * self.c_per_head 78 | 79 | self.W_t = nn.Linear(text_feature_size, feature_size) 80 | self.normalize = normalizer if normalizer else nn.Softmax(dim=1) 81 | 82 | def forward(self, x, t): 83 | b, c, h, w = x.size() 84 | 85 | x_reshape = x.view(b * self.n_heads, self.c_per_head, h, w) 86 | x_reshape = x_reshape.view(b * self.n_heads, self.c_per_head, h * w) 87 | 88 | t_mapped = self.W_t(t) 89 | t_mapped = t_mapped.view(b * self.n_heads, self.c_per_head, 1) 90 | 91 | att_map = torch.bmm(x_reshape.transpose(1, 2), t_mapped).squeeze(-1) / (self.c_per_head ** 0.5) 92 | att_map = self.normalize(att_map) # (b * n_heads, h * w) 93 | att_map = att_map.view(b * self.n_heads, 1, h * w) 94 | att_map = att_map.view(b, self.n_heads, h * w) 95 | 96 | return att_map 97 | -------------------------------------------------------------------------------- /models/attention_modules/test_self_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from models.attention_modules.self_attention import SelfAttentionMap, GlobalCrossAttentionMap, AttentionModule 6 | 7 | 8 | class TestSelfAttentionMap(unittest.TestCase): 9 | b, c, h, w = 2, 10, 5, 5 10 | t_size = 8 11 | n_head = 2 12 | 13 | def test_attention_map(self): 14 | m = SelfAttentionMap(self.c, self.n_head) 15 | x = torch.randn(self.b, self.c, self.h, self.w) 16 | map = m(x) 17 | 18 | self.assertTupleEqual((self.b, self.n_head, self.h * self.w, self.h * self.w), map.size()) 19 | for b_i in range(self.b): 20 | for head_i in range(self.n_head): 21 | for pos in range(self.h * self.w): 22 | self.assertTrue(1, torch.sum(map[b_i][head_i][pos]).data) 23 | 24 | def test_global_cross_attention_map(self): 25 | m = GlobalCrossAttentionMap(self.c, self.t_size, self.n_head) 26 | x = torch.randn(self.b, self.c, self.h, self.w) 27 | t = torch.randn(self.b, self.t_size) 28 | map = m(x, t) 29 | 30 | self.assertTupleEqual((self.b, self.n_head, self.h * self.w), map.size()) 31 | for b_i in range(self.b): 32 | for head_i in range(self.n_head): 33 | self.assertTrue(1, torch.sum(map[b_i][head_i]).data) 34 | 35 | def test_attention(self): 36 | m = AttentionModule(self.c, self.t_size, self.n_head) 37 | x = torch.randn(self.b, self.c, self.h, self.w) 38 | t = torch.randn(self.b, self.t_size) 39 | out, map = m(x, t, return_map=True) 40 | 41 | self.assertTupleEqual((self.b, self.c, self.h, self.w), out.size()) 42 | self.assertTupleEqual((self.b, self.n_head, self.h * self.w, self.h * self.w), map.size()) 43 | for b_i in range(self.b): 44 | for head_i in range(self.n_head): 45 | for pos in range(self.h * self.w): 46 | self.assertTrue(2, torch.sum(map[b_i][head_i][pos]).data) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /models/compositors/__init__.py: -------------------------------------------------------------------------------- 1 | from models.compositors.global_style_models import GlobalStyleTransformer2 2 | from models.compositors.transformers import DisentangledTransformer 3 | 4 | 5 | def global_styler_factory(code, feature_size, text_feature_size): 6 | if code == GlobalStyleTransformer2.code(): 7 | return GlobalStyleTransformer2(feature_size, text_feature_size) 8 | else: 9 | raise ValueError("{} not exists".format(code)) 10 | 11 | 12 | def transformer_factory(feature_sizes, configs): 13 | text_feature_size = feature_sizes['text_feature_size'] 14 | num_heads = configs['num_heads'] 15 | 16 | global_styler_code = configs['global_styler'] 17 | global_styler = global_styler_factory(global_styler_code, feature_sizes['layer4'], text_feature_size) 18 | return {'layer4': DisentangledTransformer(feature_sizes['layer4'], text_feature_size, num_heads=num_heads, 19 | global_styler=global_styler)} 20 | -------------------------------------------------------------------------------- /models/compositors/global_style_models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from models.utils import calculate_mean_std, EqualLinear 4 | from trainers.abc import AbstractGlobalStyleTransformer 5 | 6 | 7 | class GlobalStyleTransformer2(AbstractGlobalStyleTransformer): 8 | def __init__(self, feature_size, text_feature_size, *args, **kwargs): 9 | super().__init__() 10 | self.global_transform = EqualLinear(text_feature_size, feature_size * 2) 11 | self.gate = EqualLinear(text_feature_size, feature_size * 2) 12 | self.sigmoid = nn.Sigmoid() 13 | 14 | self.init_style_weights(feature_size) 15 | 16 | def forward(self, normed_x, t, *args, **kwargs): 17 | x_mu, x_std = calculate_mean_std(kwargs['x']) 18 | gate = self.sigmoid(self.gate(t)).unsqueeze(-1).unsqueeze(-1) 19 | std_gate, mu_gate = gate.chunk(2, 1) 20 | 21 | global_style = self.global_transform(t).unsqueeze(2).unsqueeze(3) 22 | gamma, beta = global_style.chunk(2, 1) 23 | 24 | gamma = std_gate * x_std + gamma 25 | beta = mu_gate * x_mu + beta 26 | out = gamma * normed_x + beta 27 | return out 28 | 29 | def init_style_weights(self, feature_size): 30 | self.global_transform.linear.bias.data[:feature_size] = 1 31 | self.global_transform.linear.bias.data[feature_size:] = 0 32 | 33 | @classmethod 34 | def code(cls) -> str: 35 | return 'global2' 36 | -------------------------------------------------------------------------------- /models/compositors/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.attention_modules.self_attention import AttentionModule 5 | 6 | 7 | class DisentangledTransformer(nn.Module): 8 | def __init__(self, feature_size, text_feature_size, num_heads, global_styler=None, *args, **kwargs): 9 | super().__init__() 10 | self.n_heads = num_heads 11 | self.c_per_head = feature_size // num_heads 12 | assert feature_size == self.n_heads * self.c_per_head 13 | 14 | self.att_module = AttentionModule(feature_size, text_feature_size, num_heads, *args, **kwargs) 15 | self.att_module2 = AttentionModule(feature_size, text_feature_size, num_heads, *args, **kwargs) 16 | self.global_styler = global_styler 17 | 18 | self.weights = nn.Parameter(torch.tensor([1., 1.])) 19 | self.instance_norm = nn.InstanceNorm2d(feature_size) 20 | 21 | def forward(self, x, t, *args, **kwargs): 22 | normed_x = self.instance_norm(x) 23 | att_out, att_map = self.att_module(normed_x, t, return_map=True) 24 | out = normed_x + self.weights[0] * att_out 25 | 26 | att_out2, att_map2 = self.att_module2(out, t, return_map=True) 27 | out = out + self.weights[1] * att_out2 28 | 29 | out = self.global_styler(out, t, x=x) 30 | 31 | return out, att_map 32 | -------------------------------------------------------------------------------- /models/image_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from trainers.abc import AbstractBaseImageLowerEncoder, AbstractBaseImageUpperEncoder 4 | from models.image_encoders.resnet import ResNet18Layer4Lower, ResNet18Layer4Upper, ResNet50Layer4Lower, \ 5 | ResNet50Layer4Upper 6 | 7 | 8 | def image_encoder_factory(config: dict) -> Tuple[AbstractBaseImageLowerEncoder, AbstractBaseImageUpperEncoder]: 9 | model_code = config['image_encoder'] 10 | feature_size = config['feature_size'] 11 | pretrained = config.get('pretrained', True) 12 | norm_scale = config.get('norm_scale', 4) 13 | 14 | if model_code == 'resnet18_layer4': 15 | lower_encoder = ResNet18Layer4Lower(pretrained) 16 | lower_feature_shape = lower_encoder.layer_shapes()['layer4'] 17 | upper_encoder = ResNet18Layer4Upper(lower_feature_shape, feature_size, pretrained=pretrained, 18 | norm_scale=norm_scale) 19 | return lower_encoder, upper_encoder 20 | elif model_code == 'resnet50_layer4': 21 | lower_encoder = ResNet50Layer4Lower(pretrained) 22 | lower_feature_shape = lower_encoder.layer_shapes()['layer4'] 23 | upper_encoder = ResNet50Layer4Upper(lower_feature_shape, feature_size, pretrained=pretrained, 24 | norm_scale=norm_scale) 25 | return lower_encoder, upper_encoder 26 | else: 27 | raise ValueError("There's no image encoder matched with {}".format(model_code)) 28 | -------------------------------------------------------------------------------- /models/image_encoders/resnet.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.models import resnet18, resnet50 7 | 8 | from trainers.abc import AbstractBaseImageLowerEncoder, AbstractBaseImageUpperEncoder 9 | 10 | 11 | class ResNet18Layer4Lower(AbstractBaseImageLowerEncoder): 12 | def __init__(self, pretrained=True): 13 | super().__init__() 14 | self._model = resnet18(pretrained=pretrained) 15 | 16 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: 17 | x = self._model.conv1(x) 18 | x = self._model.bn1(x) 19 | x = self._model.relu(x) 20 | x = self._model.maxpool(x) 21 | 22 | layer1_out = self._model.layer1(x) 23 | layer2_out = self._model.layer2(layer1_out) 24 | layer3_out = self._model.layer3(layer2_out) 25 | layer4_out = self._model.layer4(layer3_out) 26 | 27 | return layer4_out, (layer3_out, layer2_out, layer1_out) 28 | 29 | def layer_shapes(self): 30 | return {'layer4': 512, 'layer3': 256, 'layer2': 128, 'layer1': 64} 31 | 32 | 33 | class ResNet18Layer4Upper(AbstractBaseImageUpperEncoder): 34 | def __init__(self, lower_feature_shape, feature_size, pretrained=True, *args, **kwargs): 35 | super().__init__(lower_feature_shape, feature_size, pretrained=pretrained, *args, **kwargs) 36 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 37 | self.fc = nn.Linear(self.lower_feature_shape, self.feature_size) 38 | self.norm_scale = kwargs['norm_scale'] 39 | 40 | def forward(self, layer4_out: torch.Tensor) -> torch.Tensor: 41 | x = self.avgpool(layer4_out) 42 | x = torch.flatten(x, 1) 43 | x = F.normalize(self.fc(x)) * self.norm_scale 44 | 45 | return x 46 | 47 | 48 | class GAPResNet18Layer4Upper(AbstractBaseImageUpperEncoder): 49 | def __init__(self, lower_feature_shape, feature_size, pretrained=True, *args, **kwargs): 50 | super().__init__(lower_feature_shape, feature_size, pretrained=pretrained, *args, **kwargs) 51 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 52 | self.norm_scale = kwargs['norm_scale'] 53 | 54 | def forward(self, layer4_out: torch.Tensor) -> torch.Tensor: 55 | x = self.avgpool(layer4_out) 56 | x = torch.flatten(x, 1) 57 | x = F.normalize(x) * self.norm_scale 58 | 59 | return x 60 | 61 | 62 | class ResNet50Layer4Lower(AbstractBaseImageLowerEncoder): 63 | def __init__(self, pretrained=True): 64 | super().__init__() 65 | self._model = resnet50(pretrained=pretrained) 66 | 67 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: 68 | x = self._model.conv1(x) 69 | x = self._model.bn1(x) 70 | x = self._model.relu(x) 71 | x = self._model.maxpool(x) 72 | 73 | layer1_out = self._model.layer1(x) 74 | layer2_out = self._model.layer2(layer1_out) 75 | layer3_out = self._model.layer3(layer2_out) 76 | layer4_out = self._model.layer4(layer3_out) 77 | 78 | return layer4_out, (layer3_out, layer2_out, layer1_out) 79 | 80 | def layer_shapes(self): 81 | return {'layer4': 2048, 'layer3': 1024, 'layer2': 512, 'layer1': 256} 82 | 83 | 84 | class ResNet50Layer4Upper(AbstractBaseImageUpperEncoder): 85 | def __init__(self, lower_feature_shape, feature_size, pretrained=True, *args, **kwargs): 86 | super().__init__(lower_feature_shape, feature_size, pretrained=pretrained, *args, **kwargs) 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 88 | self.fc = nn.Linear(self.lower_feature_shape, self.feature_size) 89 | self.norm_scale = kwargs['norm_scale'] 90 | 91 | def forward(self, layer4_out: torch.Tensor) -> torch.Tensor: 92 | x = self.avgpool(layer4_out) 93 | x = torch.flatten(x, 1) 94 | x = F.normalize(self.fc(x)) * self.norm_scale 95 | 96 | return x 97 | -------------------------------------------------------------------------------- /models/text_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from language.vocabulary import AbstractBaseVocabulary 2 | from trainers.abc import AbstractBaseTextEncoder 3 | from models.text_encoders.lstm import SimpleLSTMEncoder, NormalizationLSTMEncoder, SimplerLSTMEncoder 4 | 5 | TEXT_MODEL_CODES = [SimpleLSTMEncoder.code(), NormalizationLSTMEncoder.code(), SimplerLSTMEncoder.code()] 6 | 7 | 8 | def text_encoder_factory(vocabulary: AbstractBaseVocabulary, config: dict) -> AbstractBaseTextEncoder: 9 | model_code = config['text_encoder'] 10 | feature_size = config['text_feature_size'] 11 | word_embedding_size = config['word_embedding_size'] 12 | lstm_hidden_size = config['lstm_hidden_size'] 13 | 14 | if model_code == SimpleLSTMEncoder.code(): 15 | return SimpleLSTMEncoder(vocabulary_len=len(vocabulary), padding_idx=vocabulary.pad_id(), 16 | feature_size=feature_size, word_embedding_size=word_embedding_size, 17 | lstm_hidden_size=lstm_hidden_size) 18 | elif model_code == NormalizationLSTMEncoder.code(): 19 | return NormalizationLSTMEncoder(vocabulary_len=len(vocabulary), padding_idx=vocabulary.pad_id(), 20 | feature_size=feature_size, norm_scale=config['norm_scale'], 21 | word_embedding_size=word_embedding_size, 22 | lstm_hidden_size=lstm_hidden_size) 23 | elif model_code == SimplerLSTMEncoder.code(): 24 | return SimplerLSTMEncoder(vocabulary_len=len(vocabulary), padding_idx=vocabulary.pad_id(), 25 | feature_size=feature_size, word_embedding_size=word_embedding_size, 26 | lstm_hidden_size=lstm_hidden_size) 27 | else: 28 | raise ValueError("There's no text encoder matched with {}".format(model_code)) 29 | -------------------------------------------------------------------------------- /models/text_encoders/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from trainers.abc import AbstractBaseTextEncoder 6 | from models.text_encoders.utils import retrieve_last_timestamp_output 7 | 8 | 9 | class SimpleLSTMEncoder(AbstractBaseTextEncoder): 10 | def __init__(self, vocabulary_len, padding_idx, feature_size, *args, **kwargs): 11 | super().__init__(vocabulary_len, padding_idx, feature_size, *args, **kwargs) 12 | word_embedding_size = kwargs.get('word_embedding_size', 512) 13 | lstm_hidden_size = kwargs.get('lstm_hidden_size', 512) 14 | feature_size = feature_size 15 | 16 | self.embedding_layer = nn.Embedding(vocabulary_len, word_embedding_size, padding_idx=padding_idx) 17 | self.lstm = nn.LSTM(word_embedding_size, lstm_hidden_size, batch_first=True) 18 | self.fc = nn.Sequential( 19 | nn.Dropout(p=0.1), 20 | nn.Linear(lstm_hidden_size, feature_size), 21 | ) 22 | 23 | def forward(self, x, lengths): 24 | # x is a tensor that has shape of (batch_size * seq_len) 25 | x = self.embedding_layer(x) # x's shape (batch_size * seq_len * word_embed_dim) 26 | self.lstm.flatten_parameters() 27 | lstm_outputs, _ = self.lstm(x) 28 | outputs = retrieve_last_timestamp_output(lstm_outputs, lengths) 29 | 30 | outputs = self.fc(outputs) 31 | return outputs 32 | 33 | @classmethod 34 | def code(cls) -> str: 35 | return 'lstm' 36 | 37 | 38 | class NormalizationLSTMEncoder(SimpleLSTMEncoder): 39 | def __init__(self, vocabulary_len, padding_idx, feature_size, *args, **kwargs): 40 | super().__init__(vocabulary_len, padding_idx, feature_size, *args, **kwargs) 41 | self.norm_scale = kwargs['norm_scale'] 42 | 43 | def forward(self, x: torch.Tensor, lengths: torch.LongTensor) -> torch.Tensor: 44 | outputs = super().forward(x, lengths) 45 | return F.normalize(outputs) * self.norm_scale 46 | 47 | @classmethod 48 | def code(cls) -> str: 49 | return 'norm_lstm' 50 | 51 | 52 | class SimplerLSTMEncoder(AbstractBaseTextEncoder): 53 | def __init__(self, vocabulary_len, padding_idx, feature_size, *args, **kwargs): 54 | super().__init__(vocabulary_len, padding_idx, feature_size, *args, **kwargs) 55 | word_embedding_size = kwargs.get('word_embedding_size', 512) 56 | 57 | self.embedding_layer = nn.Embedding(vocabulary_len, word_embedding_size, padding_idx=padding_idx) 58 | self.lstm = nn.LSTM(word_embedding_size, self.feature_size, batch_first=True) 59 | 60 | def forward(self, x, lengths): 61 | # x is a tensor that has shape of (batch_size * seq_len) 62 | x = self.embedding_layer(x) # x's shape (batch_size * seq_len * word_embed_dim) 63 | self.lstm.flatten_parameters() 64 | lstm_outputs, _ = self.lstm(x) 65 | return retrieve_last_timestamp_output(lstm_outputs, lengths) 66 | 67 | @classmethod 68 | def code(cls) -> str: 69 | return 'simpler_lstm' 70 | -------------------------------------------------------------------------------- /models/text_encoders/test_lstm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from models.text_encoders.lstm import SimpleLSTMEncoder, NormalizationLSTMEncoder, SimplerLSTMEncoder 6 | 7 | 8 | class TestLSTMModel(unittest.TestCase): 9 | def setUp(self): 10 | self.vocabulary_size = 10 11 | self.batch_size, self.max_seq_len, self.feature_size = 4, 5, 3 12 | self.word_embedding_size, self.lstm_hidden_size, self.fc_output_size = 5, 10, 11 13 | self.norm_scale = 4 14 | self.epsilon = 1e-4 15 | 16 | self.some_big_scala = 10000 17 | self.inputs = torch.randint(0, self.vocabulary_size, [self.batch_size, self.max_seq_len]) 18 | 19 | def test_the_input_output_size_of_simple_lstm_model(self): 20 | model = SimpleLSTMEncoder(self.vocabulary_size, padding_idx=0, word_embedding_size=self.word_embedding_size, 21 | lstm_hidden_size=self.lstm_hidden_size, feature_size=self.fc_output_size) 22 | lengths = torch.LongTensor([3, 1, 2, 4]) 23 | 24 | outputs = model(self.inputs, lengths) 25 | self.assertTupleEqual((self.batch_size, self.fc_output_size), outputs.size()) 26 | 27 | def test_the_input_output_size_of_simpler_lstm_model(self): 28 | model = SimplerLSTMEncoder(self.vocabulary_size, padding_idx=0, word_embedding_size=self.word_embedding_size, 29 | lstm_hidden_size=self.lstm_hidden_size, feature_size=self.fc_output_size) 30 | lengths = torch.LongTensor([3, 1, 2, 4]) 31 | 32 | outputs = model(self.inputs, lengths) 33 | self.assertTupleEqual((self.batch_size, self.fc_output_size), outputs.size()) 34 | 35 | def test_the_input_output_size_of_norm_lstm_model(self): 36 | model = NormalizationLSTMEncoder(self.vocabulary_size, padding_idx=0, 37 | word_embedding_size=self.word_embedding_size, 38 | lstm_hidden_size=self.lstm_hidden_size, 39 | feature_size=self.fc_output_size, norm_scale=self.norm_scale) 40 | lengths = torch.LongTensor([3, 1, 2, 4]) 41 | 42 | outputs = model(self.inputs, lengths) 43 | 44 | self.assertTupleEqual((self.batch_size, self.fc_output_size), outputs.size()) 45 | self.assertEqual(self.batch_size, torch.sum(torch.norm(outputs, dim=1) < self.norm_scale + self.epsilon)) 46 | 47 | def test_the_input_output_size_of_batch_size1_lstm_model2(self): 48 | batch_size, max_seq_len, feature_size = 1, 14, 3 49 | inputs = torch.randint(0, self.vocabulary_size, [batch_size, max_seq_len]) 50 | 51 | model = SimpleLSTMEncoder(self.vocabulary_size, padding_idx=0, 52 | word_embedding_size=self.word_embedding_size, 53 | lstm_hidden_size=self.lstm_hidden_size, 54 | feature_size=self.fc_output_size, norm_scale=self.norm_scale) 55 | lengths = torch.LongTensor([14]) 56 | 57 | outputs = model(inputs, lengths) 58 | 59 | self.assertTupleEqual((batch_size, self.fc_output_size), outputs.size()) 60 | self.assertEqual(batch_size, torch.sum(torch.norm(outputs, dim=1) < self.norm_scale + self.epsilon)) 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /models/text_encoders/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from models.text_encoders.utils import retrieve_last_timestamp_output 6 | 7 | 8 | class TestTextModelUtils(unittest.TestCase): 9 | def test_retrieve_last_timestamp_output(self): 10 | batch_size, max_seq_len, feature_size = 4, 5, 3 11 | lstm_outputs = torch.rand([batch_size, max_seq_len, feature_size]) 12 | last_timestamps = torch.LongTensor([3, 1, 4, 2]) 13 | 14 | outputs = retrieve_last_timestamp_output(lstm_outputs, last_timestamps) 15 | 16 | for output, lstm_output, last_timestamp in zip(outputs, lstm_outputs, last_timestamps): 17 | self.assertTrue(output.equal(lstm_output[last_timestamp - 1].squeeze(0))) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /models/text_encoders/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def retrieve_last_timestamp_output(lstm_outputs: torch.Tensor, lengths: torch.LongTensor, timestamp_dim=1): 5 | batch_size, max_seq_len, lstm_hidden_dim = lstm_outputs.size() 6 | 7 | last_timestamps = (lengths - 1).view(-1, 1).expand(batch_size, lstm_hidden_dim) # (batch_size, feature_size) 8 | last_timestamps = last_timestamps.unsqueeze(timestamp_dim) # (batch_size, 1, feature_size) 9 | return lstm_outputs.gather(timestamp_dim, last_timestamps).squeeze(1) # (batch_size, feature_size) 10 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def reshape_text_features_to_concat(text_features, image_features_shapes): 8 | return text_features.view((*text_features.size(), 1, 1)).repeat(1, 1, *image_features_shapes[2:]) 9 | 10 | 11 | def calculate_mean_std(x): 12 | mu = torch.mean(x, dim=(2, 3), keepdim=True).detach() 13 | std = torch.std(x, dim=(2, 3), keepdim=True, unbiased=False).detach() 14 | return mu, std 15 | 16 | 17 | class EqualLR: 18 | def __init__(self, name): 19 | self.name = name 20 | 21 | def compute_weight(self, module): 22 | weight = getattr(module, self.name + '_orig') 23 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 24 | 25 | return weight * sqrt(2 / fan_in) 26 | 27 | @staticmethod 28 | def apply(module, name): 29 | fn = EqualLR(name) 30 | 31 | weight = getattr(module, name) 32 | del module._parameters[name] 33 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 34 | module.register_forward_pre_hook(fn) 35 | 36 | return fn 37 | 38 | def __call__(self, module, input): 39 | weight = self.compute_weight(module) 40 | setattr(module, self.name, weight) 41 | 42 | 43 | def equal_lr(module, name='weight'): 44 | EqualLR.apply(module, name) 45 | 46 | return module 47 | 48 | 49 | class EqualLinear(nn.Module): 50 | def __init__(self, in_dim, out_dim, bias=True): 51 | super().__init__() 52 | 53 | linear = nn.Linear(in_dim, out_dim, bias=bias) 54 | linear.weight.data.normal_() 55 | if bias: 56 | linear.bias.data.zero_() 57 | 58 | self.linear = equal_lr(linear) 59 | 60 | def forward(self, inputs): 61 | return self.linear(inputs) -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim import Adam, SGD 5 | from torch.optim.optimizer import Optimizer 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | 9 | def create_optimizers(models, config): 10 | optimizer = config['optimizer'] 11 | lr = config['lr'] 12 | weight_decay = config['weight_decay'] 13 | momentum = config['momentum'] 14 | 15 | parameterized_models = {key: model for key, model in models.items() if len(list(model.parameters())) > 0} 16 | 17 | optimizers = {} 18 | if optimizer == 'Adam': 19 | optimizers = {key: Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 20 | for key, model in parameterized_models.items()} 21 | elif optimizer == 'SGD': 22 | optimizers = {key: SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum) 23 | for key, model in parameterized_models.items()} 24 | elif optimizer == 'RAdam': 25 | optimizers = {key: RAdam(model.parameters(), lr=lr, weight_decay=weight_decay) 26 | for key, model in parameterized_models.items()} 27 | return optimizers 28 | 29 | 30 | def create_lr_schedulers(optimizers, config): 31 | decay_step = config['decay_step'] 32 | gamma = config['gamma'] 33 | return {key: StepLR(optimizer, step_size=decay_step, gamma=gamma) for key, optimizer in optimizers.items()} 34 | 35 | 36 | class RAdam(Optimizer): 37 | 38 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 39 | if not 0.0 <= lr: 40 | raise ValueError("Invalid learning rate: {}".format(lr)) 41 | if not 0.0 <= eps: 42 | raise ValueError("Invalid epsilon value: {}".format(eps)) 43 | if not 0.0 <= betas[0] < 1.0: 44 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 45 | if not 0.0 <= betas[1] < 1.0: 46 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 47 | 48 | self.degenerated_to_sgd = degenerated_to_sgd 49 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 50 | for param in params: 51 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 52 | param['buffer'] = [[None, None, None] for _ in range(10)] 53 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 54 | buffer=[[None, None, None] for _ in range(10)]) 55 | super(RAdam, self).__init__(params, defaults) 56 | 57 | def __setstate__(self, state): 58 | super(RAdam, self).__setstate__(state) 59 | 60 | def step(self, closure=None): 61 | 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | 68 | for p in group['params']: 69 | if p.grad is None: 70 | continue 71 | grad = p.grad.data.float() 72 | if grad.is_sparse: 73 | raise RuntimeError('RAdam does not support sparse gradients') 74 | 75 | p_data_fp32 = p.data.float() 76 | 77 | state = self.state[p] 78 | 79 | if len(state) == 0: 80 | state['step'] = 0 81 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 82 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 83 | else: 84 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 85 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | beta1, beta2 = group['betas'] 89 | 90 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 91 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 92 | 93 | state['step'] += 1 94 | buffered = group['buffer'][int(state['step'] % 10)] 95 | if state['step'] == buffered[0]: 96 | N_sma, step_size = buffered[1], buffered[2] 97 | else: 98 | buffered[0] = state['step'] 99 | beta2_t = beta2 ** state['step'] 100 | N_sma_max = 2 / (1 - beta2) - 1 101 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 102 | buffered[1] = N_sma 103 | 104 | # more conservative since it's an approximated value 105 | if N_sma >= 5: 106 | step_size = math.sqrt( 107 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 108 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 109 | elif self.degenerated_to_sgd: 110 | step_size = 1.0 / (1 - beta1 ** state['step']) 111 | else: 112 | step_size = -1 113 | buffered[2] = step_size 114 | 115 | # more conservative since it's an approximated value 116 | if N_sma >= 5: 117 | if group['weight_decay'] != 0: 118 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 119 | denom = exp_avg_sq.sqrt().add_(group['eps']) 120 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 121 | p.data.copy_(p_data_fp32) 122 | elif step_size > 0: 123 | if group['weight_decay'] != 0: 124 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 125 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 126 | p.data.copy_(p_data_fp32) 127 | 128 | return loss 129 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | from options.command_line import load_config_from_command 2 | from options.config_file import load_config_from_file 3 | 4 | 5 | def _merge_configs(configs_ordered_by_increasing_priority): 6 | merged_config = {} 7 | for config in configs_ordered_by_increasing_priority: 8 | for k, v in config.items(): 9 | merged_config[k] = v 10 | return merged_config 11 | 12 | 13 | def _check_mandatory_config(config_from_config_file, user_defined_configs, 14 | exception_keys=('experiment_description', 'device_idx')): 15 | exception_keys = [] if exception_keys is None else exception_keys 16 | trigger = False 17 | undefined_configs = [] 18 | for key, val in config_from_config_file.items(): 19 | if val == "": 20 | if key not in user_defined_configs and key not in exception_keys: 21 | trigger = True 22 | undefined_configs.append(key) 23 | print("Must define {} setting from command".format(key)) 24 | if trigger: 25 | raise Exception('Mandatory configs not defined:', undefined_configs) 26 | 27 | 28 | def _generate_experiment_description(configs, config_from_command): 29 | experiment_description = configs['experiment_description'] 30 | 31 | if experiment_description == "": 32 | remove_keys = ['dataset', 'trainer', 'config_path', 'device_idx'] 33 | for key in remove_keys: 34 | if key in config_from_command: 35 | config_from_command.pop(key) 36 | 37 | descriptors = [] 38 | for key, val in config_from_command.items(): 39 | descriptors.append(key + str(val)) 40 | experiment_description = "_".join([configs['dataset'], configs['trainer'], *descriptors]) 41 | return experiment_description 42 | 43 | 44 | def get_experiment_config(): 45 | config_from_command, user_defined_configs = load_config_from_command() 46 | config_from_config_file = load_config_from_file(config_from_command['config_path']) 47 | _check_mandatory_config(config_from_config_file, user_defined_configs) 48 | merged_configs = _merge_configs([config_from_command, config_from_config_file, user_defined_configs]) 49 | experiment_description = _generate_experiment_description(merged_configs, user_defined_configs) 50 | merged_configs['experiment_description'] = experiment_description 51 | return merged_configs 52 | -------------------------------------------------------------------------------- /options/command_line.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | parser = argparse.ArgumentParser(description='Options for CoSMo.pytorch') 5 | 6 | ######################### 7 | # Load Template 8 | ######################### 9 | parser.add_argument('--config_path', type=str, default='', help='config json path') 10 | 11 | ######################### 12 | # Trainer Settings 13 | ######################### 14 | parser.add_argument('--trainer', type=str, default="tirg", help='Select Trainer') 15 | parser.add_argument('--epoch', type=int, default=80, help='epoch (default: 80)') 16 | parser.add_argument('--evaluator', type=str, default="simple", help='Select Evaluator') 17 | 18 | ######################### 19 | # Language Template 20 | ######################### 21 | parser.add_argument('--vocab_path', type=str, default='', help='Vocabulary path') 22 | parser.add_argument('--vocab_threshold', type=int, default=0, help='Vocabulary word count threshold') 23 | 24 | ######################### 25 | # Dataset / DataLoader Settings 26 | ######################### 27 | parser.add_argument('--dataset', type=str, default='fashionIQ_dress', 28 | choices=['fashionIQ_dress', 'fashionIQ_toptee', 'fashionIQ_shirt'], help='Dataset') 29 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size') 30 | parser.add_argument('--num_workers', type=int, default=16, help='The Number of Workers') 31 | parser.add_argument('--shuffle', type=bool, default=True, help='Shuffle Dataset') 32 | parser.add_argument('--use_subset', type=bool, default=False, help='Test Using Subset') 33 | 34 | ######################### 35 | # Image Transform Settings 36 | ######################### 37 | parser.add_argument('--use_transform', type=bool, default=True, help='Use Transform') 38 | parser.add_argument('--img_size', type=int, default=224, help='Img Size') 39 | 40 | ######################### 41 | # Loss Settings 42 | ######################### 43 | parser.add_argument('--metric_loss', type=str, default="batch_based_classification_loss", help='Metric Loss Code') 44 | 45 | ######################### 46 | # Encoder Settings 47 | ######################### 48 | parser.add_argument('--feature_size', type=int, default=512, help='Image Feature Size') 49 | parser.add_argument('--text_feature_size', type=int, default=512, help='Text Feature Size') 50 | parser.add_argument('--word_embedding_size', type=int, default=512, help='Word Embedding Size') 51 | parser.add_argument('--image_encoder', type=str, default='resnet50_layer4', help='Image Model') 52 | parser.add_argument('--text_encoder', type=str, default="lstm", help='Text Model') 53 | 54 | ######################### 55 | # Composition Model Settings 56 | ######################### 57 | parser.add_argument('--compositor', type=str, default="transformer", help='Composition Model') 58 | parser.add_argument('--norm_scale', type=float, default=4, help='Norm Scale') 59 | parser.add_argument('--num_heads', type=int, default=8, help='Num Heads') 60 | parser.add_argument('--global_styler', type=str, default='global2', help='Global Styler') 61 | 62 | ######################### 63 | # Optimizer Settings 64 | ######################### 65 | parser.add_argument('--optimizer', type=str, default='RAdam', choices=['SGD', 'Adam', 'RAdam'], help='Optimizer') 66 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate (default: 2e-4)') 67 | parser.add_argument('--weight_decay', type=float, default=5e-5, help='l2 regularization lambda (default: 5e-5)') 68 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') 69 | parser.add_argument('--decay_step', type=int, default=30, help='num epochs for decaying learning rate') 70 | parser.add_argument('--gamma', type=float, default=0.1, help='learning rate decay gamma') 71 | 72 | ######################### 73 | # Logging Settings 74 | ######################### 75 | parser.add_argument('--topk', type=str, default='1,5,10,50', help='topK recall for evaluation') 76 | parser.add_argument('--wandb_project_name', type=str, default='CoSMo.pytorch', help='Weights & Biases project name') 77 | parser.add_argument('--wandb_account_name', type=str, default='your_account_name', help='Weights & Biases account name') 78 | 79 | ######################### 80 | # Resume Training 81 | ######################### 82 | parser.add_argument('--checkpoint_path', type=str, default='', help='Path to saved checkpoint file') 83 | 84 | ######################### 85 | # Misc 86 | ######################### 87 | parser.add_argument('--device_idx', type=str, default='0,1', help='Gpu idx') 88 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed value') 89 | parser.add_argument('--experiment_dir', type=str, default='experiments', help='Experiment save directory') 90 | parser.add_argument('--experiment_description', type=str, default='NO', help='Experiment description') 91 | 92 | 93 | def _get_user_defined_arguments(argvs): 94 | prefix, conjugator = '--', '=' 95 | return [argv.replace(prefix, '').split(conjugator)[0] for argv in argvs] 96 | 97 | 98 | def load_config_from_command(): 99 | user_defined_argument = _get_user_defined_arguments(sys.argv[1:]) 100 | 101 | configs = vars(parser.parse_args()) 102 | user_defined_configs = {k: v for k, v in configs.items() if k in user_defined_argument} 103 | return configs, user_defined_configs 104 | -------------------------------------------------------------------------------- /options/config_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_config_from_file(json_path): 5 | if not json_path: 6 | return {} 7 | 8 | with open(json_path, 'r') as f: 9 | config = json.load(f) 10 | 11 | print("Config at '{}' has been loaded".format(json_path)) 12 | return config 13 | -------------------------------------------------------------------------------- /readme_resources/CoSMo poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/postBG/CosMo.pytorch/ab4af36fafd7a0214d55810f42e74b034fb31f3c/readme_resources/CoSMo poster.pdf -------------------------------------------------------------------------------- /readme_resources/cosmo_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/postBG/CosMo.pytorch/ab4af36fafd7a0214d55810f42e74b034fb31f3c/readme_resources/cosmo_fig.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | argh==0.26.2 3 | blessings==1.7 4 | certifi==2019.11.28 5 | chardet==3.0.4 6 | Click==7.0 7 | configparser==4.0.2 8 | docker-pycreds==0.4.0 9 | gitdb2==2.0.6 10 | GitPython==3.0.5 11 | gpustat==0.6.0 12 | gql==0.2.0 13 | graphql-core==1.1 14 | idna==2.8 15 | numpy==1.18.1 16 | nltk==3.5 17 | nvidia-ml-py3==7.352.0 18 | pathtools==0.1.2 19 | Pillow==6.1.0 20 | promise==2.3 21 | psutil==5.6.7 22 | python-dateutil==2.8.1 23 | PyYAML==5.3 24 | requests==2.22.0 25 | sentry-sdk==0.14.0 26 | shortuuid==0.5.0 27 | six==1.13.0 28 | smmap2==2.0.5 29 | subprocess32==3.5.4 30 | torch==1.2.0 31 | torchvision==0.4.0 32 | tqdm==4.41.1 33 | urllib3==1.25.7 34 | wandb==0.8.35 35 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@6b5e8953a80aef5b324104dc0c2e9b8c34d622bd 36 | watchdog==0.9.0 37 | xxhash==1.4.3 38 | -------------------------------------------------------------------------------- /set_up.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pprint as pp 4 | import random 5 | from datetime import date 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import wandb 11 | 12 | 13 | def fix_random_seed_as(random_seed): 14 | if random_seed == -1: 15 | random_seed = np.random.randint(100000) 16 | print("RANDOM SEED: {}".format(random_seed)) 17 | 18 | random.seed(random_seed) 19 | torch.manual_seed(random_seed) 20 | torch.cuda.manual_seed_all(random_seed) 21 | np.random.seed(random_seed) 22 | cudnn.deterministic = True 23 | cudnn.benchmark = False 24 | return random_seed 25 | 26 | 27 | def _get_experiment_index(experiment_path): 28 | idx = 0 29 | while os.path.exists(experiment_path + "_" + str(idx)): 30 | idx += 1 31 | return idx 32 | 33 | 34 | def create_experiment_export_folder(experiment_dir, experiment_description): 35 | print(os.path.abspath(experiment_dir)) 36 | if not os.path.exists(experiment_dir): 37 | os.mkdir(experiment_dir) 38 | experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description) 39 | print(os.path.abspath(experiment_path)) 40 | os.mkdir(experiment_path) 41 | print("folder created: " + os.path.abspath(experiment_path)) 42 | return experiment_path 43 | 44 | 45 | def get_name_of_experiment_path(experiment_dir, experiment_description): 46 | experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today()))) 47 | idx = _get_experiment_index(experiment_path) 48 | experiment_path = experiment_path + "_" + str(idx) 49 | return experiment_path 50 | 51 | 52 | def export_config_as_json(config, experiment_path): 53 | with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile: 54 | json.dump(config, outfile, indent=2) 55 | 56 | 57 | def generate_tags(config): 58 | tags = [] 59 | tags.append(config.get('generator', config.get('text_encoder'))) 60 | tags.append(config.get('trainer')) 61 | tags = [tag for tag in tags if tag is not None] 62 | return tags 63 | 64 | 65 | def set_up_gpu(device_idx): 66 | if device_idx: 67 | os.environ['CUDA_VISIBLE_DEVICES'] = device_idx 68 | return { 69 | 'num_gpu': len(device_idx.split(",")) 70 | } 71 | else: 72 | idxs = os.environ['CUDA_VISIBLE_DEVICES'] 73 | return { 74 | 'num_gpu': len(idxs.split(",")) 75 | } 76 | 77 | 78 | def setup_experiment(config): 79 | device_info = set_up_gpu(config['device_idx']) 80 | config.update(device_info) 81 | 82 | random_seed = fix_random_seed_as(config['random_seed']) 83 | config['random_seed'] = random_seed 84 | export_root = create_experiment_export_folder(config['experiment_dir'], config['experiment_description']) 85 | export_config_as_json(config, export_root) 86 | config['export_root'] = export_root 87 | 88 | pp.pprint(config, width=1) 89 | os.environ['WANDB_SILENT'] = "true" 90 | tags = generate_tags(config) 91 | project_name = config['wandb_project_name'] 92 | wandb_account_name = config['wandb_account_name'] 93 | experiment_name = config['experiment_description'] 94 | experiment_name = experiment_name if config['random_seed'] != -1 else experiment_name + "_{}".format(random_seed) 95 | wandb.init(config=config, name=experiment_name, project=project_name, 96 | entity=wandb_account_name, tags=tags) 97 | return export_root, config 98 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from trainers.tirg_trainer import TIRGTrainer 2 | 3 | TRAINER_CODE_DICT = { 4 | TIRGTrainer.code(): TIRGTrainer, 5 | } 6 | 7 | 8 | def get_trainer_cls(configs): 9 | trainer_code = configs['trainer'] 10 | return TRAINER_CODE_DICT[trainer_code] 11 | -------------------------------------------------------------------------------- /trainers/abc.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from abc import ABC 3 | from typing import Sequence, Tuple, Any 4 | 5 | import torch 6 | from torch import nn as nn 7 | 8 | 9 | class AbstractBaseLogger(ABC): 10 | @abc.abstractmethod 11 | def log(self, log_data: dict, step: int, commit: bool) -> None: 12 | raise NotImplementedError 13 | 14 | @abc.abstractmethod 15 | def complete(self, log_data: dict, step: int) -> None: 16 | raise NotImplementedError 17 | 18 | 19 | class LoggingService(object): 20 | def __init__(self, loggers: Sequence[AbstractBaseLogger]): 21 | self.loggers = loggers 22 | 23 | def log(self, log_data: dict, step: int, commit=False): 24 | for logger in self.loggers: 25 | logger.log(log_data, step, commit=commit) 26 | 27 | def complete(self, log_data: dict, step: int): 28 | for logger in self.loggers: 29 | logger.complete(log_data, step) 30 | 31 | 32 | class AbstractBaseMetricLoss(nn.Module, ABC): 33 | @abc.abstractmethod 34 | def forward(self, ref_features: torch.Tensor, tar_features: torch.Tensor) -> torch.Tensor: 35 | raise NotImplementedError 36 | 37 | @classmethod 38 | @abc.abstractmethod 39 | def code(cls) -> str: 40 | raise NotImplementedError 41 | 42 | 43 | class AbstractBaseImageLowerEncoder(nn.Module, abc.ABC): 44 | @abc.abstractmethod 45 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: 46 | raise NotImplementedError 47 | 48 | @abc.abstractmethod 49 | def layer_shapes(self): 50 | raise NotImplementedError 51 | 52 | 53 | class AbstractBaseImageUpperEncoder(nn.Module, abc.ABC): 54 | def __init__(self, lower_feature_shape, feature_size, pretrained=True, *args, **kwargs): 55 | super().__init__() 56 | self.lower_feature_shape = lower_feature_shape 57 | self.feature_size = feature_size 58 | 59 | @abc.abstractmethod 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | raise NotImplementedError 62 | 63 | 64 | class AbstractBaseTextEncoder(nn.Module, abc.ABC): 65 | @abc.abstractmethod 66 | def __init__(self, vocabulary_len, padding_idx, feature_size, *args, **kwargs): 67 | super().__init__() 68 | self.feature_size = feature_size 69 | 70 | @abc.abstractmethod 71 | def forward(self, x: torch.Tensor, lengths: torch.LongTensor) -> torch.Tensor: 72 | raise NotImplementedError 73 | 74 | @classmethod 75 | @abc.abstractmethod 76 | def code(cls) -> str: 77 | raise NotImplementedError 78 | 79 | 80 | class AbstractBaseCompositor(nn.Module, abc.ABC): 81 | @abc.abstractmethod 82 | def forward(self, mid_image_features, text_features, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 83 | raise NotImplementedError 84 | 85 | @classmethod 86 | @abc.abstractmethod 87 | def code(cls) -> str: 88 | raise NotImplementedError 89 | 90 | 91 | class AbstractGlobalStyleTransformer(nn.Module, abc.ABC): 92 | @abc.abstractmethod 93 | def forward(self, normed_x, t, *args, **kwargs): 94 | raise NotImplementedError 95 | 96 | @classmethod 97 | @abc.abstractmethod 98 | def code(cls) -> str: 99 | raise NotImplementedError 100 | 101 | 102 | class AbstractBaseTrainer(ABC): 103 | def __init__(self, models, train_dataloader, criterions, optimizers, lr_schedulers, num_epochs, 104 | train_loggers, val_loggers, evaluator, train_evaluator, *args, **kwargs): 105 | self.models = models 106 | self.train_dataloader = train_dataloader 107 | self.criterions = criterions 108 | self.optimizers = optimizers 109 | self.lr_schedulers = lr_schedulers 110 | self.num_epochs = num_epochs 111 | self.train_logging_service = LoggingService(train_loggers) 112 | self.val_logging_service = LoggingService(val_loggers) 113 | self.evaluator = evaluator 114 | self.train_evaluator = train_evaluator 115 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 116 | self.start_epoch = kwargs['start_epoch'] if 'start_epoch' in kwargs else 0 117 | 118 | def train_one_epoch(self, epoch) -> dict: 119 | raise NotImplementedError 120 | 121 | def run(self) -> dict: 122 | self._load_models_to_device() 123 | for epoch in range(self.start_epoch, self.num_epochs): 124 | for phase in ['train', 'val']: 125 | if phase == 'train': 126 | self._to_train_mode() 127 | train_results = self.train_one_epoch(epoch) 128 | self.train_logging_service.log(train_results, step=epoch) 129 | print(train_results) 130 | else: 131 | self._to_eval_mode() 132 | val_results, _ = self.evaluator.evaluate(epoch) 133 | # train_val_results = self.train_evaluator.evaluate(epoch) 134 | # self.train_logging_service.log(train_val_results, step=epoch) 135 | model_state_dicts = self._get_state_dicts(self.models) 136 | optimizer_state_dicts = self._get_state_dicts(self.optimizers) 137 | val_results['model_state_dict'] = model_state_dicts 138 | val_results['optimizer_state_dict'] = optimizer_state_dicts 139 | self.val_logging_service.log(val_results, step=epoch, commit=True) 140 | 141 | return self.models 142 | 143 | def _load_models_to_device(self): 144 | for model in self.models.values(): 145 | model.to(self.device) 146 | 147 | def _to_train_mode(self, keys=None): 148 | keys = keys if keys else self.models.keys() 149 | for key in keys: 150 | self.models[key].train() 151 | 152 | def _to_eval_mode(self, keys=None): 153 | keys = keys if keys else self.models.keys() 154 | for key in keys: 155 | self.models[key].eval() 156 | 157 | def _reset_grad(self, keys=None): 158 | keys = keys if keys else self.optimizers.keys() 159 | for key in keys: 160 | self.optimizers[key].zero_grad() 161 | 162 | def _update_grad(self, keys=None, exclude_keys=None): 163 | keys = keys if keys else list(self.optimizers.keys()) 164 | if exclude_keys: 165 | keys = [key for key in keys if key not in exclude_keys] 166 | for key in keys: 167 | self.optimizers[key].step() 168 | 169 | def _step_schedulers(self): 170 | for scheduler in self.lr_schedulers.values(): 171 | scheduler.step() 172 | 173 | @staticmethod 174 | def _get_state_dicts(dict_of_models): 175 | state_dicts = {} 176 | for model_name, model in dict_of_models.items(): 177 | if isinstance(model, nn.DataParallel): 178 | state_dicts[model_name] = model.module.state_dict() 179 | else: 180 | state_dicts[model_name] = model.state_dict() 181 | return state_dicts 182 | 183 | @classmethod 184 | def code(cls) -> str: 185 | raise NotImplementedError 186 | 187 | 188 | METRIC_LOGGING_KEYS = { 189 | 'train_loss': 'train/loss', 190 | 'val_loss': 'val/loss', 191 | 'val_correct': 'val/correct' 192 | } 193 | STATE_DICT_KEY = 'state_dict' 194 | -------------------------------------------------------------------------------- /trainers/tirg_trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from trainers.abc import AbstractBaseTrainer 4 | from utils.metrics import AverageMeterSet 5 | 6 | 7 | class TIRGTrainer(AbstractBaseTrainer): 8 | def __init__(self, models, train_dataloader, criterions, optimizers, lr_schedulers, num_epochs, 9 | train_loggers, val_loggers, evaluator, *args, **kwargs): 10 | super().__init__(models, train_dataloader, criterions, optimizers, lr_schedulers, num_epochs, 11 | train_loggers, val_loggers, evaluator, *args, **kwargs) 12 | self.lower_image_encoder = self.models['lower_image_encoder'] 13 | self.upper_image_encoder = self.models['upper_image_encoder'] 14 | self.text_encoder = self.models['text_encoder'] 15 | self.compositor = self.models['layer4'] 16 | self.metric_loss = self.criterions['metric_loss'] 17 | 18 | def train_one_epoch(self, epoch): 19 | average_meter_set = AverageMeterSet() 20 | train_dataloader = tqdm(self.train_dataloader, desc="Epoch {}".format(epoch)) 21 | 22 | for batch_idx, (ref_images, tar_images, modifiers, len_modifiers) in enumerate(train_dataloader): 23 | ref_images, tar_images = ref_images.to(self.device), tar_images.to(self.device) 24 | modifiers, len_modifiers = modifiers.to(self.device), len_modifiers.to(self.device) 25 | 26 | self._reset_grad() 27 | # Encode Target Images 28 | tar_mid_features, _ = self.lower_image_encoder(tar_images) 29 | tar_features = self.upper_image_encoder(tar_mid_features) 30 | 31 | # Encode and Fuse Reference Images with Texts 32 | ref_mid_features, _ = self.lower_image_encoder(ref_images) 33 | text_features = self.text_encoder(modifiers, len_modifiers) 34 | composed_ref_features, _ = self.compositor(ref_mid_features, text_features) 35 | composed_ref_features = self.upper_image_encoder(composed_ref_features) 36 | 37 | # Compute Loss 38 | loss = self.metric_loss(composed_ref_features, tar_features) 39 | loss.backward() 40 | average_meter_set.update('loss', loss.item()) 41 | self._update_grad() 42 | 43 | self._step_schedulers() 44 | train_results = average_meter_set.averages() 45 | return train_results 46 | 47 | @classmethod 48 | def code(cls) -> str: 49 | return 'tirg' 50 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from transforms.image_transforms import image_transform_factory 2 | from transforms.text_transforms import text_transform_factory 3 | -------------------------------------------------------------------------------- /transforms/image_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 4 | 5 | 6 | def get_train_transform(transform_config: dict): 7 | use_transform = transform_config['use_transform'] 8 | img_size = transform_config['img_size'] 9 | 10 | if use_transform: 11 | return transforms.Compose([transforms.RandomResizedCrop(size=img_size, scale=(0.75, 1.33)), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize(**IMAGENET_STATS)]) 15 | 16 | return transforms.Compose([transforms.Resize((img_size, img_size)), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | transforms.Normalize(**IMAGENET_STATS)]) 20 | 21 | 22 | def get_val_transform(transform_config: dict): 23 | img_size = transform_config['img_size'] 24 | 25 | return transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), 26 | transforms.Normalize(**IMAGENET_STATS)]) 27 | 28 | 29 | def image_transform_factory(config: dict): 30 | return { 31 | 'train': get_train_transform(config), 32 | 'val': get_val_transform(config) 33 | } 34 | -------------------------------------------------------------------------------- /transforms/text_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torchvision import transforms 5 | 6 | from language import AbstractBaseVocabulary 7 | 8 | 9 | class ToIds(object): 10 | def __init__(self, vocabulary: AbstractBaseVocabulary): 11 | self.vocabulary = vocabulary 12 | 13 | def __call__(self, text: str) -> List[int]: 14 | return self.vocabulary.convert_text_to_ids(text) 15 | 16 | 17 | class ToLongTensor(object): 18 | def __call__(self, ids: List[int]) -> torch.LongTensor: 19 | return torch.LongTensor(ids) 20 | 21 | 22 | def text_transform_factory(config: dict): 23 | vocabulary = config['vocabulary'] 24 | 25 | return { 26 | 'train': transforms.Compose([ToIds(vocabulary), ToLongTensor()]), 27 | 'val': transforms.Compose([ToIds(vocabulary), ToLongTensor()]) 28 | } 29 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/postBG/CosMo.pytorch/ab4af36fafd7a0214d55810f42e74b034fb31f3c/utils/__init__.py -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | class AverageMeterSet(object): 2 | def __init__(self, meters=None): 3 | self.meters = meters if meters else {} 4 | 5 | def __getitem__(self, key): 6 | if key not in self.meters: 7 | meter = AverageMeter() 8 | meter.update(0) 9 | return meter 10 | return self.meters[key] 11 | 12 | def update(self, name, value, n=1): 13 | if name not in self.meters: 14 | self.meters[name] = AverageMeter() 15 | self.meters[name].update(value, n) 16 | 17 | def reset(self): 18 | for meter in self.meters.values(): 19 | meter.reset() 20 | 21 | def values(self, format_string='{}'): 22 | return {format_string.format(name): meter.val for name, meter in self.meters.items()} 23 | 24 | def averages(self, format_string='{}'): 25 | return {format_string.format(name): meter.avg for name, meter in self.meters.items()} 26 | 27 | def sums(self, format_string='{}'): 28 | return {format_string.format(name): meter.sum for name, meter in self.meters.items()} 29 | 30 | def counts(self, format_string='{}'): 31 | return {format_string.format(name): meter.count for name, meter in self.meters.items()} 32 | 33 | 34 | class AverageMeter(object): 35 | """Computes and stores the average and current value""" 36 | 37 | def __init__(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def reset(self): 44 | self.val = 0 45 | self.avg = 0 46 | self.sum = 0 47 | self.count = 0 48 | 49 | def update(self, val, n=1): 50 | self.val = val 51 | self.sum += val 52 | self.count += n 53 | self.avg = self.sum / self.count 54 | 55 | def __format__(self, format): 56 | return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format) 57 | -------------------------------------------------------------------------------- /utils/mixins.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class AbstractGradientControl(abc.ABC): 8 | @abc.abstractmethod 9 | def stash_grad(self, grad_dict): 10 | raise NotImplementedError 11 | 12 | @abc.abstractmethod 13 | def restore_grad(self, grad_dict): 14 | raise NotImplementedError 15 | 16 | 17 | class GradientControlMixin(AbstractGradientControl): 18 | def stash_grad(self, grad_dict): 19 | for k, v in self.named_parameters(): 20 | if k in grad_dict: 21 | grad_dict[k] += v.grad.clone() 22 | else: 23 | grad_dict[k] = v.grad.clone() 24 | self.zero_grad() 25 | return grad_dict 26 | 27 | def restore_grad(self, grad_dict): 28 | for k, v in self.named_parameters(): 29 | grad = grad_dict[k] if k in grad_dict else torch.zeros_like(v.grad) 30 | 31 | if v.grad is None: 32 | v.grad = grad 33 | else: 34 | v.grad += grad 35 | 36 | 37 | class GradientControlDataParallel(nn.DataParallel, AbstractGradientControl): 38 | def stash_grad(self, grad_dict): 39 | if isinstance(self.module, GradientControlMixin): 40 | return self.module.stash_grad(grad_dict) 41 | else: 42 | raise RuntimeError("A module should be an instance of GradientControlMixin") 43 | 44 | def restore_grad(self, grad_dict): 45 | if isinstance(self.module, GradientControlMixin): 46 | self.module.restore_grad(grad_dict) 47 | else: 48 | raise RuntimeError("A module should be an instance of GradientControlMixin") 49 | -------------------------------------------------------------------------------- /utils/tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def is_almost_equal(tens_a, tens_b, delta=1e-5): 5 | return torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), delta)) 6 | --------------------------------------------------------------------------------