├── .DS_Store ├── GSC ├── model │ ├── __pycache__ │ │ ├── DECL.cpython-37.pyc │ │ ├── DECL.cpython-39.pyc │ │ ├── RINCE.cpython-37.pyc │ │ ├── RINCE.cpython-39.pyc │ │ ├── SGRAF.cpython-37.pyc │ │ └── SGRAF.cpython-39.pyc │ ├── RINCE.py │ └── SGRAF.py ├── eval.py ├── utils.py ├── train_coco.sh ├── train_f30k.sh ├── train_cc152k.sh ├── vocab.py ├── data.py ├── opts.py ├── train.py └── evaluation.py ├── GSC+ ├── model │ ├── __pycache__ │ │ ├── DECL.cpython-37.pyc │ │ ├── DECL.cpython-39.pyc │ │ ├── RINCE.cpython-37.pyc │ │ ├── RINCE.cpython-39.pyc │ │ ├── SGRAF.cpython-37.pyc │ │ └── SGRAF.cpython-39.pyc │ ├── RINCE.py │ └── SGRAF.py ├── eval.py ├── utils.py ├── train_f30k.sh ├── train_coco.sh ├── train_cc152k.sh ├── vocab.py ├── data.py ├── opts.py ├── train.py └── evaluation.py ├── README.md └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/.DS_Store -------------------------------------------------------------------------------- /GSC/model/__pycache__/DECL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/DECL.cpython-37.pyc -------------------------------------------------------------------------------- /GSC/model/__pycache__/DECL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/DECL.cpython-39.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/DECL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/DECL.cpython-37.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/DECL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/DECL.cpython-39.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/RINCE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/RINCE.cpython-37.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/RINCE.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/RINCE.cpython-39.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/SGRAF.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/SGRAF.cpython-37.pyc -------------------------------------------------------------------------------- /GSC+/model/__pycache__/SGRAF.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/SGRAF.cpython-39.pyc -------------------------------------------------------------------------------- /GSC/model/__pycache__/RINCE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/RINCE.cpython-37.pyc -------------------------------------------------------------------------------- /GSC/model/__pycache__/RINCE.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/RINCE.cpython-39.pyc -------------------------------------------------------------------------------- /GSC/model/__pycache__/SGRAF.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/SGRAF.cpython-37.pyc -------------------------------------------------------------------------------- /GSC/model/__pycache__/SGRAF.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/SGRAF.cpython-39.pyc -------------------------------------------------------------------------------- /GSC/eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | from evaluation import eval_DECL 4 | 5 | if __name__ == '__main__': 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | # DECL-SGRAF. 8 | avg_SGRAF = False 9 | 10 | data_path= '/remote-home/share/zhaozh/NC_Datasets/data' 11 | vocab_path= '/remote-home/share/zhaozh/NC_Datasets/vocab' 12 | ns = [0] 13 | for n in ns: 14 | print(f"\n============================>f30k noise:{n}") 15 | # checkpoint_paths = [ 16 | # f'./f30K_SAF_noise{n}_best.tar', 17 | # f'./f30K_SGR_noise{n}_best.tar'] 18 | checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO/runs/f30k_gamma/r_corr_0/gamma_01/checkpoint_dir/model_best.pth.tar'] 19 | eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path) 20 | 21 | # for n in ns: 22 | # print(f"\n============================>coco noise:{n}") 23 | # # checkpoint_paths = [ 24 | # # f'./coco_SAF_noise{n}_best.tar', 25 | # # f'./coco_SGR_noise{n}_best.tar'] 26 | # checkpoint_paths = [f'/remote-home/zhaozh/NC/DECL/runs/coco/r_corr_{n}/checkpoint_dir/model_best.pth.tar'] 27 | # eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path) 28 | 29 | # print(f"\n============================>cc152k") 30 | # eval_DECL(['./cc152k_SAF_best.tar', './cc152k_SGR_best.tar'], avg_SGRAF=avg_SGRAF) 31 | -------------------------------------------------------------------------------- /GSC+/eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | from evaluation import eval_DECL 4 | 5 | if __name__ == '__main__': 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | # DECL-SGRAF. 8 | avg_SGRAF = True 9 | 10 | data_path= '/remote-home/share/zhaozh/NC_Datasets/data' 11 | vocab_path= '/remote-home/share/zhaozh/NC_Datasets/vocab' 12 | ns = [4] 13 | for n in ns: 14 | print(f"\n============================>f30k noise:{n}") 15 | # checkpoint_paths = [ 16 | # f'./f30K_SAF_noise{n}_best.tar', 17 | # f'./f30K_SGR_noise{n}_best.tar'] 18 | checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/f30k_batch128/r_corr_4/gamma_001_biinfo/checkpoint_dir/model_best.pth.tar', 19 | f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/f30k_batch128/r_corr_4/gamma_001/checkpoint_dir/model_best.pth.tar'] 20 | eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path) 21 | 22 | # for n in ns: 23 | # print(f"\n============================>coco noise:{n}") 24 | # # checkpoint_paths = [ 25 | # # f'./coco_SAF_noise{n}_best.tar', 26 | # # f'./coco_SGR_noise{n}_best.tar'] 27 | # checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/coco_batch128/r_corr_2/gamma_002/checkpoint_dir/model_best.pth.tar'] 28 | # eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path) 29 | 30 | # print(f"\n============================>cc152k") 31 | # eval_DECL(['./cc152k_SAF_best.tar', './cc152k_SGR_best.tar'], avg_SGRAF=avg_SGRAF) 32 | -------------------------------------------------------------------------------- /GSC+/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self, name="", fmt=":f"): 8 | self.name = name 9 | self.fmt = fmt 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | def __str__(self): 25 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 26 | return fmtstr.format(**self.__dict__) 27 | 28 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='',ckpt=True): 29 | tries = 15 30 | error = None 31 | # deal with unstable I/O. Usually not necessary. 32 | while tries: 33 | try: 34 | if ckpt: 35 | torch.save(state, prefix + filename) 36 | if is_best: 37 | torch.save(state, prefix + 'model_best.pth.tar') 38 | except IOError as e: 39 | error = e 40 | tries -= 1 41 | else: 42 | break 43 | print('model save {} failed, remaining {} trials'.format(filename, tries)) 44 | if not tries: 45 | raise error 46 | 47 | 48 | def save_config(opt, file_path): 49 | with open(file_path, "w") as f: 50 | json.dump(opt.__dict__, f, indent=2) 51 | 52 | 53 | def load_config(opt, file_path): 54 | with open(file_path, "r") as f: 55 | opt.__dict__ = json.load(f) 56 | -------------------------------------------------------------------------------- /GSC/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self, name="", fmt=":f"): 8 | self.name = name 9 | self.fmt = fmt 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | def __str__(self): 25 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 26 | return fmtstr.format(**self.__dict__) 27 | 28 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='',ckpt=True): 29 | tries = 15 30 | error = None 31 | # deal with unstable I/O. Usually not necessary. 32 | while tries: 33 | try: 34 | if ckpt: 35 | torch.save(state, prefix + filename) 36 | if is_best: 37 | torch.save(state, prefix + 'model_best.pth.tar') 38 | except IOError as e: 39 | error = e 40 | tries -= 1 41 | else: 42 | break 43 | print('model save {} failed, remaining {} trials'.format(filename, tries)) 44 | if not tries: 45 | raise error 46 | 47 | 48 | def save_config(opt, file_path): 49 | with open(file_path, "w") as f: 50 | json.dump(opt.__dict__, f, indent=2) 51 | 52 | 53 | def load_config(opt, file_path): 54 | with open(file_path, "r") as f: 55 | opt.__dict__ = json.load(f) 56 | -------------------------------------------------------------------------------- /GSC/train_coco.sh: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_name=coco_precomp 3 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 4 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 5 | 6 | # noise settings 7 | noise_ratio=0.6 8 | noise_file=noise_file_by_caption/coco/coco_precomp_0.6.npy 9 | 10 | # loss settings 11 | contrastive_loss=InfoNCE 12 | temp=0.07 13 | lam=0.01 14 | q=0.01 15 | 16 | # train settings 17 | gpu=0 18 | warmup_epoch=0 19 | num_epochs=60 20 | lr_update=20 21 | batch_size=128 22 | module_name=SGR 23 | folder_name=coco/r_corr_2 24 | 25 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 26 | --warmup_epoch $warmup_epoch \ 27 | --folder_name $folder_name \ 28 | --noise_ratio $noise_ratio \ 29 | --noise_file $noise_file \ 30 | --num_epochs $num_epochs \ 31 | --batch_size $batch_size \ 32 | --lr_update $lr_update \ 33 | --module_name $module_name \ 34 | --learning_rate 0.0002 \ 35 | --data_name $data_name \ 36 | --data_path $data_path \ 37 | --vocab_path $vocab_path \ 38 | --contrastive_loss $contrastive_loss \ 39 | --temp $temp \ 40 | --lam $lam \ 41 | --q $q \ 42 | 43 | -------------------------------------------------------------------------------- /GSC+/train_f30k.sh: -------------------------------------------------------------------------------- 1 | export OPENBLAS_NUM_THREADS=1 2 | export GOTO_NUM_THREADS=1 3 | export OMP_NUM_THREADS=1 4 | # dataset settings 5 | data_name=f30k_precomp 6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 8 | 9 | # noise settings 10 | noise_ratio=0.2 11 | noise_file=noise_file_by_caption/f30k/f30k_precomp_0.2.npy 12 | 13 | # loss settings 14 | contrastive_loss=InfoNCE 15 | temp=0.07 16 | lam=0.01 17 | q=0.01 18 | 19 | beta=0.7 20 | gamma=0.01 21 | 22 | # train settings 23 | gpu=0 24 | warmup_epoch=0 25 | num_epochs=60 26 | lr_update=15 27 | batch_size=128 28 | module_name=SGR 29 | folder_name=f30k/r_corr_2 30 | 31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 32 | --warmup_epoch $warmup_epoch \ 33 | --folder_name $folder_name \ 34 | --noise_ratio $noise_ratio \ 35 | --noise_file $noise_file \ 36 | --num_epochs $num_epochs \ 37 | --batch_size $batch_size \ 38 | --lr_update $lr_update \ 39 | --module_name $module_name \ 40 | --learning_rate 0.0002 \ 41 | --data_name $data_name \ 42 | --data_path $data_path \ 43 | --vocab_path $vocab_path \ 44 | --contrastive_loss $contrastive_loss \ 45 | --temp $temp \ 46 | --lam $lam \ 47 | --q $q \ 48 | --beta $beta \ 49 | --gamma $gamma \ 50 | 51 | -------------------------------------------------------------------------------- /GSC/train_f30k.sh: -------------------------------------------------------------------------------- 1 | export OPENBLAS_NUM_THREADS=1 2 | export GOTO_NUM_THREADS=1 3 | export OMP_NUM_THREADS=1 4 | # dataset settings 5 | data_name=f30k_precomp 6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 8 | 9 | # noise settings 10 | noise_ratio=0.2 11 | noise_file=noise_file_by_caption/f30k/f30k_precomp_0.2.npy 12 | 13 | # loss settings 14 | contrastive_loss=InfoNCE 15 | temp=0.07 16 | lam=0.01 17 | q=0.01 18 | 19 | beta=0.7 20 | gamma=0.01 21 | 22 | # train settings 23 | gpu=0 24 | warmup_epoch=0 25 | num_epochs=60 26 | lr_update=15 27 | batch_size=128 28 | module_name=SGR 29 | folder_name=f30k/r_corr_2 30 | 31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 32 | --warmup_epoch $warmup_epoch \ 33 | --folder_name $folder_name \ 34 | --noise_ratio $noise_ratio \ 35 | --noise_file $noise_file \ 36 | --num_epochs $num_epochs \ 37 | --batch_size $batch_size \ 38 | --lr_update $lr_update \ 39 | --module_name $module_name \ 40 | --learning_rate 0.0002 \ 41 | --data_name $data_name \ 42 | --data_path $data_path \ 43 | --vocab_path $vocab_path \ 44 | --contrastive_loss $contrastive_loss \ 45 | --temp $temp \ 46 | --lam $lam \ 47 | --q $q \ 48 | --beta $beta \ 49 | --gamma $gamma \ 50 | 51 | -------------------------------------------------------------------------------- /GSC/train_cc152k.sh: -------------------------------------------------------------------------------- 1 | export OPENBLAS_NUM_THREADS=1 2 | export GOTO_NUM_THREADS=1 3 | export OMP_NUM_THREADS=1 4 | # dataset settings 5 | data_name=cc152k_precomp 6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 8 | 9 | # noise settings 10 | noise_ratio=0.0 11 | noise_file=/remote-home/share/zhaozh/NC_Datasets/data 12 | 13 | # loss settings 14 | contrastive_loss=InfoNCE 15 | temp=0.07 16 | lam=0.01 17 | q=0.01 18 | 19 | beta=0.7 20 | gamma=0.01 21 | 22 | # train settings 23 | gpu=0 24 | warmup_epoch=0 25 | num_epochs=60 26 | lr_update=15 27 | batch_size=128 28 | module_name=SGR 29 | folder_name=cc152k/r_corr_2 30 | 31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 32 | --warmup_epoch $warmup_epoch \ 33 | --folder_name $folder_name \ 34 | --noise_ratio $noise_ratio \ 35 | --noise_file $noise_file \ 36 | --num_epochs $num_epochs \ 37 | --batch_size $batch_size \ 38 | --lr_update $lr_update \ 39 | --module_name $module_name \ 40 | --learning_rate 0.0002 \ 41 | --data_name $data_name \ 42 | --data_path $data_path \ 43 | --vocab_path $vocab_path \ 44 | --contrastive_loss $contrastive_loss \ 45 | --temp $temp \ 46 | --lam $lam \ 47 | --q $q \ 48 | --beta $beta \ 49 | --gamma $gamma \ 50 | 51 | 52 | -------------------------------------------------------------------------------- /GSC+/train_coco.sh: -------------------------------------------------------------------------------- 1 | export OPENBLAS_NUM_THREADS=1 2 | export GOTO_NUM_THREADS=1 3 | export OMP_NUM_THREADS=1 4 | # dataset settings 5 | data_name=coco_precomp 6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 8 | 9 | # noise settings 10 | noise_ratio=0.2 11 | noise_file=noise_file_by_caption/coco/coco_precomp_0.2.npy 12 | 13 | # loss settings 14 | contrastive_loss=InfoNCE 15 | temp=0.07 16 | lam=0.01 17 | q=0.01 18 | 19 | beta=0.7 20 | gamma=0.05 21 | 22 | # train settings 23 | gpu=0 24 | warmup_epoch=0 25 | num_epochs=30 26 | lr_update=15 27 | batch_size=128 28 | module_name=SGR 29 | folder_name=coco/r_corr_2 30 | 31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 32 | --warmup_epoch $warmup_epoch \ 33 | --folder_name $folder_name \ 34 | --noise_ratio $noise_ratio \ 35 | --noise_file $noise_file \ 36 | --num_epochs $num_epochs \ 37 | --batch_size $batch_size \ 38 | --lr_update $lr_update \ 39 | --module_name $module_name \ 40 | --learning_rate 0.0002 \ 41 | --data_name $data_name \ 42 | --data_path $data_path \ 43 | --vocab_path $vocab_path \ 44 | --contrastive_loss $contrastive_loss \ 45 | --temp $temp \ 46 | --lam $lam \ 47 | --q $q \ 48 | --beta $beta \ 49 | --gamma $gamma \ 50 | 51 | -------------------------------------------------------------------------------- /GSC+/train_cc152k.sh: -------------------------------------------------------------------------------- 1 | export OPENBLAS_NUM_THREADS=1 2 | export GOTO_NUM_THREADS=1 3 | export OMP_NUM_THREADS=1 4 | # dataset settings 5 | data_name=cc152k_precomp 6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data 7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab 8 | 9 | # noise settings 10 | noise_ratio=0.0 11 | noise_file=/remote-home/share/zhaozh/NC_Datasets/data 12 | 13 | # loss settings 14 | contrastive_loss=InfoNCE 15 | temp=0.07 16 | lam=0.01 17 | q=0.01 18 | 19 | beta=0.7 20 | gamma=0.02 21 | 22 | # train settings 23 | gpu=0 24 | warmup_epoch=0 25 | num_epochs=60 26 | lr_update=20 27 | batch_size=128 28 | module_name=SGR 29 | folder_name=cc152k/r_corr_0 30 | 31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \ 32 | --warmup_epoch $warmup_epoch \ 33 | --folder_name $folder_name \ 34 | --noise_ratio $noise_ratio \ 35 | --noise_file $noise_file \ 36 | --num_epochs $num_epochs \ 37 | --batch_size $batch_size \ 38 | --lr_update $lr_update \ 39 | --module_name $module_name \ 40 | --learning_rate 0.0002 \ 41 | --data_name $data_name \ 42 | --data_path $data_path \ 43 | --vocab_path $vocab_path \ 44 | --contrastive_loss $contrastive_loss \ 45 | --temp $temp \ 46 | --lam $lam \ 47 | --q $q \ 48 | --beta $beta \ 49 | --gamma $gamma \ 50 | 51 | 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mitigating Noisy Correspondence by Geometrical Structure Consistency Learning 2 | 3 |
arXiv OpenReview 4 | GitHub
5 | 6 | by Zihua Zhao, Mengzi Chen, Tianjie Dai, Jiangchao Yao, Bo Han, Ya Zhang, Yanfeng Wang at Cooperative Medianet Innovation Center, Shanghai Jiao Tong University, Shanghai Artificial Intelligence Laboratory, Hong Kong Baptist University 7 | 8 | Computer Vision and Pattern Recognition (CVPR), 2024. 9 | 10 | This repo is the official Pytorch implementation of GSC. 11 | 12 | ⚠️ This repository is being organized and updated continuously. Please note that this version is not the final release. 13 | 14 | ## Environment 15 | 16 | Code is implemented based on original code provided by DECL from https://github.com/QinYang79/DECL, which offers the standard code for retrieval framework, data loader and evaluation metrics. 17 | 18 | Besides, create the environment for running our code: 19 | 20 | ```bash 21 | conda create --name GSC python=3.9.13 22 | conda activate GSC 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Data Preparation 27 | 28 | For all of the dataset settings, we follow DECL including data split, sample preprocessing and noise simulation. 29 | 30 | #### MS-COCO and Flickr30K 31 | 32 | MS-COCO and Flickr30K are two datasets with simulated noisy correspondence. The simulated noise is created by randomly flipping captions. The images in the datasets are first extracted into features by SCAN (https://github.com/kuanghuei/SCAN). 33 | 34 | ### CC152K 35 | 36 | CC152K is a real-world noisy correspondence dataset. The images are also extracted into features by SCAN. 37 | 38 | Notably, SCAN is implemented on Python2.7 and Caffe, which may not be supported on latest Cuda versions. We are extracting the features through renting machines from platform AutoDL (https://www.autodl.com/login?url=/console/homepage/personal). 39 | 40 | ## Running 41 | 42 | For training on different datasets, we provide integrated bash orders. Evaluation metrics are implemented in eval.py. 43 | 44 | ``` 45 | # for training and evaluation on Flickr30K as an example 46 | conda activate GSC 47 | sh train_f30k.sh 48 | python eval.py 49 | ``` 50 | 51 | ## Citation 52 | 53 | If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation. 54 | 55 | ``` 56 | @inproceedings{zhao2024mitigating, 57 | title={Mitigating Noisy Correspondence by Geometrical Structure Consistency Learning}, 58 | author={Zhao, Zihua and Chen, Mengxi and Dai, Tianjie and Yao, Jiangchao and Han, Bo and Zhang, Ya and Wang, Yanfeng}, 59 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 60 | pages={27381--27390}, 61 | year={2024} 62 | } 63 | ``` 64 | 65 | If you have any problem with this code, please feel free to contact **[sjtuszzh@sjtu.edu.cn](mailto:sjtuszzh@sjtu.edu.cn)**. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | alabaster==0.7.13 3 | argon2-cffi==21.1.0 4 | async-generator==1.10 5 | attrs==21.2.0 6 | Babel==2.12.1 7 | backcall==0.2.0 8 | bleach==4.1.0 9 | cachetools==5.3.0 10 | certifi==2021.5.30 11 | cffi==1.15.0 12 | charset-normalizer==2.0.7 13 | click==8.1.3 14 | cloudpickle==1.6.0 15 | cycler==0.10.0 16 | Cython==0.29.34 17 | dask==2021.3.0 18 | dataclasses==0.6 19 | decorator==4.4.2 20 | defusedxml==0.7.1 21 | docopt==0.6.2 22 | docutils==0.19 23 | entrypoints==0.3 24 | filelock==3.12.0 25 | fsspec==2023.1.0 26 | google-auth==2.17.3 27 | google-auth-oauthlib==0.4.6 28 | greenlet==1.1.2 29 | grpcio==1.41.1 30 | huggingface-hub==0.14.0 31 | idna==3.3 32 | imageio==2.9.0 33 | imagesize==1.4.1 34 | importlib-metadata==4.8.2 35 | ipykernel==5.5.6 36 | ipython==7.16.1 37 | ipython-genutils==0.2.0 38 | ipywidgets==7.6.5 39 | jedi==0.17.0 40 | Jinja2==3.0.3 41 | joblib==1.0.1 42 | jsonschema==3.2.0 43 | jupyter==1.0.0 44 | jupyter-client==7.0.6 45 | jupyter-console==6.4.0 46 | jupyter-core==4.9.1 47 | jupyterlab-pygments==0.1.2 48 | jupyterlab-widgets==1.0.2 49 | kiwisolver==1.3.1 50 | lmdb==1.4.1 51 | Markdown==3.3.4 52 | MarkupSafe==2.0.1 53 | matplotlib==3.3.4 54 | MedPy==0.4.0 55 | mistune==0.8.4 56 | monai==0.7.0 57 | nbclient==0.5.8 58 | nbconvert==6.0.7 59 | nbformat==5.1.3 60 | nest-asyncio==1.5.1 61 | networkx==2.5.1 62 | nibabel==3.2.1 63 | nlpaug==1.1.7 64 | nltk==3.8.1 65 | notebook==6.4.5 66 | numpy==1.21.6 67 | oauthlib==3.2.2 68 | olefile==0.46 69 | opencv-contrib-python==4.5.2.52 70 | opencv-python-headless==4.5.4.58 71 | packaging==21.2 72 | pandas==1.1.5 73 | pandocfilters==1.5.0 74 | parso==0.8.2 75 | pexpect==4.8.0 76 | pickleshare==0.7.5 77 | Pillow==8.4.0 78 | pipreqs==0.4.11 79 | prefetch-generator==1.0.3 80 | prometheus-client==0.12.0 81 | prompt-toolkit==3.0.22 82 | protobuf==3.17.3 83 | ptyprocess==0.7.0 84 | pyasn1==0.5.0 85 | pyasn1-modules==0.3.0 86 | pycparser==2.21 87 | pydicom==2.2.2 88 | Pygments==2.15.1 89 | pylidc==0.2.2 90 | pyparsing==2.4.7 91 | pyrsistent==0.18.0 92 | python-dateutil==2.8.2 93 | pytz==2021.1 94 | PyWavelets==1.1.1 95 | PyYAML==5.4.1 96 | pyzmq==20.0.0 97 | qtconsole==5.2.0 98 | QtPy==1.9.0 99 | regex==2022.10.31 100 | requests==2.26.0 101 | requests-oauthlib==1.3.1 102 | rsa==4.9 103 | seaborn==0.12.2 104 | Send2Trash==1.8.0 105 | setuptools==58.0.4 106 | SimpleITK==2.0.2 107 | six==1.16.0 108 | snowballstemmer==2.2.0 109 | Sphinx==5.3.0 110 | sphinx-gallery==0.13.0 111 | sphinxcontrib-applehelp==1.0.2 112 | sphinxcontrib-devhelp==1.0.2 113 | sphinxcontrib-htmlhelp==2.0.0 114 | sphinxcontrib-jsmath==1.0.1 115 | sphinxcontrib-qthelp==1.0.3 116 | sphinxcontrib-serializinghtml==1.1.5 117 | SQLAlchemy==1.2.0 118 | tensorboard==2.11.2 119 | tensorboard-data-server==0.6.1 120 | tensorboard-logger==0.1.0 121 | tensorboard-plugin-wit==1.8.1 122 | tensorboardX==2.4 123 | terminado==0.9.2 124 | testpath==0.5.0 125 | threadpoolctl==2.1.0 126 | tifffile==2020.9.3 127 | tokenizers==0.13.3 128 | toolz==0.11.1 129 | torch-tb-profiler==0.4.1 130 | tornado==6.1 131 | tqdm==4.62.3 132 | traitlets==4.3.3 133 | transformers==4.28.1 134 | typing-extensions==3.7.4.3 135 | urllib3==1.26.7 136 | wcwidth==0.2.5 137 | webencodings==0.5.1 138 | Werkzeug==2.0.2 139 | wheel==0.37.0 140 | widgetsnbextension==3.5.1 141 | yarg==0.1.9 142 | zipp==3.6.0 143 | -------------------------------------------------------------------------------- /GSC+/vocab.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Vocabul ary wrapper""" 10 | 11 | import nltk 12 | from collections import Counter 13 | import argparse 14 | import os 15 | import json 16 | import csv 17 | 18 | annotations = { 19 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 20 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 21 | 'cc152k_precomp': ['train_caps.tsv', 'dev_caps.tsv'], 22 | } 23 | 24 | 25 | class Vocabulary(object): 26 | """Simple vocabulary wrapper.""" 27 | 28 | def __init__(self): 29 | self.word2idx = {} 30 | self.idx2word = {} 31 | self.idx = 0 32 | 33 | def add_word(self, word): 34 | if word not in self.word2idx: 35 | self.word2idx[word] = self.idx 36 | self.idx2word[self.idx] = word 37 | self.idx += 1 38 | 39 | def __call__(self, word): 40 | if word not in self.word2idx: 41 | return self.word2idx[''] 42 | return self.word2idx[word] 43 | 44 | def __len__(self): 45 | return len(self.word2idx) 46 | 47 | 48 | def serialize_vocab(vocab, dest): 49 | d = {} 50 | d['word2idx'] = vocab.word2idx 51 | d['idx2word'] = vocab.idx2word 52 | d['idx'] = vocab.idx 53 | with open(dest, "w") as f: 54 | json.dump(d, f) 55 | 56 | 57 | def deserialize_vocab(src): 58 | with open(src) as f: 59 | d = json.load(f) 60 | vocab = Vocabulary() 61 | vocab.word2idx = d['word2idx'] 62 | vocab.idx2word = d['idx2word'] 63 | vocab.idx = d['idx'] 64 | return vocab 65 | 66 | 67 | def from_txt(txt): 68 | captions = [] 69 | with open(txt, 'r') as f: 70 | for line in f: 71 | captions.append(line.strip()) 72 | return captions 73 | 74 | def from_tsv(tsv): 75 | captions = [] 76 | img_ids = [] 77 | with open(tsv) as f: 78 | tsvreader = csv.reader(f, delimiter='\t') 79 | for line in tsvreader: 80 | captions.append(line[1]) 81 | img_ids.append(line[0]) 82 | return captions 83 | 84 | def build_vocab(data_path, data_name, caption_file, threshold): 85 | """Build a simple vocabulary wrapper.""" 86 | counter = Counter() 87 | for path in caption_file[data_name]: 88 | if data_name == 'cc152k_precomp': 89 | full_path = os.path.join(os.path.join(data_path, data_name), path) 90 | captions = from_tsv(full_path) 91 | elif data_name in ['coco_precomp', 'f30k_precomp']: 92 | full_path = os.path.join(os.path.join(data_path, data_name), path) 93 | captions = from_txt(full_path) 94 | else: 95 | raise NotImplementedError('Not support!') 96 | 97 | for i, caption in enumerate(captions): 98 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 99 | counter.update(tokens) 100 | 101 | if i % 1000 == 0: 102 | print("[%d/%d] tokenized the captions." % (i, len(captions))) 103 | 104 | # Discard if the occurrence of the word is less than min_word_cnt. 105 | words = [word for word, cnt in counter.items() if cnt >= threshold] 106 | 107 | # Create a vocab wrapper and add some special tokens. 108 | vocab = Vocabulary() 109 | vocab.add_word('') 110 | vocab.add_word('') 111 | vocab.add_word('') 112 | vocab.add_word('') 113 | 114 | # Add words to the vocabulary. 115 | for i, word in enumerate(words): 116 | vocab.add_word(word) 117 | return vocab 118 | 119 | 120 | def main(data_path, data_name): 121 | vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4) 122 | serialize_vocab(vocab, './%s_vocab.json' % data_name) 123 | print("Saved vocabulary file to ", './%s_vocab.json' % data_name) 124 | 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--data_path', default='/data/RR/data') 130 | parser.add_argument('--data_name', default='cc152k_precomp', 131 | help='{coco,f30k,cc152k}_precomp') 132 | opt = parser.parse_args() 133 | main(opt.data_path, opt.data_name) 134 | -------------------------------------------------------------------------------- /GSC/vocab.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Vocabul ary wrapper""" 10 | 11 | import nltk 12 | from collections import Counter 13 | import argparse 14 | import os 15 | import json 16 | import csv 17 | 18 | annotations = { 19 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 20 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 21 | 'cc152k_precomp': ['train_caps.tsv', 'dev_caps.tsv'], 22 | } 23 | 24 | 25 | class Vocabulary(object): 26 | """Simple vocabulary wrapper.""" 27 | 28 | def __init__(self): 29 | self.word2idx = {} 30 | self.idx2word = {} 31 | self.idx = 0 32 | 33 | def add_word(self, word): 34 | if word not in self.word2idx: 35 | self.word2idx[word] = self.idx 36 | self.idx2word[self.idx] = word 37 | self.idx += 1 38 | 39 | def __call__(self, word): 40 | if word not in self.word2idx: 41 | return self.word2idx[''] 42 | return self.word2idx[word] 43 | 44 | def __len__(self): 45 | return len(self.word2idx) 46 | 47 | 48 | def serialize_vocab(vocab, dest): 49 | d = {} 50 | d['word2idx'] = vocab.word2idx 51 | d['idx2word'] = vocab.idx2word 52 | d['idx'] = vocab.idx 53 | with open(dest, "w") as f: 54 | json.dump(d, f) 55 | 56 | 57 | def deserialize_vocab(src): 58 | with open(src) as f: 59 | d = json.load(f) 60 | vocab = Vocabulary() 61 | vocab.word2idx = d['word2idx'] 62 | vocab.idx2word = d['idx2word'] 63 | vocab.idx = d['idx'] 64 | return vocab 65 | 66 | 67 | def from_txt(txt): 68 | captions = [] 69 | with open(txt, 'r') as f: 70 | for line in f: 71 | captions.append(line.strip()) 72 | return captions 73 | 74 | def from_tsv(tsv): 75 | captions = [] 76 | img_ids = [] 77 | with open(tsv) as f: 78 | tsvreader = csv.reader(f, delimiter='\t') 79 | for line in tsvreader: 80 | captions.append(line[1]) 81 | img_ids.append(line[0]) 82 | return captions 83 | 84 | def build_vocab(data_path, data_name, caption_file, threshold): 85 | """Build a simple vocabulary wrapper.""" 86 | counter = Counter() 87 | for path in caption_file[data_name]: 88 | if data_name == 'cc152k_precomp': 89 | full_path = os.path.join(os.path.join(data_path, data_name), path) 90 | captions = from_tsv(full_path) 91 | elif data_name in ['coco_precomp', 'f30k_precomp']: 92 | full_path = os.path.join(os.path.join(data_path, data_name), path) 93 | captions = from_txt(full_path) 94 | else: 95 | raise NotImplementedError('Not support!') 96 | 97 | for i, caption in enumerate(captions): 98 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 99 | counter.update(tokens) 100 | 101 | if i % 1000 == 0: 102 | print("[%d/%d] tokenized the captions." % (i, len(captions))) 103 | 104 | # Discard if the occurrence of the word is less than min_word_cnt. 105 | words = [word for word, cnt in counter.items() if cnt >= threshold] 106 | 107 | # Create a vocab wrapper and add some special tokens. 108 | vocab = Vocabulary() 109 | vocab.add_word('') 110 | vocab.add_word('') 111 | vocab.add_word('') 112 | vocab.add_word('') 113 | 114 | # Add words to the vocabulary. 115 | for i, word in enumerate(words): 116 | vocab.add_word(word) 117 | return vocab 118 | 119 | 120 | def main(data_path, data_name): 121 | vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4) 122 | serialize_vocab(vocab, './%s_vocab.json' % data_name) 123 | print("Saved vocabulary file to ", './%s_vocab.json' % data_name) 124 | 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--data_path', default='/data/RR/data') 130 | parser.add_argument('--data_name', default='cc152k_precomp', 131 | help='{coco,f30k,cc152k}_precomp') 132 | opt = parser.parse_args() 133 | main(opt.data_path, opt.data_name) 134 | -------------------------------------------------------------------------------- /GSC/data.py: -------------------------------------------------------------------------------- 1 | """Dataloader""" 2 | 3 | import csv 4 | import torch 5 | import torch.utils.data as data 6 | import os 7 | import nltk 8 | import numpy as np 9 | 10 | 11 | class PrecompDataset(data.Dataset): 12 | """ 13 | Load precomputed captions and image features 14 | Possible options: f30k_precomp, coco_precomp, cc152k_precomp 15 | """ 16 | 17 | def __init__(self, data_path, data_split, vocab, opt=None): 18 | self.vocab = vocab 19 | loc = data_path + '/' 20 | 21 | # load the raw captions 22 | self.captions = [] 23 | 24 | if 'cc152k' in opt.data_name: 25 | with open(loc + '%s_caps.tsv' % data_split) as f: 26 | tsvreader = csv.reader(f, delimiter='\t') 27 | for line in tsvreader: 28 | self.captions.append(line[1].strip()) 29 | else: 30 | with open(loc + '%s_caps.txt' % data_split, 'r', encoding='utf-8') as f: 31 | for line in f.readlines(): 32 | self.captions.append(line.strip()) 33 | 34 | # load the image features 35 | self.images = np.load(loc + '%s_ims.npy' % data_split) 36 | img_len = self.images.shape[0] 37 | cap_len = len(self.captions) 38 | 39 | # rkiros data has redundancy in images, we divide by 5 40 | if img_len != cap_len: 41 | self.im_div = 5 42 | else: 43 | self.im_div = 1 44 | 45 | # if use noise index by image: self.noisy_inx = np.arange(img_len) 46 | self.noisy_inx = np.arange(cap_len) // self.im_div 47 | if data_split == 'train' and opt.noise_ratio > 0.0: 48 | noise_file = opt.noise_file 49 | if os.path.exists(noise_file): 50 | print('=> load noisy index from {}'.format(noise_file)) 51 | self.noisy_inx = np.load(noise_file) 52 | else: 53 | noise_ratio = opt.noise_ratio 54 | inx = np.arange(img_len) 55 | np.random.shuffle(inx) 56 | noisy_inx = inx[0: int(noise_ratio * img_len)] 57 | shuffle_noisy_inx = np.array(noisy_inx) 58 | np.random.shuffle(shuffle_noisy_inx) 59 | self.noisy_inx[noisy_inx] = shuffle_noisy_inx 60 | print('Noisy ratio: %g' % noise_ratio) 61 | self.length = len(self.captions) 62 | self.corrs = (self.noisy_inx == (np.arange(self.length) // self.im_div)).astype(np.int) 63 | 64 | # the development set for coco is large and so validation would be slow 65 | if data_split == 'dev': 66 | self.length = 5000 67 | if 'cc152k' in opt.data_name: 68 | self.length = 1000 69 | 70 | def __getitem__(self, index): 71 | # handle the image redundancy 72 | # if use noise index by images: img_id = self.noisy_inx[int(index / self.im_div)] 73 | img_id = self.noisy_inx[index] 74 | image = torch.Tensor(self.images[img_id]) 75 | caption = self.captions[index] 76 | vocab = self.vocab 77 | 78 | # convert caption (string) to word ids. 79 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 80 | caption = [] 81 | caption.append(vocab('')) 82 | caption.extend([vocab(token) for token in tokens]) 83 | caption.append(vocab('')) 84 | target = torch.Tensor(caption) 85 | if img_id == index // self.im_div: 86 | label = 1 87 | else: 88 | label = 0 89 | return image, target, index, img_id, label 90 | 91 | def __len__(self): 92 | return self.length 93 | 94 | 95 | def collate_fn(data): 96 | # Sort a data list by caption length 97 | data.sort(key=lambda x: len(x[1]), reverse=True) 98 | images, captions, ids, img_ids, label = zip(*data) 99 | 100 | # Merge images (convert tuple of 2D tensor to 3D tensor) 101 | images = torch.stack(images, 0) 102 | 103 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 104 | lengths = [len(cap) for cap in captions] 105 | targets = torch.zeros(len(captions), max(lengths)).long() 106 | for i, cap in enumerate(captions): 107 | end = lengths[i] 108 | targets[i, :end] = cap[:end] 109 | 110 | return images, targets, lengths, list(ids), list(label) 111 | 112 | 113 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100, 114 | shuffle=True, num_workers=2): 115 | dset = PrecompDataset(data_path, data_split, vocab, opt) 116 | 117 | data_loader = torch.utils.data.DataLoader(dataset=dset, 118 | batch_size=batch_size, 119 | shuffle=shuffle, 120 | pin_memory=False, 121 | num_workers=num_workers, 122 | collate_fn=collate_fn) 123 | return data_loader 124 | 125 | 126 | def get_loaders(data_name, vocab, batch_size, workers, opt): 127 | # get the data path 128 | dpath = os.path.join(opt.data_path, data_name) 129 | 130 | # get the train_loader 131 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt, 132 | batch_size, True, workers) 133 | # get the val_loader 134 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt, 135 | batch_size, False, workers) 136 | # get the test_loader 137 | test_loader = get_precomp_loader(dpath, 'test', vocab, opt, 138 | batch_size, False, workers) 139 | return train_loader, val_loader, test_loader 140 | 141 | 142 | def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt): 143 | # get the data path 144 | dpath = os.path.join(opt.data_path, data_name) 145 | 146 | # get the test_loader 147 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt, 148 | batch_size, False, workers) 149 | return test_loader 150 | -------------------------------------------------------------------------------- /GSC/opts.py: -------------------------------------------------------------------------------- 1 | """Argument parser""" 2 | 3 | import argparse 4 | import os 5 | import random 6 | import sys 7 | import numpy as np 8 | import torch 9 | 10 | from utils import save_config 11 | 12 | 13 | def parse_opt(): 14 | parser = argparse.ArgumentParser() 15 | # --------------------------- data path -------------------------# 16 | parser.add_argument('--data_name', default='cc152k_precomp', 17 | help='{coco,f30k,cc152k}_precomp') 18 | parser.add_argument('--data_path', default='/remote-home/share/zhaozh/NC_Datasets/data', 19 | help='path to datasets') 20 | parser.add_argument('--vocab_path', default='/remote-home/share/zhaozh/NC_Datasets/vocab', 21 | help='Path to saved vocabulary json files.') 22 | parser.add_argument('--folder_name', default='debug', help='Folder name to save the running results') 23 | 24 | # ----------------------- training setting ----------------------# 25 | parser.add_argument('--batch_size', default=128, type=int, 26 | help='Size of a training mini-batch.') 27 | parser.add_argument('--num_epochs', default=60, type=int, 28 | help='Number of training epochs.') 29 | parser.add_argument('--lr_update', default=20, type=int, 30 | help='Number of epochs to update the learning rate.') 31 | parser.add_argument('--learning_rate', default=2e-4, type=float, 32 | help='Initial learning rate.') 33 | parser.add_argument('--workers', default=4, type=int, 34 | help='Number of data loader workers.') 35 | parser.add_argument('--log_step', default=100, type=int, 36 | help='Number of steps to print and record the log.') 37 | parser.add_argument('--val_step', default=1000, type=int, 38 | help='Number of steps to run validation.') 39 | parser.add_argument('--grad_clip', default=2., type=float, 40 | help='Gradient clipping threshold.') 41 | 42 | # ------------------------- model settings (SGR, SAF, and other similarity models) -----------------------# 43 | parser.add_argument('--margin', default=0.2, type=float, 44 | help='Rank loss margin.') 45 | parser.add_argument('--max_violation', action='store_false', 46 | help='Use max instead of sum in the rank loss.') 47 | parser.add_argument('--img_dim', default=2048, type=int, 48 | help='Dimensionality of the image embedding.') 49 | parser.add_argument('--word_dim', default=300, type=int, 50 | help='Dimensionality of the word embedding.') 51 | parser.add_argument('--embed_size', default=1024, type=int, 52 | help='Dimensionality of the joint embedding.') 53 | parser.add_argument('--sim_dim', default=256, type=int, 54 | help='Dimensionality of the sim embedding.') 55 | parser.add_argument('--num_layers', default=1, type=int, 56 | help='Number of GRU layers.') 57 | parser.add_argument('--bi_gru', action='store_false', 58 | help='Use bidirectional GRU.') 59 | parser.add_argument('--no_imgnorm', action='store_true', 60 | help='Do not normalize the image embeddings.') 61 | parser.add_argument('--no_txtnorm', action='store_true', 62 | help='Do not normalize the text embeddings.') 63 | parser.add_argument('--module_name', default='SGR', type=str, 64 | help='SGR, SAF') 65 | parser.add_argument('--sgr_step', default=3, type=int, 66 | help='Step of the SGR.') 67 | 68 | # ------------------------- our DECL settings -----------------------# 69 | parser.add_argument('--warmup_if', default=False, type=bool, 70 | help='Need warmup training?') 71 | parser.add_argument('--warmup_epochs', default=0, type=int, 72 | help='Number of warmup epochs.') 73 | parser.add_argument('--warmup_model_path', default='', 74 | help='Path to load a pre-model') 75 | # noise settings 76 | parser.add_argument('--noise_ratio', default=0.0, type=float, 77 | help='Noisy ratio') 78 | parser.add_argument('--noise_file', default='/remote-home/share/zhaozh/NC_Datasets/data/noise_file_by_caption/f30k/f30k_precomp_0.0.npy', 79 | help='Path to noise index file') 80 | parser.add_argument("--seed", default=random.randint(0, 100), type=int, help="Random seed.") 81 | 82 | # loss settings 83 | parser.add_argument('--contrastive_loss', default='InfoNCE', help='Choose in Triplet/InfoNCE/RINCE.') 84 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature for Info NCE Loss.') 85 | parser.add_argument('--lam', default=0.01, type=float, help='Parameter for RINCE.') 86 | parser.add_argument('--q', default=0.1, type=float, help='Parameter for RINCE.') 87 | 88 | # elr settings 89 | parser.add_argument('--beta', default=0.7, type=float) 90 | parser.add_argument('--alpha', default=1.0, type=float) 91 | parser.add_argument('--gamma', default=1.0, type=float) 92 | 93 | # gpu settings 94 | parser.add_argument('--gpu', default='0', help='Which gpu to use.') 95 | 96 | opt = parser.parse_args() 97 | project_path = str(sys.path[0]) 98 | opt.log_dir = f'{project_path}/runs/{opt.folder_name}/log_dir' 99 | opt.checkpoint_dir = f'{project_path}/runs/{opt.folder_name}/checkpoint_dir' 100 | if not os.path.isdir(opt.log_dir): 101 | os.makedirs(opt.log_dir) 102 | if not os.path.isdir(opt.checkpoint_dir): 103 | os.makedirs(opt.checkpoint_dir) 104 | 105 | # save opt 106 | save_config(opt, os.path.join(opt.log_dir, "config.json")) 107 | 108 | # set random seed 109 | random.seed(opt.seed) 110 | np.random.seed(opt.seed) 111 | torch.manual_seed(opt.seed) 112 | torch.random.manual_seed(opt.seed) 113 | if torch.cuda.is_available(): 114 | torch.cuda.manual_seed(opt.seed) 115 | torch.backends.cudnn.benchmark = True 116 | return opt 117 | -------------------------------------------------------------------------------- /GSC+/data.py: -------------------------------------------------------------------------------- 1 | """Dataloader""" 2 | 3 | import csv 4 | import torch 5 | import torch.utils.data as data 6 | import os 7 | import nltk 8 | import numpy as np 9 | 10 | 11 | class PrecompDataset(data.Dataset): 12 | """ 13 | Load precomputed captions and image features 14 | Possible options: f30k_precomp, coco_precomp, cc152k_precomp 15 | """ 16 | 17 | def __init__(self, data_path, data_split, vocab, opt=None): 18 | self.vocab = vocab 19 | loc = data_path + '/' 20 | 21 | # load the raw captions 22 | self.captions = [] 23 | 24 | if 'cc152k' in opt.data_name: 25 | with open(loc + '%s_caps.tsv' % data_split) as f: 26 | tsvreader = csv.reader(f, delimiter='\t') 27 | for line in tsvreader: 28 | self.captions.append(line[1].strip()) 29 | else: 30 | with open(loc + '%s_caps.txt' % data_split, 'r', encoding='utf-8') as f: 31 | for line in f.readlines(): 32 | self.captions.append(line.strip()) 33 | 34 | # load the image features 35 | self.images = np.load(loc + '%s_ims.npy' % data_split) 36 | img_len = self.images.shape[0] 37 | cap_len = len(self.captions) 38 | 39 | # rkiros data has redundancy in images, we divide by 5 40 | if img_len != cap_len: 41 | self.im_div = 5 42 | else: 43 | self.im_div = 1 44 | 45 | # if use noise index by image: self.noisy_inx = np.arange(img_len) 46 | self.noisy_inx = np.arange(cap_len) // self.im_div 47 | if data_split == 'train' and opt.noise_ratio > 0.0: 48 | noise_file = opt.noise_file 49 | if os.path.exists(noise_file): 50 | print('=> load noisy index from {}'.format(noise_file)) 51 | self.noisy_inx = np.load(noise_file) 52 | else: 53 | noise_ratio = opt.noise_ratio 54 | inx = np.arange(cap_len) 55 | np.random.shuffle(inx) 56 | noisy_inx = inx[0: int(noise_ratio * cap_len)] 57 | shuffle_noisy_inx = np.array(noisy_inx) // 5 58 | np.random.shuffle(shuffle_noisy_inx) 59 | self.noisy_inx[noisy_inx] = shuffle_noisy_inx 60 | np.save(noise_file, self.noisy_inx) 61 | print('Noisy ratio: %g' % noise_ratio) 62 | self.length = len(self.captions) 63 | self.corrs = (self.noisy_inx == (np.arange(self.length) // self.im_div)).astype(int) 64 | 65 | # the development set for coco is large and so validation would be slow 66 | if data_split == 'dev': 67 | self.length = 5000 68 | if 'cc152k' in opt.data_name: 69 | self.length = 1000 70 | 71 | def __getitem__(self, index): 72 | # handle the image redundancy 73 | # if use noise index by images: img_id = self.noisy_inx[int(index / self.im_div)] 74 | img_id = self.noisy_inx[index] 75 | image = torch.Tensor(self.images[img_id]) 76 | caption = self.captions[index] 77 | vocab = self.vocab 78 | 79 | # convert caption (string) to word ids. 80 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 81 | caption = [] 82 | caption.append(vocab('')) 83 | caption.extend([vocab(token) for token in tokens]) 84 | caption.append(vocab('')) 85 | target = torch.Tensor(caption) 86 | if img_id == index // self.im_div: 87 | label = 1 88 | else: 89 | label = 0 90 | return image, target, index, img_id, label 91 | 92 | def __len__(self): 93 | return self.length 94 | 95 | 96 | def collate_fn(data): 97 | # Sort a data list by caption length 98 | data.sort(key=lambda x: len(x[1]), reverse=True) 99 | images, captions, ids, img_ids, label = zip(*data) 100 | 101 | # Merge images (convert tuple of 2D tensor to 3D tensor) 102 | images = torch.stack(images, 0) 103 | 104 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 105 | lengths = [len(cap) for cap in captions] 106 | targets = torch.zeros(len(captions), max(lengths)).long() 107 | for i, cap in enumerate(captions): 108 | end = lengths[i] 109 | targets[i, :end] = cap[:end] 110 | 111 | return images, targets, lengths, list(ids), list(label) 112 | 113 | 114 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100, 115 | shuffle=True, num_workers=2): 116 | dset = PrecompDataset(data_path, data_split, vocab, opt) 117 | 118 | data_loader = torch.utils.data.DataLoader(dataset=dset, 119 | batch_size=batch_size, 120 | shuffle=shuffle, 121 | pin_memory=False, 122 | num_workers=num_workers, 123 | collate_fn=collate_fn) 124 | return data_loader 125 | 126 | 127 | def get_loaders(data_name, vocab, batch_size, workers, opt): 128 | # get the data path 129 | dpath = os.path.join(opt.data_path, data_name) 130 | 131 | # get the train_loader 132 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt, 133 | batch_size, True, workers) 134 | # get the val_loader 135 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt, 136 | batch_size, False, workers) 137 | # get the test_loader 138 | test_loader = get_precomp_loader(dpath, 'test', vocab, opt, 139 | batch_size, False, workers) 140 | return train_loader, val_loader, test_loader 141 | 142 | 143 | def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt): 144 | # get the data path 145 | dpath = os.path.join(opt.data_path, data_name) 146 | 147 | # get the test_loader 148 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt, 149 | batch_size, False, workers) 150 | return test_loader 151 | -------------------------------------------------------------------------------- /GSC+/opts.py: -------------------------------------------------------------------------------- 1 | """Argument parser""" 2 | 3 | import argparse 4 | import os 5 | import random 6 | import sys 7 | import numpy as np 8 | import torch 9 | 10 | from utils import save_config 11 | 12 | 13 | def parse_opt(): 14 | parser = argparse.ArgumentParser() 15 | # --------------------------- data path -------------------------# 16 | parser.add_argument('--data_name', default='f30k_precomp', 17 | help='{coco,f30k,cc152k}_precomp') 18 | parser.add_argument('--data_path', default='/remote-home/share/zhaozh/NC_Datasets/data', 19 | help='path to datasets') 20 | parser.add_argument('--vocab_path', default='/remote-home/share/zhaozh/NC_Datasets/vocab', 21 | help='Path to saved vocabulary json files.') 22 | parser.add_argument('--folder_name', default='debug', help='Folder name to save the running results') 23 | 24 | # ----------------------- training setting ----------------------# 25 | parser.add_argument('--batch_size', default=32, type=int, 26 | help='Size of a training mini-batch.') 27 | parser.add_argument('--num_epochs', default=60, type=int, 28 | help='Number of training epochs.') 29 | parser.add_argument('--lr_update', default=20, type=int, 30 | help='Number of epochs to update the learning rate.') 31 | parser.add_argument('--learning_rate', default=2e-4, type=float, 32 | help='Initial learning rate.') 33 | parser.add_argument('--workers', default=4, type=int, 34 | help='Number of data loader workers.') 35 | parser.add_argument('--log_step', default=100, type=int, 36 | help='Number of steps to print and record the log.') 37 | parser.add_argument('--val_step', default=1000, type=int, 38 | help='Number of steps to run validation.') 39 | parser.add_argument('--grad_clip', default=2., type=float, 40 | help='Gradient clipping threshold.') 41 | 42 | # ------------------------- model settings (SGR, SAF, and other similarity models) -----------------------# 43 | parser.add_argument('--margin', default=0.2, type=float, 44 | help='Rank loss margin.') 45 | parser.add_argument('--max_violation', action='store_false', 46 | help='Use max instead of sum in the rank loss.') 47 | parser.add_argument('--img_dim', default=2048, type=int, 48 | help='Dimensionality of the image embedding.') 49 | parser.add_argument('--word_dim', default=300, type=int, 50 | help='Dimensionality of the word embedding.') 51 | parser.add_argument('--embed_size', default=1024, type=int, 52 | help='Dimensionality of the joint embedding.') 53 | parser.add_argument('--sim_dim', default=256, type=int, 54 | help='Dimensionality of the sim embedding.') 55 | parser.add_argument('--num_layers', default=1, type=int, 56 | help='Number of GRU layers.') 57 | parser.add_argument('--bi_gru', action='store_false', 58 | help='Use bidirectional GRU.') 59 | parser.add_argument('--no_imgnorm', action='store_true', 60 | help='Do not normalize the image embeddings.') 61 | parser.add_argument('--no_txtnorm', action='store_true', 62 | help='Do not normalize the text embeddings.') 63 | parser.add_argument('--module_name', default='SGR', type=str, 64 | help='SGR, SAF') 65 | parser.add_argument('--sgr_step', default=3, type=int, 66 | help='Step of the SGR.') 67 | 68 | # ------------------------- our DECL settings -----------------------# 69 | parser.add_argument('--warmup_if', default=False, type=bool, 70 | help='Need warmup training?') 71 | parser.add_argument('--warmup_epochs', default=0, type=int, 72 | help='Number of warmup epochs.') 73 | parser.add_argument('--warmup_model_path', default='/remote-home/zhaozh/NC/ELCL_TOPO/runs_1031/f30k_upper/r_corr_4/gamma_001/checkpoint_dir/model_best.pth.tar', 74 | help='Path to load a pre-model') 75 | # noise settings 76 | parser.add_argument('--noise_ratio', default=0.4, type=float, 77 | help='Noisy ratio') 78 | parser.add_argument('--noise_file', default='/remote-home/share/zhaozh/NC_Datasets/data/noise_file_by_caption/f30k/f30k_precomp_0.4.npy', 79 | help='Path to noise index file') 80 | parser.add_argument("--seed", default=random.randint(0, 100), type=int, help="Random seed.") 81 | 82 | # loss settings 83 | parser.add_argument('--contrastive_loss', default='InfoNCE', help='Choose in Triplet/InfoNCE/RINCE.') 84 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature for Info NCE Loss.') 85 | parser.add_argument('--lam', default=0.01, type=float, help='Parameter for RINCE.') 86 | parser.add_argument('--q', default=0.1, type=float, help='Parameter for RINCE.') 87 | 88 | # elr settings 89 | parser.add_argument('--beta', default=0.7, type=float) 90 | parser.add_argument('--alpha', default=1.0, type=float) 91 | parser.add_argument('--gamma', default=1.0, type=float) 92 | 93 | # gpu settings 94 | parser.add_argument('--gpu', default='0', help='Which gpu to use.') 95 | 96 | opt = parser.parse_args() 97 | project_path = str(sys.path[0]) 98 | opt.log_dir = f'{project_path}/runs/{opt.folder_name}/log_dir' 99 | opt.checkpoint_dir = f'{project_path}/runs/{opt.folder_name}/checkpoint_dir' 100 | if not os.path.isdir(opt.log_dir): 101 | os.makedirs(opt.log_dir) 102 | if not os.path.isdir(opt.checkpoint_dir): 103 | os.makedirs(opt.checkpoint_dir) 104 | 105 | # save opt 106 | save_config(opt, os.path.join(opt.log_dir, "config.json")) 107 | 108 | # set random seed 109 | random.seed(opt.seed) 110 | np.random.seed(opt.seed) 111 | torch.manual_seed(opt.seed) 112 | torch.random.manual_seed(opt.seed) 113 | if torch.cuda.is_available(): 114 | torch.cuda.manual_seed(opt.seed) 115 | torch.backends.cudnn.benchmark = True 116 | return opt 117 | -------------------------------------------------------------------------------- /GSC/model/RINCE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import clip_grad_norm_ 5 | from model.SGRAF import SGRAF 6 | from sklearn.mixture import GaussianMixture 7 | 8 | import numpy as np 9 | from scipy.spatial.distance import cdist 10 | from utils import AverageMeter 11 | 12 | class RINCE(nn.Module): 13 | def __init__(self, opt): 14 | super(RINCE, self).__init__() 15 | self.opt = opt 16 | self.grad_clip = opt.grad_clip 17 | self.similarity_model = SGRAF(opt) 18 | self.params = list(self.similarity_model.params) 19 | self.optimizer = torch.optim.Adam(self.params, lr=opt.learning_rate) 20 | self.step = 0 21 | 22 | def state_dict(self): 23 | return self.similarity_model.state_dict() 24 | 25 | def load_state_dict(self, state_dict): 26 | self.similarity_model.load_state_dict(state_dict) 27 | 28 | def train_start(self): 29 | """switch to train mode""" 30 | self.similarity_model.train_start() 31 | 32 | def val_start(self): 33 | """switch to valuate mode""" 34 | self.similarity_model.val_start() 35 | 36 | def info_nce_loss(self, similarity_matrix, temp=None): 37 | if temp is None: 38 | temp = self.opt.temp 39 | labels = torch.eye(len(similarity_matrix)).float().cuda() 40 | 41 | # select and combine multiple positives 42 | pos = similarity_matrix[labels.bool()].view(similarity_matrix.shape[0], -1) 43 | # preds = pos 44 | pos = torch.exp(pos / temp).squeeze(1) 45 | 46 | # select only the negatives the negatives 47 | neg = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 48 | neg = torch.exp(neg / temp) 49 | 50 | neg_sum = neg.sum(1) 51 | 52 | loss = -(torch.log(pos / (pos + neg_sum))) 53 | preds = pos / (pos + neg_sum) 54 | return loss, preds 55 | 56 | def warmup_batch(self, images, captions, lengths, ids, corrs): 57 | self.step += 1 58 | batch_length = images.size(0) 59 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths) 60 | targets_batch = self.similarity_model.targets[ids] 61 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim') 62 | 63 | if self.opt.contrastive_loss == 'Triplet': 64 | loss_cl = self.similarity_model.forward_loss(sims) 65 | elif self.opt.contrastive_loss == 'InfoNCE': 66 | loss_cl, preds = self.info_nce_loss(sims) 67 | 68 | loss = loss_cl.mean() 69 | 70 | self.optimizer.zero_grad() 71 | loss = loss.mean() 72 | loss.backward() 73 | if self.grad_clip > 0: 74 | clip_grad_norm_(self.params, self.grad_clip) 75 | self.optimizer.step() 76 | self.logger.update('Step', self.step) 77 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr']) 78 | self.logger.update('Loss', loss.item(), batch_length) 79 | 80 | def train_batch(self, images, captions, lengths, ids, corrs): 81 | self.step += 1 82 | batch_length = images.size(0) 83 | 84 | targets_batch = self.similarity_model.targets[ids] 85 | 86 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths) 87 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim') 88 | loss_cl, preds = self.info_nce_loss(sims) 89 | loss_cl_t, preds_t = self.info_nce_loss(sims.t()) 90 | loss_cl = (loss_cl + loss_cl_t) * 0.5 * targets_batch 91 | preds = (preds + preds_t) * 0.5 92 | 93 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens) 94 | sims_img = torch.sigmoid(sims_img) * targets_batch 95 | sims_cap = torch.sigmoid(sims_cap) * targets_batch 96 | sims_topo = sims_img @ sims_cap.t() 97 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0) 98 | loss_topo = loss_topo * targets_batch 99 | 100 | self.similarity_model.topos[ids] = torch.cosine_similarity(sims_img, sims_cap, dim=1).data.detach().float() 101 | self.similarity_model.preds[ids] = preds.data.detach().float() 102 | target_clean = (targets_batch * corrs).sum() / corrs.sum() 103 | target_noise = (targets_batch * (1 - corrs)).sum() / (1 - corrs).sum() 104 | 105 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma 106 | 107 | self.optimizer.zero_grad() 108 | loss.backward() 109 | if self.grad_clip > 0: 110 | clip_grad_norm_(self.params, self.grad_clip) 111 | self.optimizer.step() 112 | self.logger.update('Step', self.step) 113 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr']) 114 | self.logger.update('Loss', loss.item(), batch_length) 115 | self.logger.update('Loss_CL', loss_cl.mean().item(), batch_length) 116 | self.logger.update('Loss_TOPO', loss_topo.mean().item(), batch_length) 117 | self.logger.update('Target Clean', target_clean.item(), batch_length) 118 | self.logger.update('Target Noise', target_noise.item(), batch_length) 119 | 120 | def split_batch(self, corrs=None): 121 | topos = self.similarity_model.topos 122 | topos = (topos - topos.min()) / (topos.max() - topos.min()) 123 | 124 | gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4) 125 | gmm.fit(topos.detach().cpu().numpy().reshape(-1, 1)) 126 | prob = gmm.predict_proba(topos.detach().cpu().numpy().reshape(-1, 1)) 127 | prob = prob[:, gmm.means_.argmax()] 128 | 129 | self.similarity_model.topo_targets = self.opt.beta * torch.tensor(prob).cuda() + (1 - self.opt.beta) * self.similarity_model.topo_targets 130 | self.similarity_model.preds_targets = self.opt.beta * self.similarity_model.preds + (1 - self.opt.beta) * self.similarity_model.preds_targets 131 | self.similarity_model.targets = torch.minimum(self.similarity_model.preds_targets, self.similarity_model.topo_targets) 132 | 133 | if corrs is not None: 134 | pred = prob > 0.5 135 | clean_acc = (pred * corrs).sum() / corrs.sum() 136 | noise_acc = ((1 - pred) * (1 - corrs)).sum() / (1 - corrs).sum() 137 | print('GMM Clean Acc: ' + str(clean_acc)) 138 | print('GMM Noise Acc: ' + str(noise_acc)) -------------------------------------------------------------------------------- /GSC+/model/RINCE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import clip_grad_norm_ 5 | from model.SGRAF import SGRAF 6 | from sklearn.mixture import GaussianMixture 7 | 8 | import numpy as np 9 | from scipy.spatial.distance import cdist 10 | from utils import AverageMeter 11 | 12 | class RINCE(nn.Module): 13 | def __init__(self, opt): 14 | super(RINCE, self).__init__() 15 | self.opt = opt 16 | self.grad_clip = opt.grad_clip 17 | self.similarity_model = SGRAF(opt) 18 | self.params = list(self.similarity_model.params) 19 | self.optimizer = torch.optim.Adam(self.params, lr=opt.learning_rate) 20 | self.step = 0 21 | 22 | def state_dict(self): 23 | return self.similarity_model.state_dict() 24 | 25 | def load_state_dict(self, state_dict): 26 | self.similarity_model.load_state_dict(state_dict) 27 | 28 | def train_start(self): 29 | """switch to train mode""" 30 | self.similarity_model.train_start() 31 | 32 | def val_start(self): 33 | """switch to valuate mode""" 34 | self.similarity_model.val_start() 35 | 36 | def info_nce_loss(self, similarity_matrix, temp=None): 37 | if temp is None: 38 | temp = self.opt.temp 39 | labels = torch.eye(len(similarity_matrix)).float().cuda() 40 | 41 | # select and combine multiple positives 42 | pos = similarity_matrix[labels.bool()].view(similarity_matrix.shape[0], -1) 43 | pos = torch.exp(pos / temp).squeeze(1) 44 | 45 | # select only the negatives the negatives 46 | neg = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 47 | neg = torch.exp(neg / temp) 48 | 49 | neg_sum = neg.sum(1) 50 | 51 | loss = -(torch.log(pos / (pos + neg_sum))) 52 | preds = pos / (pos + neg_sum) 53 | return loss, preds 54 | 55 | def warmup_batch(self, images, captions, lengths, ids, corrs): 56 | self.step += 1 57 | batch_length = images.size(0) 58 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths) 59 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim') 60 | 61 | loss_cl, preds = self.info_nce_loss(sims) 62 | loss_cl_t, preds_t = self.info_nce_loss(sims.t()) 63 | loss_cl = (loss_cl + loss_cl_t) * 0.5 64 | 65 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens) 66 | sims_img = torch.sigmoid(sims_img) 67 | sims_cap = torch.sigmoid(sims_cap) 68 | sims_topo = sims_img @ sims_cap.t() 69 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0) 70 | loss_topo = loss_topo 71 | 72 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma 73 | 74 | self.optimizer.zero_grad() 75 | loss.backward() 76 | if self.grad_clip > 0: 77 | clip_grad_norm_(self.params, self.grad_clip) 78 | self.optimizer.step() 79 | self.logger.update('Step', self.step) 80 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr']) 81 | self.logger.update('Loss', loss.item(), batch_length) 82 | 83 | def train_batch(self, images, captions, lengths, ids, corrs, targets_batch): 84 | self.step += 1 85 | batch_length = images.size(0) 86 | 87 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths) 88 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim') 89 | loss_cl, preds = self.info_nce_loss(sims) 90 | loss_cl_t, preds_t = self.info_nce_loss(sims.t()) 91 | loss_cl = (loss_cl + loss_cl_t) * 0.5 * targets_batch 92 | preds = (preds + preds_t) * 0.5 93 | 94 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens) 95 | sims_img = torch.sigmoid(sims_img) * targets_batch 96 | sims_cap = torch.sigmoid(sims_cap) * targets_batch 97 | sims_topo = sims_img @ sims_cap.t() 98 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0) 99 | loss_topo = loss_topo * targets_batch 100 | 101 | self.similarity_model.topos[ids] = torch.cosine_similarity(sims_img, sims_cap, dim=1).data.detach().float() 102 | self.similarity_model.preds[ids] = preds.data.detach().float() 103 | target_clean = (targets_batch * corrs).sum() / corrs.sum() 104 | target_noise = (targets_batch * (1 - corrs)).sum() / (1 - corrs).sum() 105 | 106 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma 107 | 108 | self.optimizer.zero_grad() 109 | loss.backward() 110 | if self.grad_clip > 0: 111 | clip_grad_norm_(self.params, self.grad_clip) 112 | self.optimizer.step() 113 | self.logger.update('Step', self.step) 114 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr']) 115 | self.logger.update('Loss', loss.item(), batch_length) 116 | self.logger.update('Loss_CL', loss_cl.mean().item(), batch_length) 117 | self.logger.update('Loss_TOPO', loss_topo.mean().item(), batch_length) 118 | self.logger.update('Target Clean', target_clean.item(), batch_length) 119 | self.logger.update('Target Noise', target_noise.item(), batch_length) 120 | 121 | def split_batch(self, corrs=None): 122 | topos = self.similarity_model.topos 123 | topos = (topos - topos.min()) / (topos.max() - topos.min()) 124 | 125 | gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4) 126 | gmm.fit(topos.detach().cpu().numpy().reshape(-1, 1)) 127 | prob = gmm.predict_proba(topos.detach().cpu().numpy().reshape(-1, 1)) 128 | prob = prob[:, gmm.means_.argmax()] 129 | 130 | self.similarity_model.preds_targets = self.opt.beta * self.similarity_model.preds + (1 - self.opt.beta) * self.similarity_model.preds_targets 131 | self.similarity_model.topo_targets = self.opt.beta * torch.tensor(prob).cuda() + (1 - self.opt.beta) * self.similarity_model.topo_targets 132 | self.similarity_model.targets = torch.minimum(self.similarity_model.preds_targets, self.similarity_model.topo_targets) 133 | 134 | if corrs is not None: 135 | pred = prob > 0.5 136 | clean_acc = (pred * corrs).sum() / corrs.sum() 137 | noise_acc = ((1 - pred) * (1 - corrs)).sum() / (1 - corrs).sum() 138 | print('GMM Clean Acc: ' + str(clean_acc)) 139 | print('GMM Noise Acc: ' + str(noise_acc)) -------------------------------------------------------------------------------- /GSC/train.py: -------------------------------------------------------------------------------- 1 | """Training script""" 2 | 3 | import logging 4 | import os 5 | import time 6 | from datetime import datetime 7 | import numpy as np 8 | import torch 9 | from sklearn.metrics import confusion_matrix 10 | import tensorboard_logger as tb_logger 11 | import data 12 | import opts 13 | from model.RINCE import RINCE 14 | from evaluation import encode_data, shard_attn_scores, i2t, t2i, AverageMeter, LogCollector 15 | from utils import save_checkpoint 16 | from vocab import deserialize_vocab 17 | import warnings 18 | 19 | warnings.filterwarnings("ignore") 20 | 21 | def adjust_learning_rate(opt, optimizer, epoch): 22 | """ 23 | Sets the learning rate to the initial LR 24 | decayed by 10 after opt.lr_update epoch 25 | """ 26 | lr = opt.learning_rate * (0.2 ** (epoch // opt.lr_update)) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | 30 | 31 | def train(opt, data_loader, val_loader, model, epoch, preds=None, mode='warmup', best_rsum=0): 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | train_logger = LogCollector() 35 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1 36 | end = time.time() 37 | logger.info("=> {mode} epoch: {0}".format(epoch, mode=mode)) 38 | for i, (images, captions, lengths, ids, corrs) in enumerate(data_loader): 39 | if images.size(0) == 1: 40 | break 41 | model.train_start() 42 | data_time.update(time.time() - end) 43 | model.logger = train_logger 44 | ids = torch.tensor(ids).cuda() 45 | corrs = torch.tensor(corrs).cuda() 46 | if mode == 'warmup': 47 | model.warmup_batch(images, captions, lengths, ids, corrs) 48 | else: 49 | model.train_batch(images, captions, lengths, ids, corrs) 50 | batch_time.update(time.time() - end) 51 | if model.step % opt.log_step == 0: 52 | logger.info( 53 | 'Epoch ({mode}): [{0}][{1}/{2}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 54 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t{loss}'.format(epoch, i, num_loader_iter, 55 | mode=mode, 56 | batch_time=batch_time, 57 | data_time=data_time, 58 | loss=str(model.logger))) 59 | # Record logs in tensorboard 60 | tb_logger.log_value('epoch', epoch, step=model.step) 61 | tb_logger.log_value('step', i, step=model.step) 62 | tb_logger.log_value('batch_time', batch_time.val, step=model.step) 63 | tb_logger.log_value('data_time', data_time.val, step=model.step) 64 | model.logger.tb_log(tb_logger, step=model.step) 65 | 66 | model.split_batch(train_loader.dataset.corrs) 67 | 68 | 69 | def validation(opt, val_loader, model, test=False): 70 | # compute the encoding for all the validation images and captions 71 | if opt.data_name == 'cc152k_precomp': 72 | per_captions = 1 73 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 74 | per_captions = 5 75 | else: 76 | logger.info(f"No dataset") 77 | return 0 78 | if test: 79 | logger.info(f"=> Test") 80 | else: 81 | logger.info(f"=> Validation") 82 | # compute the encoding for all the validation images and captions 83 | img_embs, cap_embs, cap_lens = encode_data(model.similarity_model, val_loader, opt.log_step) 84 | 85 | # clear duplicate 5*images and keep 1*images 86 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 87 | 88 | # record computation time of validation 89 | start = time.time() 90 | sims = shard_attn_scores(model.similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000) 91 | end = time.time() 92 | logger.info(f"calculate similarity time: {end - start}") 93 | 94 | with torch.no_grad(): 95 | loss_cl, preds = model.info_nce_loss(torch.tensor(sims[:, :1000]).float().cuda()) 96 | loss_cl_t, preds_t = model.info_nce_loss(torch.tensor(sims[:, :1000]).float().cuda().t()) 97 | loss_cl = (loss_cl + loss_cl_t) * 0.5 98 | 99 | sims_img, sims_cap = model.similarity_model.sim_enc.forward_topology(torch.tensor(img_embs[:128]).float().cuda(), torch.tensor(cap_embs[:128]).float().cuda(), cap_lens[:128]) 100 | sims_img = torch.sigmoid(sims_img) 101 | sims_cap = torch.sigmoid(sims_cap) 102 | sims_topo = sims_img @ sims_cap.t() 103 | loss_topo, _ = model.info_nce_loss(sims_topo, temp=1.0) 104 | 105 | # caption retrieval 106 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims, per_captions, return_ranks=False) 107 | logger.info("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 108 | logger.info("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 109 | # image retrieval 110 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims, per_captions, return_ranks=False) 111 | logger.info("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 112 | logger.info("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 113 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 114 | 115 | logger.info("Sum of Recall: %.2f" % (r_sum)) 116 | # sum of recalls to be used for early stopping 117 | if test: 118 | # record metrics in tensorboard 119 | tb_logger.log_value('r1', r1, step=model.step) 120 | tb_logger.log_value('r5', r5, step=model.step) 121 | tb_logger.log_value('r10', r10, step=model.step) 122 | tb_logger.log_value('medr', medr, step=model.step) 123 | tb_logger.log_value('meanr', meanr, step=model.step) 124 | tb_logger.log_value('r1i', r1i, step=model.step) 125 | tb_logger.log_value('r5i', r5i, step=model.step) 126 | tb_logger.log_value('r10i', r10i, step=model.step) 127 | tb_logger.log_value('medri', medri, step=model.step) 128 | tb_logger.log_value('meanr', meanr, step=model.step) 129 | tb_logger.log_value('r_sum', r_sum, step=model.step) 130 | tb_logger.log_value('loss_cl', loss_cl.mean().item(), step=model.step) 131 | tb_logger.log_value('loss_topo', loss_topo.mean().item(), step=model.step) 132 | else: 133 | # record metrics in tensorboard 134 | tb_logger.log_value('t-r1', r1, step=model.step) 135 | tb_logger.log_value('t-r5', r5, step=model.step) 136 | tb_logger.log_value('t-r10', r10, step=model.step) 137 | tb_logger.log_value('t-medr', medr, step=model.step) 138 | tb_logger.log_value('t-meanr', meanr, step=model.step) 139 | tb_logger.log_value('t-r1i', r1i, step=model.step) 140 | tb_logger.log_value('t-r5i', r5i, step=model.step) 141 | tb_logger.log_value('t-r10i', r10i, step=model.step) 142 | tb_logger.log_value('t-medri', medri, step=model.step) 143 | tb_logger.log_value('t-meanr', meanr, step=model.step) 144 | tb_logger.log_value('t-r_sum', r_sum, step=model.step) 145 | tb_logger.log_value('t-loss_cl', loss_cl.mean().item(), step=model.step) 146 | tb_logger.log_value('t-loss_topo', loss_topo.mean().item(), step=model.step) 147 | return r_sum 148 | 149 | 150 | def init_logging(log_file_path): 151 | logger = logging.getLogger(__name__) 152 | logger.setLevel(level=logging.DEBUG) 153 | formatter = logging.Formatter('%(asctime)s %(message)s') 154 | file_handler = logging.FileHandler(log_file_path) 155 | file_handler.setLevel(level=logging.INFO) 156 | file_handler.setFormatter(formatter) 157 | stream_handler = logging.StreamHandler() 158 | stream_handler.setLevel(logging.DEBUG) 159 | stream_handler.setFormatter(formatter) 160 | logger.addHandler(file_handler) 161 | logger.addHandler(stream_handler) 162 | return logger 163 | 164 | 165 | if __name__ == '__main__': 166 | opt = opts.parse_opt() 167 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu 168 | tb_logger.configure(opt.log_dir, flush_secs=5) 169 | logger = init_logging(opt.log_dir + '/log.txt') 170 | logger.info(opt) 171 | logger.info(f"=> PID:{os.getpid()}, GPU:[{opt.gpu}], Noise ratio: {opt.noise_ratio}") 172 | logger.info(f"=> Log save path: '{opt.log_dir}'") 173 | logger.info(f"=> Checkpoint save path: '{opt.checkpoint_dir}'") 174 | # Load Vocabulary 175 | logger.info(f"=> Load vocabulary from '{opt.vocab_path}'") 176 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 177 | opt.vocab_size = len(vocab) 178 | # Load data loaders 179 | logger.info(f"=> Load loaders from '{opt.data_path}/{opt.data_name}'") 180 | train_loader, val_loader, test_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size, 181 | opt.workers, opt) 182 | # Construct the model (DECL-) 183 | logger.info(f"=> Similarity model is {opt.module_name}") 184 | model = RINCE(opt) 185 | best_rsum = 0 186 | start_epoch = 0 187 | if opt.warmup_if: 188 | if os.path.isfile(opt.warmup_model_path): 189 | checkpoint = torch.load(opt.warmup_model_path) 190 | model.load_state_dict(checkpoint['model']) 191 | logger.info( 192 | "=> Load warmup(pre-) checkpoint '{}' (epoch {})".format(opt.warmup_model_path, checkpoint['epoch'])) 193 | if 'best_rsum' in checkpoint: 194 | if 'warmup' not in opt.warmup_model_path: 195 | start_epoch = checkpoint['epoch'] + 1 196 | best_rsum = checkpoint['best_rsum'] 197 | model.step = checkpoint['step'] 198 | else: 199 | logger.info(f"=> no checkpoint found at '{opt.warmup_model_path}', warmup start!") 200 | for e in range(opt.warmup_epochs): 201 | train(opt, train_loader, val_loader, model, e, mode='warmup') 202 | save_checkpoint({ 203 | 'epoch': e, 204 | 'model': model.state_dict(), 205 | 'opt': opt, 206 | 'step': model.step 207 | }, is_best=False, filename='warmup_model_{}.pth.tar'.format(e), prefix=opt.checkpoint_dir + '/') 208 | else: 209 | logger.info("=> No warmup stage") 210 | 211 | for epoch in range(start_epoch, opt.num_epochs): 212 | adjust_learning_rate(opt, model.optimizer, epoch) 213 | start_time = datetime.now() 214 | train(opt, train_loader, val_loader, model, epoch, mode='train', best_rsum=best_rsum) 215 | end_time = datetime.now() 216 | tb_logger.log_value('cost_time', int((end_time - start_time).seconds), step=epoch) 217 | validation(opt, test_loader, model, test=True) 218 | rsum = validation(opt, val_loader, model) 219 | is_best = rsum > best_rsum 220 | best_rsum = max(rsum, best_rsum) 221 | save_checkpoint({ 222 | 'epoch': epoch, 223 | 'model': model.state_dict(), 224 | 'best_rsum': best_rsum, 225 | 'opt': opt, 226 | 'step': model.step, 227 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.checkpoint_dir + '/') 228 | -------------------------------------------------------------------------------- /GSC+/train.py: -------------------------------------------------------------------------------- 1 | """Training script""" 2 | 3 | import logging 4 | import os 5 | import time 6 | from datetime import datetime 7 | import numpy as np 8 | import torch 9 | from sklearn.metrics import confusion_matrix 10 | import tensorboard_logger as tb_logger 11 | import data 12 | import opts 13 | from model.RINCE import RINCE 14 | from evaluation import encode_data, shard_attn_scores, i2t, t2i, AverageMeter, LogCollector 15 | from utils import save_checkpoint 16 | from vocab import deserialize_vocab 17 | import warnings 18 | 19 | warnings.filterwarnings("ignore") 20 | 21 | def adjust_learning_rate(opt, optimizer, epoch): 22 | """ 23 | Sets the learning rate to the initial LR 24 | decayed by 10 after opt.lr_update epoch 25 | """ 26 | lr = opt.learning_rate * (0.2 ** (epoch // opt.lr_update)) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | 30 | 31 | def train(opt, data_loader, val_loader, model_A, model_B, epoch, preds=None, mode='warmup', best_rsum=0): 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | train_logger = LogCollector() 35 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1 36 | end = time.time() 37 | logger.info("=> {mode} epoch: {0}".format(epoch, mode=mode)) 38 | for i, (images, captions, lengths, ids, corrs) in enumerate(data_loader): 39 | # print(i) 40 | if images.size(0) == 1: 41 | break 42 | model_A.train_start() 43 | model_B.train_start() 44 | data_time.update(time.time() - end) 45 | model_A.logger = train_logger 46 | model_B.logger = train_logger 47 | ids = torch.tensor(ids).cuda() 48 | corrs = torch.tensor(corrs).cuda() 49 | if mode == 'warmup': 50 | model_A.warmup_batch(images, captions, lengths) 51 | model_B.warmup_batch(images, captions, lengths) 52 | else: 53 | model_A.train_batch(images, captions, lengths, ids, corrs, model_B.similarity_model.targets[ids]) 54 | model_B.train_batch(images, captions, lengths, ids, corrs, model_A.similarity_model.targets[ids]) 55 | batch_time.update(time.time() - end) 56 | if model_A.step % opt.log_step == 0: 57 | logger.info( 58 | 'Epoch ({mode}): [{0}][{1}/{2}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 59 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t{loss}'.format(epoch, i, num_loader_iter, 60 | mode=mode, 61 | batch_time=batch_time, 62 | data_time=data_time, 63 | loss=str(model_A.logger))) 64 | # Record logs in tensorboard 65 | tb_logger.log_value('epoch', epoch, step=model_A.step) 66 | tb_logger.log_value('step', i, step=model_A.step) 67 | tb_logger.log_value('batch_time', batch_time.val, step=model_A.step) 68 | tb_logger.log_value('data_time', data_time.val, step=model_A.step) 69 | model_A.logger.tb_log(tb_logger, step=model_A.step) 70 | model_B.logger.tb_log(tb_logger, step=model_B.step) 71 | 72 | if mode != 'warmup': 73 | model_A.split_batch(train_loader.dataset.corrs) 74 | model_B.split_batch(train_loader.dataset.corrs) 75 | 76 | 77 | def validation(opt, val_loader, model_A, model_B, test=False): 78 | # compute the encoding for all the validation images and captions 79 | if opt.data_name == 'cc152k_precomp': 80 | per_captions = 1 81 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 82 | per_captions = 5 83 | else: 84 | logger.info(f"No dataset") 85 | return 0 86 | if test: 87 | logger.info(f"=> Test") 88 | else: 89 | logger.info(f"=> Validation") 90 | # compute the encoding for all the validation images and captions 91 | img_embs_A, cap_embs_A, cap_lens_A = encode_data(model_A.similarity_model, val_loader, opt.log_step) 92 | img_embs_A = np.array([img_embs_A[i] for i in range(0, len(img_embs_A), per_captions)]) 93 | 94 | img_embs_B, cap_embs_B, cap_lens_B = encode_data(model_B.similarity_model, val_loader, opt.log_step) 95 | img_embs_B = np.array([img_embs_B[i] for i in range(0, len(img_embs_B), per_captions)]) 96 | 97 | 98 | # record computation time of validation 99 | start = time.time() 100 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_A, cap_embs_A, cap_lens_A, opt, shard_size=1000) 101 | end = time.time() 102 | logger.info(f"calculate similarity time: {end - start}") 103 | 104 | start = time.time() 105 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_B, cap_embs_B, cap_lens_B, opt, shard_size=1000) 106 | end = time.time() 107 | logger.info(f"calculate similarity time: {end - start}") 108 | sims = (sims_A + sims_B) / 2 109 | 110 | # caption retrieval 111 | (r1, r5, r10, medr, meanr) = i2t(img_embs_A.shape[0], sims, per_captions, return_ranks=False) 112 | logger.info("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 113 | logger.info("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 114 | # image retrieval 115 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs_A.shape[0], sims, per_captions, return_ranks=False) 116 | logger.info("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 117 | logger.info("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 118 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 119 | 120 | logger.info("Sum of Recall: %.2f" % (r_sum)) 121 | # sum of recalls to be used for early stopping 122 | if test: 123 | # record metrics in tensorboard 124 | tb_logger.log_value('r1', r1, step=model_A.step) 125 | tb_logger.log_value('r5', r5, step=model_A.step) 126 | tb_logger.log_value('r10', r10, step=model_A.step) 127 | tb_logger.log_value('medr', medr, step=model_A.step) 128 | tb_logger.log_value('meanr', meanr, step=model_A.step) 129 | tb_logger.log_value('r1i', r1i, step=model_A.step) 130 | tb_logger.log_value('r5i', r5i, step=model_A.step) 131 | tb_logger.log_value('r10i', r10i, step=model_A.step) 132 | tb_logger.log_value('medri', medri, step=model_A.step) 133 | tb_logger.log_value('meanr', meanr, step=model_A.step) 134 | tb_logger.log_value('r_sum', r_sum, step=model_A.step) 135 | else: 136 | # record metrics in tensorboard 137 | tb_logger.log_value('t-r1', r1, step=model_A.step) 138 | tb_logger.log_value('t-r5', r5, step=model_A.step) 139 | tb_logger.log_value('t-r10', r10, step=model_A.step) 140 | tb_logger.log_value('t-medr', medr, step=model_A.step) 141 | tb_logger.log_value('t-meanr', meanr, step=model_A.step) 142 | tb_logger.log_value('t-r1i', r1i, step=model_A.step) 143 | tb_logger.log_value('t-r5i', r5i, step=model_A.step) 144 | tb_logger.log_value('t-r10i', r10i, step=model_A.step) 145 | tb_logger.log_value('t-medri', medri, step=model_A.step) 146 | tb_logger.log_value('t-meanr', meanr, step=model_A.step) 147 | tb_logger.log_value('t-r_sum', r_sum, step=model_A.step) 148 | return r_sum 149 | 150 | def init_logging(log_file_path): 151 | logger = logging.getLogger(__name__) 152 | logger.setLevel(level=logging.DEBUG) 153 | formatter = logging.Formatter('%(asctime)s %(message)s') 154 | file_handler = logging.FileHandler(log_file_path) 155 | file_handler.setLevel(level=logging.INFO) 156 | file_handler.setFormatter(formatter) 157 | stream_handler = logging.StreamHandler() 158 | stream_handler.setLevel(logging.DEBUG) 159 | stream_handler.setFormatter(formatter) 160 | logger.addHandler(file_handler) 161 | logger.addHandler(stream_handler) 162 | return logger 163 | 164 | 165 | if __name__ == '__main__': 166 | opt = opts.parse_opt() 167 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu 168 | tb_logger.configure(opt.log_dir, flush_secs=5) 169 | logger = init_logging(opt.log_dir + '/log.txt') 170 | logger.info(opt) 171 | logger.info(f"=> PID:{os.getpid()}, GPU:[{opt.gpu}], Noise ratio: {opt.noise_ratio}") 172 | logger.info(f"=> Log save path: '{opt.log_dir}'") 173 | logger.info(f"=> Checkpoint save path: '{opt.checkpoint_dir}'") 174 | # Load Vocabulary 175 | logger.info(f"=> Load vocabulary from '{opt.vocab_path}'") 176 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 177 | opt.vocab_size = len(vocab) 178 | # Load data loaders 179 | logger.info(f"=> Load loaders from '{opt.data_path}/{opt.data_name}'") 180 | train_loader, val_loader, test_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size, 181 | opt.workers, opt) 182 | # Construct the model (DECL-) 183 | logger.info(f"=> Similarity model is {opt.module_name}") 184 | model_A = RINCE(opt) 185 | model_B = RINCE(opt) 186 | best_rsum = 0 187 | start_epoch = 0 188 | if opt.warmup_if: 189 | if os.path.isfile(opt.warmup_model_path): 190 | checkpoint = torch.load(opt.warmup_model_path) 191 | model_A.load_state_dict(checkpoint['model_A']) 192 | model_B.load_state_dict(checkpoint['model_B']) 193 | logger.info( 194 | "=> Load warmup(pre-) checkpoint '{}' (epoch {})".format(opt.warmup_model_path, checkpoint['epoch'])) 195 | if 'best_rsum' in checkpoint: 196 | if 'warmup' not in opt.warmup_model_path: 197 | start_epoch = checkpoint['epoch'] + 1 198 | best_rsum = checkpoint['best_rsum'] 199 | model_A.step = checkpoint['step'] 200 | model_B.step = checkpoint['step'] 201 | else: 202 | logger.info(f"=> no checkpoint found at '{opt.warmup_model_path}', warmup start!") 203 | for e in range(opt.warmup_epochs): 204 | train(opt, train_loader, val_loader, model_A, model_B, e, mode='warmup') 205 | save_checkpoint({ 206 | 'epoch': e, 207 | 'model_A': model_A.state_dict(), 208 | 'model_B': model_B.state_dict(), 209 | 'opt': opt, 210 | 'step': model_A.step 211 | }, is_best=False, filename='warmup_model_{}.pth.tar'.format(e), prefix=opt.checkpoint_dir + '/') 212 | else: 213 | logger.info("=> No warmup stage") 214 | 215 | for epoch in range(start_epoch, opt.num_epochs): 216 | adjust_learning_rate(opt, model_A.optimizer, epoch) 217 | adjust_learning_rate(opt, model_B.optimizer, epoch) 218 | start_time = datetime.now() 219 | train(opt, train_loader, val_loader, model_A, model_B, epoch, mode='train', best_rsum=best_rsum) 220 | end_time = datetime.now() 221 | tb_logger.log_value('cost_time', int((end_time - start_time).seconds), step=epoch) 222 | validation(opt, test_loader, model_A, model_B, test=True) 223 | rsum = validation(opt, val_loader, model_A, model_B) 224 | is_best = rsum > best_rsum 225 | best_rsum = max(rsum, best_rsum) 226 | save_checkpoint({ 227 | 'epoch': epoch, 228 | 'model_A': model_A.state_dict(), 229 | 'model_B': model_B.state_dict(), 230 | 'best_rsum': best_rsum, 231 | 'opt': opt, 232 | 'step': model_A.step, 233 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.checkpoint_dir + '/') 234 | -------------------------------------------------------------------------------- /GSC/evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation""" 2 | 3 | from __future__ import print_function 4 | import os 5 | import sys 6 | import torch 7 | import numpy as np 8 | 9 | from model.RINCE import RINCE 10 | from data import get_test_loader 11 | from vocab import deserialize_vocab 12 | from collections import OrderedDict 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=0): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / (.0001 + self.count) 32 | 33 | def __str__(self): 34 | """String representation for logging 35 | """ 36 | # for values that should be recorded exactly e.g. iteration number 37 | if self.count == 0: 38 | return str(self.val) 39 | # for stats 40 | return '%.4f (%.4f)' % (self.val, self.avg) 41 | 42 | 43 | class LogCollector(object): 44 | """A collection of logging objects that can change from train to val""" 45 | 46 | def __init__(self): 47 | # to keep the order of logged variables deterministic 48 | self.meters = OrderedDict() 49 | 50 | def update(self, k, v, n=0): 51 | # create a new meter if previously not recorded 52 | if k not in self.meters: 53 | self.meters[k] = AverageMeter() 54 | self.meters[k].update(v, n) 55 | 56 | def __str__(self): 57 | """Concatenate the meters in one log line 58 | """ 59 | s = '' 60 | for i, (k, v) in enumerate(self.meters.items()): 61 | if i > 0: 62 | s += ' ' 63 | s += k + ' ' + str(v) 64 | return s 65 | 66 | def tb_log(self, tb_logger, prefix='', step=None): 67 | """Log using tensorboard 68 | """ 69 | for k, v in self.meters.items(): 70 | tb_logger.log_value(prefix + k, v.val, step=step) 71 | 72 | 73 | def encode_data(model, data_loader, log_step=10, logging=print, sub=0): 74 | """Encode all images and captions loadable by `data_loader` 75 | """ 76 | val_logger = LogCollector() 77 | 78 | # switch to evaluate mode 79 | model.val_start() 80 | 81 | # np array to keep all the embeddings 82 | img_embs = None 83 | cap_embs = None 84 | 85 | max_n_word = 0 86 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader): 87 | max_n_word = max(max_n_word, max(lengths)) 88 | ids_ = [] 89 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader): 90 | # make sure val logger is used 91 | model.logger = val_logger 92 | ids_ += ids 93 | # compute the embeddings 94 | with torch.no_grad(): 95 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths) 96 | if img_embs is None: 97 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 98 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 99 | cap_lens = [0] * len(data_loader.dataset) 100 | # cache embeddings 101 | img_embs[ids, :, :] = img_emb.data.cpu().numpy().copy() 102 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy() 103 | 104 | for j, nid in enumerate(ids): 105 | cap_lens[nid] = cap_len[j] 106 | 107 | del images, captions 108 | if sub > 0: 109 | print(f"===>batch {i}") 110 | if sub > 0 and i > sub: 111 | break 112 | if sub > 0: 113 | return np.array(img_embs)[ids_].tolist(), np.array(cap_embs)[ids_].tolist(), np.array(cap_lens)[ 114 | ids_].tolist(), ids_ 115 | else: 116 | return img_embs, cap_embs, cap_lens 117 | 118 | 119 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100, mode="sim"): 120 | n_im_shard = (len(img_embs) - 1) // shard_size + 1 121 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1 122 | 123 | sims = np.zeros((len(img_embs), len(cap_embs))) 124 | for i in range(n_im_shard): 125 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs)) 126 | for j in range(n_cap_shard): 127 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j)) 128 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs)) 129 | with torch.no_grad(): 130 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda() 131 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda() 132 | l = cap_lens[ca_start:ca_end] 133 | if mode == "sim": 134 | sim = model.forward_sim(im, ca, l, mode) 135 | else: 136 | _, sim, _ = model.forward_sim(im, ca, l, mode) # Calculate evidence for retrieval 137 | 138 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy() 139 | sys.stdout.write('\n') 140 | return sims 141 | 142 | 143 | def t2i(npts, sims, per_captions=1, return_ranks=False): 144 | """ 145 | Text->Images (Image Search) 146 | Images: (N, n_region, d) matrix of images 147 | Captions: (per_captions * N, max_n_word, d) matrix of captions 148 | CapLens: (per_captions * N) array of caption lengths 149 | sims: (N, per_captions * N) matrix of similarity im-cap 150 | """ 151 | ranks = np.zeros(per_captions * npts) 152 | top1 = np.zeros(per_captions * npts) 153 | top5 = np.zeros((per_captions * npts, 5), dtype=int) 154 | 155 | # --> (per_captions * N(caption), N(image)) 156 | sims = sims.T 157 | retreivaled_index = [] 158 | for index in range(npts): 159 | for i in range(per_captions): 160 | inds = np.argsort(sims[per_captions * index + i])[::-1] 161 | retreivaled_index.append(inds) 162 | ranks[per_captions * index + i] = np.where(inds == index)[0][0] 163 | top1[per_captions * index + i] = inds[0] 164 | top5[per_captions * index + i] = inds[0:5] 165 | 166 | # Compute metrics 167 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 168 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 169 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 170 | medr = np.floor(np.median(ranks)) + 1 171 | meanr = ranks.mean() + 1 172 | if return_ranks: 173 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 174 | else: 175 | return (r1, r5, r10, medr, meanr) 176 | 177 | 178 | def i2t(npts, sims, per_captions=1, return_ranks=False): 179 | """ 180 | Images->Text (Image Annotation) 181 | Images: (N, n_region, d) matrix of images 182 | Captions: (per_captions * N, max_n_word, d) matrix of captions 183 | CapLens: (per_captions * N) array of caption lengths 184 | sims: (N, per_captions * N) matrix of similarity im-cap 185 | """ 186 | ranks = np.zeros(npts) 187 | top1 = np.zeros(npts) 188 | top5 = np.zeros((npts, 5), dtype=int) 189 | retreivaled_index = [] 190 | for index in range(npts): 191 | inds = np.argsort(sims[index])[::-1] 192 | retreivaled_index.append(inds) 193 | # Score 194 | rank = 1e20 195 | for i in range(per_captions * index, per_captions * index + per_captions, 1): 196 | tmp = np.where(inds == i)[0][0] 197 | if tmp < rank: 198 | rank = tmp 199 | ranks[index] = rank 200 | top1[index] = inds[0] 201 | top5[index] = inds[0:5] 202 | 203 | # Compute metrics 204 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 205 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 206 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 207 | medr = np.floor(np.median(ranks)) + 1 208 | meanr = ranks.mean() + 1 209 | if return_ranks: 210 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 211 | else: 212 | return (r1, r5, r10, medr, meanr) 213 | 214 | 215 | def validation(opt, val_loader, model, fold=False): 216 | # compute the encoding for all the validation images and captions 217 | if opt.data_name == 'cc152k_precomp': 218 | per_captions = 1 219 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 220 | per_captions = 5 221 | else: 222 | print(f"No dataset") 223 | return 0 224 | 225 | model.val_start() 226 | print('Encoding with model') 227 | img_embs, cap_embs, cap_lens = encode_data(model.similarity_model, val_loader) 228 | # clear duplicate 5*images and keep 1*images FIXME 229 | if not fold: 230 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 231 | # record computation time of validation 232 | print('Computing similarity from model') 233 | sims_mean = shard_attn_scores(model.similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000, 234 | mode="sim") 235 | print("Calculate similarity time with model") 236 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 237 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 238 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 239 | # image retrieval 240 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 241 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 242 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 243 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 244 | print("Sum of Recall: %.2f" % (r_sum)) 245 | else: 246 | # 5fold cross-validation, only for MSCOCO 247 | results = [] 248 | for i in range(5): 249 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 250 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 251 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 252 | sims = shard_attn_scores(model.similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, 253 | shard_size=1000, 254 | mode="not sim") 255 | 256 | print('Computing similarity from model') 257 | r, rt = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 258 | ri, rti = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 259 | 260 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 261 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 262 | ar = (r[0] + r[1] + r[2]) / 3 263 | ari = (ri[0] + ri[1] + ri[2]) / 3 264 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 265 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 266 | results += [list(r) + list(ri) + [ar, ari, rsum]] 267 | print("-----------------------------------") 268 | print("Mean metrics: ") 269 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 270 | a = np.array(mean_metrics) 271 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 272 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 273 | mean_metrics[:5]) 274 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 275 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 276 | mean_metrics[5:10]) 277 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 278 | 279 | 280 | def validation_dul(opt, val_loader, models, fold=False): 281 | # compute the encoding for all the validation images and captions 282 | if opt.data_name == 'cc152k_precomp': 283 | per_captions = 1 284 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 285 | per_captions = 5 286 | else: 287 | print(f"No dataset") 288 | return 0 289 | 290 | models[0].val_start() 291 | models[1].val_start() 292 | print('Encoding with model') 293 | img_embs, cap_embs, cap_lens = encode_data(models[0].similarity_model, val_loader, opt.log_step) 294 | img_embs1, cap_embs1, cap_lens1 = encode_data(models[1].similarity_model, val_loader, opt.log_step) 295 | if not fold: 296 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 297 | img_embs1 = np.array([img_embs1[i] for i in range(0, len(img_embs1), per_captions)]) 298 | # record computation time of validation 299 | print('Computing similarity from model') 300 | sims_mean = shard_attn_scores(models[0].similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000, 301 | mode="sim") 302 | sims_mean += shard_attn_scores(models[1].similarity_model, img_embs1, cap_embs1, cap_lens1, opt, 303 | shard_size=1000, mode="sim") 304 | sims_mean /= 2 305 | print("Calculate similarity time with model") 306 | # caption retrieval 307 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 308 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 309 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 310 | # image retrieval 311 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 312 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 313 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 314 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 315 | print("Sum of Recall: %.2f" % (r_sum)) 316 | return r_sum 317 | else: 318 | # 5fold cross-validation, only for MSCOCO 319 | results = [] 320 | for i in range(5): 321 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 322 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 323 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 324 | 325 | img_embs_shard1 = img_embs1[i * 5000:(i + 1) * 5000:5] 326 | cap_embs_shard1 = cap_embs1[i * 5000:(i + 1) * 5000] 327 | cap_lens_shard1 = cap_lens1[i * 5000:(i + 1) * 5000] 328 | sims = shard_attn_scores(models[0].similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, 329 | shard_size=1000, 330 | mode="not sim") 331 | sims += shard_attn_scores(models[1].similarity_model, img_embs_shard1, cap_embs_shard1, cap_lens_shard1, 332 | opt, 333 | shard_size=1000, 334 | mode="not sim") 335 | sims /= 2 336 | 337 | print('Computing similarity from model') 338 | r, rt0 = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 339 | ri, rti0 = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 340 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 341 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 342 | if i == 0: 343 | rt, rti = rt0, rti0 344 | ar = (r[0] + r[1] + r[2]) / 3 345 | ari = (ri[0] + ri[1] + ri[2]) / 3 346 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 347 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 348 | results += [list(r) + list(ri) + [ar, ari, rsum]] 349 | 350 | print("-----------------------------------") 351 | print("Mean metrics: ") 352 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 353 | a = np.array(mean_metrics) 354 | 355 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 356 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 357 | mean_metrics[:5]) 358 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 359 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 360 | mean_metrics[5:10]) 361 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 362 | 363 | 364 | def eval_DECL(checkpoint_paths, avg_SGRAF=True, data_path=None, vocab_path=None): 365 | if avg_SGRAF is False: 366 | print(f"Load checkpoint from '{checkpoint_paths[0]}'") 367 | checkpoint = torch.load(checkpoint_paths[0]) 368 | opt = checkpoint['opt'] 369 | opt.ssl = False 370 | print( 371 | f"Noise ratio is {opt.noise_ratio}, module is {opt.module_name}, best validation epoch is {checkpoint['epoch']} ({checkpoint['best_rsum']})") 372 | if vocab_path != None: 373 | opt.vocab_path = vocab_path 374 | if data_path != None: 375 | opt.data_path = data_path 376 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 377 | model = RINCE(opt) 378 | model.load_state_dict(checkpoint['model']) 379 | if 'coco' in opt.data_name: 380 | test_loader = get_test_loader('testall', opt.data_name, vocab, 100, 0, opt) 381 | validation(opt, test_loader, model=model, fold=True) 382 | validation(opt, test_loader, model=model, fold=False) 383 | else: 384 | test_loader = get_test_loader('test', opt.data_name, vocab, 100, 0, opt) 385 | validation(opt, test_loader, model=model, fold=False) 386 | else: 387 | assert len(checkpoint_paths) == 2 388 | print(f"Load checkpoint from '{checkpoint_paths}'") 389 | checkpoint0 = torch.load(checkpoint_paths[0]) 390 | checkpoint1 = torch.load(checkpoint_paths[1]) 391 | opt0 = checkpoint0['opt'] 392 | opt1 = checkpoint1['opt'] 393 | print( 394 | f"Noise ratios are {opt0.noise_ratio} and {opt1.noise_ratio}, " 395 | f"modules are {opt0.module_name} and {opt1.module_name}, best validation epochs are {checkpoint0['epoch']}" 396 | f" ({checkpoint0['best_rsum']}) and {checkpoint1['epoch']} ({checkpoint1['best_rsum']})") 397 | vocab = deserialize_vocab(os.path.join(opt0.vocab_path, '%s_vocab.json' % opt0.data_name)) 398 | model0 = RINCE(opt0) 399 | model1 = RINCE(opt1) 400 | 401 | model0.load_state_dict(checkpoint0['model']) 402 | model1.load_state_dict(checkpoint1['model']) 403 | if 'coco' in opt0.data_name: 404 | test_loader = get_test_loader('testall', opt0.data_name, vocab, 100, 0, opt0) 405 | print(f'=====>model {opt0.module_name} fold:True') 406 | validation(opt0, test_loader, model0, fold=True) 407 | print(f'=====>model {opt1.module_name} fold:True') 408 | validation(opt0, test_loader, model1, fold=True) 409 | print(f'=====>model SGRAF fold:True') 410 | validation_dul(opt0, test_loader, models=[model0, model1], fold=True) 411 | 412 | print(f'=====>model {opt0.module_name} fold:False') 413 | validation(opt0, test_loader, model0, fold=False) 414 | print(f'=====>model {opt1.module_name} fold:False') 415 | validation(opt0, test_loader, model1, fold=False) 416 | print('=====>model SGRAF fold:False') 417 | validation_dul(opt0, test_loader, models=[model0, model1], fold=False) 418 | 419 | else: 420 | test_loader = get_test_loader('test', opt0.data_name, vocab, 100, 0, opt0) 421 | print(f'=====>model {opt0.module_name} fold:False') 422 | validation(opt0, test_loader, model0, fold=False) 423 | print(f'=====>model {opt1.module_name} fold:False') 424 | validation(opt0, test_loader, model1, fold=False) 425 | print('=====>model SGRAF fold:False') 426 | validation_dul(opt0, test_loader, models=[model0, model1], fold=False) 427 | -------------------------------------------------------------------------------- /GSC+/evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation""" 2 | 3 | from __future__ import print_function 4 | import os 5 | import sys 6 | import torch 7 | import numpy as np 8 | 9 | from model.RINCE import RINCE 10 | from data import get_test_loader 11 | from vocab import deserialize_vocab 12 | from collections import OrderedDict 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=0): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / (.0001 + self.count) 32 | 33 | def __str__(self): 34 | """String representation for logging 35 | """ 36 | # for values that should be recorded exactly e.g. iteration number 37 | if self.count == 0: 38 | return str(self.val) 39 | # for stats 40 | return '%.4f (%.4f)' % (self.val, self.avg) 41 | 42 | 43 | class LogCollector(object): 44 | """A collection of logging objects that can change from train to val""" 45 | 46 | def __init__(self): 47 | # to keep the order of logged variables deterministic 48 | self.meters = OrderedDict() 49 | 50 | def update(self, k, v, n=0): 51 | # create a new meter if previously not recorded 52 | if k not in self.meters: 53 | self.meters[k] = AverageMeter() 54 | self.meters[k].update(v, n) 55 | 56 | def __str__(self): 57 | """Concatenate the meters in one log line 58 | """ 59 | s = '' 60 | for i, (k, v) in enumerate(self.meters.items()): 61 | if i > 0: 62 | s += ' ' 63 | s += k + ' ' + str(v) 64 | return s 65 | 66 | def tb_log(self, tb_logger, prefix='', step=None): 67 | """Log using tensorboard 68 | """ 69 | for k, v in self.meters.items(): 70 | tb_logger.log_value(prefix + k, v.val, step=step) 71 | 72 | 73 | def encode_data(model, data_loader, log_step=10, logging=print, sub=0): 74 | """Encode all images and captions loadable by `data_loader` 75 | """ 76 | val_logger = LogCollector() 77 | 78 | # switch to evaluate mode 79 | model.val_start() 80 | 81 | # np array to keep all the embeddings 82 | img_embs = None 83 | cap_embs = None 84 | 85 | max_n_word = 0 86 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader): 87 | max_n_word = max(max_n_word, max(lengths)) 88 | ids_ = [] 89 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader): 90 | # make sure val logger is used 91 | model.logger = val_logger 92 | ids_ += ids 93 | # compute the embeddings 94 | with torch.no_grad(): 95 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths) 96 | if img_embs is None: 97 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 98 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 99 | cap_lens = [0] * len(data_loader.dataset) 100 | # cache embeddings 101 | img_embs[ids, :, :] = img_emb.data.cpu().numpy().copy() 102 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy() 103 | 104 | for j, nid in enumerate(ids): 105 | cap_lens[nid] = cap_len[j] 106 | 107 | del images, captions 108 | if sub > 0: 109 | print(f"===>batch {i}") 110 | if sub > 0 and i > sub: 111 | break 112 | if sub > 0: 113 | return np.array(img_embs)[ids_].tolist(), np.array(cap_embs)[ids_].tolist(), np.array(cap_lens)[ 114 | ids_].tolist(), ids_ 115 | else: 116 | return img_embs, cap_embs, cap_lens 117 | 118 | 119 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100, mode="sim"): 120 | n_im_shard = (len(img_embs) - 1) // shard_size + 1 121 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1 122 | 123 | sims = np.zeros((len(img_embs), len(cap_embs))) 124 | for i in range(n_im_shard): 125 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs)) 126 | for j in range(n_cap_shard): 127 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j)) 128 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs)) 129 | with torch.no_grad(): 130 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda() 131 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda() 132 | l = cap_lens[ca_start:ca_end] 133 | if mode == "sim": 134 | sim = model.forward_sim(im, ca, l, mode) 135 | else: 136 | _, sim, _ = model.forward_sim(im, ca, l, mode) # Calculate evidence for retrieval 137 | 138 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy() 139 | sys.stdout.write('\n') 140 | return sims 141 | 142 | 143 | def t2i(npts, sims, per_captions=1, return_ranks=False): 144 | """ 145 | Text->Images (Image Search) 146 | Images: (N, n_region, d) matrix of images 147 | Captions: (per_captions * N, max_n_word, d) matrix of captions 148 | CapLens: (per_captions * N) array of caption lengths 149 | sims: (N, per_captions * N) matrix of similarity im-cap 150 | """ 151 | ranks = np.zeros(per_captions * npts) 152 | top1 = np.zeros(per_captions * npts) 153 | top5 = np.zeros((per_captions * npts, 5), dtype=int) 154 | 155 | # --> (per_captions * N(caption), N(image)) 156 | sims = sims.T 157 | retreivaled_index = [] 158 | for index in range(npts): 159 | for i in range(per_captions): 160 | inds = np.argsort(sims[per_captions * index + i])[::-1] 161 | retreivaled_index.append(inds) 162 | ranks[per_captions * index + i] = np.where(inds == index)[0][0] 163 | top1[per_captions * index + i] = inds[0] 164 | top5[per_captions * index + i] = inds[0:5] 165 | 166 | # Compute metrics 167 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 168 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 169 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 170 | medr = np.floor(np.median(ranks)) + 1 171 | meanr = ranks.mean() + 1 172 | if return_ranks: 173 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 174 | else: 175 | return (r1, r5, r10, medr, meanr) 176 | 177 | 178 | def i2t(npts, sims, per_captions=1, return_ranks=False): 179 | """ 180 | Images->Text (Image Annotation) 181 | Images: (N, n_region, d) matrix of images 182 | Captions: (per_captions * N, max_n_word, d) matrix of captions 183 | CapLens: (per_captions * N) array of caption lengths 184 | sims: (N, per_captions * N) matrix of similarity im-cap 185 | """ 186 | ranks = np.zeros(npts) 187 | top1 = np.zeros(npts) 188 | top5 = np.zeros((npts, 5), dtype=int) 189 | retreivaled_index = [] 190 | for index in range(npts): 191 | inds = np.argsort(sims[index])[::-1] 192 | retreivaled_index.append(inds) 193 | # Score 194 | rank = 1e20 195 | for i in range(per_captions * index, per_captions * index + per_captions, 1): 196 | tmp = np.where(inds == i)[0][0] 197 | if tmp < rank: 198 | rank = tmp 199 | ranks[index] = rank 200 | top1[index] = inds[0] 201 | top5[index] = inds[0:5] 202 | 203 | # Compute metrics 204 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 205 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 206 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 207 | medr = np.floor(np.median(ranks)) + 1 208 | meanr = ranks.mean() + 1 209 | if return_ranks: 210 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 211 | else: 212 | return (r1, r5, r10, medr, meanr) 213 | 214 | 215 | def validation(opt, val_loader, model_A, model_B, fold=False): 216 | # compute the encoding for all the validation images and captions 217 | if opt.data_name == 'cc152k_precomp': 218 | per_captions = 1 219 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 220 | per_captions = 5 221 | else: 222 | print(f"No dataset") 223 | return 0 224 | 225 | model_A.val_start() 226 | model_B.val_start() 227 | print('Encoding with model') 228 | img_embs_A, cap_embs_A, cap_lens_A = encode_data(model_A.similarity_model, val_loader, opt.log_step) 229 | if not fold: 230 | img_embs_A = np.array([img_embs_A[i] for i in range(0, len(img_embs_A), per_captions)]) 231 | 232 | img_embs_B, cap_embs_B, cap_lens_B = encode_data(model_B.similarity_model, val_loader, opt.log_step) 233 | if not fold: 234 | img_embs_B = np.array([img_embs_B[i] for i in range(0, len(img_embs_B), per_captions)]) 235 | 236 | if not fold: 237 | # record computation time of validation 238 | print('Computing similarity from model') 239 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_A, cap_embs_A, cap_lens_A, opt, shard_size=1000) 240 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_B, cap_embs_B, cap_lens_B, opt, shard_size=1000) 241 | sims = (sims_A + sims_B) / 2 242 | # sims = sims_B 243 | print("Calculate similarity time with model") 244 | (r1, r5, r10, medr, meanr) = i2t(img_embs_A.shape[0], sims, per_captions, return_ranks=False) 245 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 246 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 247 | # image retrieval 248 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs_A.shape[0], sims, per_captions, return_ranks=False) 249 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 250 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 251 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 252 | print("Sum of Recall: %.2f" % (r_sum)) 253 | else: 254 | # 5fold cross-validation, only for MSCOCO 255 | results = [] 256 | for i in range(5): 257 | img_embs_shard_A = img_embs_A[i * 5000:(i + 1) * 5000:5] 258 | cap_embs_shard_A = cap_embs_A[i * 5000:(i + 1) * 5000] 259 | cap_lens_shard_A = cap_lens_A[i * 5000:(i + 1) * 5000] 260 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_shard_A, cap_embs_shard_A, cap_lens_shard_A, opt, 261 | shard_size=1000, 262 | mode="sim") 263 | 264 | img_embs_shard_B = img_embs_B[i * 5000:(i + 1) * 5000:5] 265 | cap_embs_shard_B = cap_embs_B[i * 5000:(i + 1) * 5000] 266 | cap_lens_shard_B = cap_lens_B[i * 5000:(i + 1) * 5000] 267 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_shard_B, cap_embs_shard_B, cap_lens_shard_B, opt, 268 | shard_size=1000, 269 | mode="sim") 270 | 271 | sims = (sims_A + sims_B) / 2 272 | 273 | print('Computing similarity from model') 274 | r, rt = i2t(img_embs_shard_A.shape[0], sims, per_captions, return_ranks=True) 275 | ri, rti = t2i(img_embs_shard_A.shape[0], sims, per_captions, return_ranks=True) 276 | 277 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 278 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 279 | ar = (r[0] + r[1] + r[2]) / 3 280 | ari = (ri[0] + ri[1] + ri[2]) / 3 281 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 282 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 283 | results += [list(r) + list(ri) + [ar, ari, rsum]] 284 | print("-----------------------------------") 285 | print("Mean metrics: ") 286 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 287 | a = np.array(mean_metrics) 288 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 289 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 290 | mean_metrics[:5]) 291 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 292 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 293 | mean_metrics[5:10]) 294 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 295 | 296 | 297 | def validation_dul(opt, val_loader, models, fold=False): 298 | # compute the encoding for all the validation images and captions 299 | if opt.data_name == 'cc152k_precomp': 300 | per_captions = 1 301 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 302 | per_captions = 5 303 | else: 304 | print(f"No dataset") 305 | return 0 306 | 307 | models[0].val_start() 308 | models[1].val_start() 309 | models[2].val_start() 310 | models[3].val_start() 311 | print('Encoding with model') 312 | img_embs, cap_embs, cap_lens = encode_data(models[0].similarity_model, val_loader, opt.log_step) 313 | img_embs1, cap_embs1, cap_lens1 = encode_data(models[1].similarity_model, val_loader, opt.log_step) 314 | img_embs2, cap_embs2, cap_lens2 = encode_data(models[2].similarity_model, val_loader, opt.log_step) 315 | img_embs3, cap_embs3, cap_lens3 = encode_data(models[3].similarity_model, val_loader, opt.log_step) 316 | if not fold: 317 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 318 | img_embs1 = np.array([img_embs1[i] for i in range(0, len(img_embs1), per_captions)]) 319 | img_embs2 = np.array([img_embs2[i] for i in range(0, len(img_embs2), per_captions)]) 320 | img_embs3 = np.array([img_embs3[i] for i in range(0, len(img_embs3), per_captions)]) 321 | # record computation time of validation 322 | print('Computing similarity from model') 323 | sims_mean = shard_attn_scores(models[0].similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000, 324 | mode="sim") 325 | sims_mean += shard_attn_scores(models[1].similarity_model, img_embs1, cap_embs1, cap_lens1, opt, 326 | shard_size=1000, mode="sim") 327 | sims_mean += shard_attn_scores(models[2].similarity_model, img_embs2, cap_embs2, cap_lens2, opt, 328 | shard_size=1000, mode="sim") 329 | sims_mean += shard_attn_scores(models[3].similarity_model, img_embs3, cap_embs3, cap_lens3, opt, 330 | shard_size=1000, mode="sim") 331 | sims_mean /= 4 332 | print("Calculate similarity time with model") 333 | # caption retrieval 334 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 335 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 336 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 337 | # image retrieval 338 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 339 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 340 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 341 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 342 | print("Sum of Recall: %.2f" % (r_sum)) 343 | return r_sum 344 | else: 345 | # 5fold cross-validation, only for MSCOCO 346 | results = [] 347 | for i in range(5): 348 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 349 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 350 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 351 | 352 | img_embs_shard1 = img_embs1[i * 5000:(i + 1) * 5000:5] 353 | cap_embs_shard1 = cap_embs1[i * 5000:(i + 1) * 5000] 354 | cap_lens_shard1 = cap_lens1[i * 5000:(i + 1) * 5000] 355 | sims = shard_attn_scores(models[0].similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, 356 | shard_size=1000, 357 | mode="not sim") 358 | sims += shard_attn_scores(models[1].similarity_model, img_embs_shard1, cap_embs_shard1, cap_lens_shard1, 359 | opt, 360 | shard_size=1000, 361 | mode="not sim") 362 | sims /= 2 363 | 364 | print('Computing similarity from model') 365 | r, rt0 = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 366 | ri, rti0 = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 367 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 368 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 369 | if i == 0: 370 | rt, rti = rt0, rti0 371 | ar = (r[0] + r[1] + r[2]) / 3 372 | ari = (ri[0] + ri[1] + ri[2]) / 3 373 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 374 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 375 | results += [list(r) + list(ri) + [ar, ari, rsum]] 376 | 377 | print("-----------------------------------") 378 | print("Mean metrics: ") 379 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 380 | a = np.array(mean_metrics) 381 | 382 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 383 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 384 | mean_metrics[:5]) 385 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 386 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 387 | mean_metrics[5:10]) 388 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 389 | 390 | 391 | def eval_DECL(checkpoint_paths, avg_SGRAF=True, data_path=None, vocab_path=None): 392 | if avg_SGRAF is False: 393 | print(f"Load checkpoint from '{checkpoint_paths[0]}'") 394 | checkpoint = torch.load(checkpoint_paths[0]) 395 | opt = checkpoint['opt'] 396 | opt.ssl = False 397 | print( 398 | f"Noise ratio is {opt.noise_ratio}, module is {opt.module_name}, best validation epoch is {checkpoint['epoch']} ({checkpoint['best_rsum']})") 399 | if vocab_path != None: 400 | opt.vocab_path = vocab_path 401 | if data_path != None: 402 | opt.data_path = data_path 403 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 404 | model_A = RINCE(opt) 405 | model_B = RINCE(opt) 406 | model_A.load_state_dict(checkpoint['model_A']) 407 | model_B.load_state_dict(checkpoint['model_B']) 408 | 409 | if 'coco' in opt.data_name: 410 | test_loader = get_test_loader('testall', opt.data_name, vocab, 100, 0, opt) 411 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=True) 412 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=False) 413 | else: 414 | test_loader = get_test_loader('test', opt.data_name, vocab, 100, 0, opt) 415 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=False) 416 | else: 417 | assert len(checkpoint_paths) == 2 418 | print(f"Load checkpoint from '{checkpoint_paths}'") 419 | checkpoint0 = torch.load(checkpoint_paths[0]) 420 | checkpoint1 = torch.load(checkpoint_paths[1]) 421 | opt0 = checkpoint0['opt'] 422 | opt1 = checkpoint1['opt'] 423 | print( 424 | f"Noise ratios are {opt0.noise_ratio} and {opt1.noise_ratio}, " 425 | f"modules are {opt0.module_name} and {opt1.module_name}, best validation epochs are {checkpoint0['epoch']}" 426 | f" ({checkpoint0['best_rsum']}) and {checkpoint1['epoch']} ({checkpoint1['best_rsum']})") 427 | vocab = deserialize_vocab(os.path.join(opt0.vocab_path, '%s_vocab.json' % opt0.data_name)) 428 | model0_A = RINCE(opt0) 429 | model0_B = RINCE(opt0) 430 | model1_A = RINCE(opt1) 431 | model1_B = RINCE(opt1) 432 | 433 | model0_A.load_state_dict(checkpoint0['model_A']) 434 | model0_B.load_state_dict(checkpoint0['model_B']) 435 | model1_A.load_state_dict(checkpoint1['model_A']) 436 | model1_B.load_state_dict(checkpoint1['model_B']) 437 | if 'coco' in opt0.data_name: 438 | test_loader = get_test_loader('testall', opt0.data_name, vocab, 100, 0, opt0) 439 | print(f'=====>model {opt0.module_name} fold:True') 440 | validation(opt0, test_loader, model0_A, model0_B, fold=True) 441 | print(f'=====>model {opt1.module_name} fold:True') 442 | validation(opt0, test_loader, model1_A, model1_B, fold=True) 443 | print(f'=====>model SGRAF fold:True') 444 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=True) 445 | 446 | print(f'=====>model {opt0.module_name} fold:False') 447 | validation(opt0, test_loader, model0_A, model0_B, fold=False) 448 | print(f'=====>model {opt1.module_name} fold:False') 449 | validation(opt0, test_loader, model0_A, model0_B, fold=False) 450 | print('=====>model SGRAF fold:False') 451 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=False) 452 | 453 | else: 454 | test_loader = get_test_loader('test', opt0.data_name, vocab, 100, 0, opt0) 455 | print(f'=====>model {opt0.module_name} fold:False') 456 | validation(opt0, test_loader, model0_A, model0_B, fold=False) 457 | print(f'=====>model {opt1.module_name} fold:False') 458 | validation(opt0, test_loader, model1_A, model1_B, fold=False) 459 | print('=====>model SGRAF fold:False') 460 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=False) -------------------------------------------------------------------------------- /GSC/model/SGRAF.py: -------------------------------------------------------------------------------- 1 | """SGRAF model""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | 8 | import torch.backends.cudnn as cudnn 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | from torch.nn.utils.clip_grad import clip_grad_norm_ 11 | 12 | import numpy as np 13 | from collections import OrderedDict 14 | 15 | 16 | def l1norm(X, dim, eps=1e-8): 17 | """L1-normalize columns of X""" 18 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 19 | X = torch.div(X, norm) 20 | return X 21 | 22 | 23 | def l2norm(X, dim=-1, eps=1e-8): 24 | """L2-normalize columns of X""" 25 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 26 | X = torch.div(X, norm) 27 | return X 28 | 29 | 30 | def cosine_sim(x1, x2, dim=-1, eps=1e-8): 31 | """Returns cosine similarity between x1 and x2, computed along dim.""" 32 | w12 = torch.sum(x1 * x2, dim) 33 | w1 = torch.norm(x1, 2, dim) 34 | w2 = torch.norm(x2, 2, dim) 35 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 36 | 37 | 38 | class EncoderImage(nn.Module): 39 | """ 40 | Build local region representations by common-used FC-layer. 41 | Args: - images: raw local detected regions, shape: (batch_size, 36, 2048). 42 | Returns: - img_emb: finial local region embeddings, shape: (batch_size, 36, 1024). 43 | """ 44 | 45 | def __init__(self, opt, img_dim, embed_size, no_imgnorm=False): 46 | super(EncoderImage, self).__init__() 47 | self.opt = opt 48 | self.embed_size = embed_size 49 | self.no_imgnorm = no_imgnorm 50 | self.fc = nn.Linear(img_dim, embed_size) 51 | 52 | self.init_weights(self.fc) 53 | 54 | def init_weights(self, fc): 55 | """Xavier initialization for the fully connected layer""" 56 | r = np.sqrt(6.) / np.sqrt(fc.in_features + 57 | fc.out_features) 58 | fc.weight.data.uniform_(-r, r) 59 | fc.bias.data.fill_(0) 60 | 61 | def forward(self, images): 62 | """Extract image feature vectors.""" 63 | # assuming that the precomputed features are already l2-normalized 64 | img_emb = self.fc(images) 65 | 66 | # normalize in the joint embedding space 67 | if not self.no_imgnorm: 68 | img_emb = l2norm(img_emb, dim=-1) 69 | 70 | return img_emb 71 | 72 | def load_state_dict(self, state_dict): 73 | """Overwrite the default one to accept state_dict from Full model""" 74 | own_state = self.state_dict() 75 | new_state = OrderedDict() 76 | for name, param in state_dict.items(): 77 | if name in own_state: 78 | new_state[name] = param 79 | 80 | super(EncoderImage, self).load_state_dict(new_state) 81 | 82 | 83 | class EncoderText(nn.Module): 84 | """ 85 | Build local word representations by common-used Bi-GRU or GRU. 86 | Args: - images: raw local word ids, shape: (batch_size, L). 87 | Returns: - img_emb: final local word embeddings, shape: (batch_size, L, 1024). 88 | """ 89 | 90 | def __init__(self, opt, vocab_size, word_dim, embed_size, num_layers, 91 | use_bi_gru=False, no_txtnorm=False): 92 | super(EncoderText, self).__init__() 93 | self.opt = opt 94 | self.embed_size = embed_size 95 | self.no_txtnorm = no_txtnorm 96 | 97 | # word embedding 98 | self.embed = nn.Embedding(vocab_size, word_dim) 99 | self.dropout = nn.Dropout(0.4) 100 | 101 | # caption embedding 102 | self.use_bi_gru = use_bi_gru 103 | self.cap_rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) 104 | 105 | self.init_weights() 106 | 107 | def init_weights(self): 108 | self.embed.weight.data.uniform_(-0.1, 0.1) 109 | 110 | def forward(self, captions, lengths): 111 | """Handles variable size captions""" 112 | # embed word ids to vectors 113 | cap_emb = self.embed(captions) 114 | cap_emb = self.dropout(cap_emb) 115 | 116 | # pack the caption 117 | packed = pack_padded_sequence(cap_emb, lengths, batch_first=True) 118 | 119 | # forward propagate RNN 120 | out, _ = self.cap_rnn(packed) 121 | 122 | # reshape output to (batch_size, hidden_size) 123 | cap_emb, _ = pad_packed_sequence(out, batch_first=True) 124 | 125 | if self.use_bi_gru: 126 | cap_emb = (cap_emb[:, :, :cap_emb.size(2) // 2] + cap_emb[:, :, cap_emb.size(2) // 2:]) / 2 127 | 128 | # normalization in the joint embedding space 129 | if not self.no_txtnorm: 130 | cap_emb = l2norm(cap_emb, dim=-1) 131 | 132 | return cap_emb 133 | 134 | 135 | class VisualSA(nn.Module): 136 | """ 137 | Build global image representations by self-attention. 138 | Args: - local: local region embeddings, shape: (batch_size, 36, 1024) 139 | - raw_global: raw image by averaging regions, shape: (batch_size, 1024) 140 | Returns: - new_global: final image by self-attention, shape: (batch_size, 1024). 141 | """ 142 | 143 | def __init__(self, embed_dim, dropout_rate, num_region): 144 | super(VisualSA, self).__init__() 145 | 146 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 147 | nn.BatchNorm1d(num_region), 148 | nn.Tanh(), nn.Dropout(dropout_rate)) 149 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 150 | nn.BatchNorm1d(embed_dim), 151 | nn.Tanh(), nn.Dropout(dropout_rate)) 152 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 153 | 154 | self.init_weights() 155 | self.softmax = nn.Softmax(dim=1) 156 | 157 | def init_weights(self): 158 | for embeddings in self.children(): 159 | for m in embeddings: 160 | if isinstance(m, nn.Linear): 161 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 162 | m.weight.data.uniform_(-r, r) 163 | m.bias.data.fill_(0) 164 | elif isinstance(m, nn.BatchNorm1d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def forward(self, local, raw_global): 169 | # compute embedding of local regions and raw global image 170 | if len(local) == 1 and len(raw_global) == 1: 171 | l_emb = self.embedding_local[0](local) 172 | l_emb = self.embedding_local[2](l_emb) 173 | l_emb = self.embedding_local[3](l_emb) 174 | g_emb = self.embedding_global[0](raw_global) 175 | g_emb = self.embedding_global[2](g_emb) 176 | g_emb = self.embedding_global[3](g_emb) 177 | else: 178 | l_emb = self.embedding_local(local) 179 | g_emb = self.embedding_global(raw_global) 180 | 181 | # compute the normalized weights, shape: (batch_size, 36) 182 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 183 | common = l_emb.mul(g_emb) 184 | weights = self.embedding_common(common).squeeze(2) 185 | weights = self.softmax(weights) 186 | 187 | # compute final image, shape: (batch_size, 1024) 188 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 189 | new_global = l2norm(new_global, dim=-1) 190 | 191 | return new_global 192 | 193 | 194 | class TextSA(nn.Module): 195 | """ 196 | Build global text representations by self-attention. 197 | Args: - local: local word embeddings, shape: (batch_size, L, 1024) 198 | - raw_global: raw text by averaging words, shape: (batch_size, 1024) 199 | Returns: - new_global: final text by self-attention, shape: (batch_size, 1024). 200 | """ 201 | 202 | def __init__(self, embed_dim, dropout_rate): 203 | super(TextSA, self).__init__() 204 | 205 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 206 | nn.Tanh(), nn.Dropout(dropout_rate)) 207 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 208 | nn.Tanh(), nn.Dropout(dropout_rate)) 209 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 210 | 211 | self.init_weights() 212 | self.softmax = nn.Softmax(dim=1) 213 | 214 | def init_weights(self): 215 | for embeddings in self.children(): 216 | for m in embeddings: 217 | if isinstance(m, nn.Linear): 218 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 219 | m.weight.data.uniform_(-r, r) 220 | m.bias.data.fill_(0) 221 | elif isinstance(m, nn.BatchNorm1d): 222 | m.weight.data.fill_(1) 223 | m.bias.data.zero_() 224 | 225 | def forward(self, local, raw_global): 226 | # compute embedding of local words and raw global text 227 | l_emb = self.embedding_local(local) 228 | g_emb = self.embedding_global(raw_global) 229 | 230 | # compute the normalized weights, shape: (batch_size, L) 231 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 232 | common = l_emb.mul(g_emb) 233 | weights = self.embedding_common(common).squeeze(2) 234 | weights = self.softmax(weights) 235 | 236 | # compute final text, shape: (batch_size, 1024) 237 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 238 | new_global = l2norm(new_global, dim=-1) 239 | 240 | return new_global 241 | 242 | 243 | class GraphReasoning(nn.Module): 244 | """ 245 | Perform the similarity graph reasoning with a full-connected graph 246 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 247 | Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256) 248 | """ 249 | 250 | def __init__(self, sim_dim): 251 | super(GraphReasoning, self).__init__() 252 | 253 | self.graph_query_w = nn.Linear(sim_dim, sim_dim) 254 | self.graph_key_w = nn.Linear(sim_dim, sim_dim) 255 | self.sim_graph_w = nn.Linear(sim_dim, sim_dim) 256 | # self.relu = nn.LeakyReLU(negative_slope=0.1) 257 | self.relu = nn.ReLU() 258 | 259 | self.init_weights() 260 | 261 | def forward(self, sim_emb): 262 | sim_query = self.graph_query_w(sim_emb) 263 | sim_key = self.graph_key_w(sim_emb) 264 | sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1) 265 | sim_sgr = torch.bmm(sim_edge, sim_emb) 266 | sim_sgr = self.relu(self.sim_graph_w(sim_sgr)) 267 | return sim_sgr 268 | 269 | def init_weights(self): 270 | for m in self.children(): 271 | if isinstance(m, nn.Linear): 272 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 273 | m.weight.data.uniform_(-r, r) 274 | m.bias.data.fill_(0) 275 | elif isinstance(m, nn.BatchNorm1d): 276 | m.weight.data.fill_(1) 277 | m.bias.data.zero_() 278 | 279 | 280 | class AttentionFiltration(nn.Module): 281 | """ 282 | Perform the similarity Attention Filtration with a gate-based attention 283 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 284 | Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256) 285 | """ 286 | 287 | def __init__(self, sim_dim): 288 | super(AttentionFiltration, self).__init__() 289 | 290 | self.attn_sim_w = nn.Linear(sim_dim, 1) 291 | self.bn = nn.BatchNorm1d(1) 292 | 293 | self.init_weights() 294 | 295 | def forward(self, sim_emb): 296 | sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1) 297 | sim_saf = torch.matmul(sim_attn, sim_emb) 298 | sim_saf = l2norm(sim_saf.squeeze(1), dim=-1) 299 | return sim_saf 300 | 301 | def init_weights(self): 302 | for m in self.children(): 303 | if isinstance(m, nn.Linear): 304 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 305 | m.weight.data.uniform_(-r, r) 306 | m.bias.data.fill_(0) 307 | elif isinstance(m, nn.BatchNorm1d): 308 | m.weight.data.fill_(1) 309 | m.bias.data.zero_() 310 | 311 | 312 | class EncoderSimilarity(nn.Module): 313 | """ 314 | Compute the image-text similarity by SGR, SAF, AVE 315 | Args: - img_emb: local region embeddings, shape: (batch_size, 36, 1024) 316 | - cap_emb: local word embeddings, shape: (batch_size, L, 1024) 317 | Returns: 318 | - sim_all: final image-text similarities, shape: (batch_size, batch_size). 319 | """ 320 | 321 | def __init__(self, opt, embed_size, sim_dim, module_name='AVE', sgr_step=3): 322 | super(EncoderSimilarity, self).__init__() 323 | self.opt = opt 324 | self.module_name = module_name 325 | 326 | self.v_global_w = VisualSA(embed_size, 0.4, 36) 327 | self.t_global_w = TextSA(embed_size, 0.4) 328 | 329 | self.sim_tranloc_w = nn.Linear(embed_size, sim_dim) 330 | self.sim_tranglo_w = nn.Linear(embed_size, sim_dim) 331 | 332 | self.sim_eval_w = nn.Linear(sim_dim, 1) 333 | self.sigmoid = nn.Sigmoid() 334 | 335 | if module_name == 'SGR': 336 | self.SGR_module = nn.ModuleList([GraphReasoning(sim_dim) for i in range(sgr_step)]) 337 | elif module_name == 'SAF': 338 | self.SAF_module = AttentionFiltration(sim_dim) 339 | else: 340 | raise ValueError('Invalid input of opt.module_name in opts.py') 341 | 342 | self.init_weights() 343 | 344 | # Modifications by extending DECL 345 | def forward(self, img_emb, cap_emb, cap_lens): 346 | sim_all = [] 347 | n_image = img_emb.size(0) 348 | n_caption = cap_emb.size(0) 349 | 350 | # get enhanced global images by self-attention 351 | img_ave = torch.mean(img_emb, 1) 352 | img_glo = self.v_global_w(img_emb, img_ave) 353 | 354 | for i in range(n_caption): 355 | # get the i-th sentence 356 | n_word = cap_lens[i] 357 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 358 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 359 | 360 | # get enhanced global i-th text by self-attention 361 | cap_ave_i = torch.mean(cap_i, 1) 362 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 363 | 364 | # local-global alignment construction 365 | Context_img = SCAN_attention(cap_i_expand, img_emb, smooth=9.0) 366 | sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2) 367 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 368 | 369 | sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2) 370 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 371 | 372 | # concat the global and local alignments 373 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 374 | 375 | # compute the final similarity vector 376 | if self.module_name == 'SGR': 377 | for module in self.SGR_module: 378 | sim_emb = module(sim_emb) 379 | sim_vec = sim_emb[:, 0, :] 380 | else: 381 | sim_vec = self.SAF_module(sim_emb) 382 | 383 | # compute the final similarity score 384 | # sim_i = self.sigmoid(self.sim_eval_w(sim_vec)) 385 | # sim_all.append(sim_i) 386 | sim_all.append(self.sim_eval_w(sim_vec)) 387 | 388 | # (n_image, n_caption) 389 | sim_all = torch.cat(sim_all, 1) 390 | 391 | return sim_all 392 | 393 | def forward_topology(self, img_emb, cap_emb, cap_lens): 394 | sim_img = [] 395 | n_image = img_emb.size(0) 396 | 397 | img_ave = torch.mean(img_emb, 1) 398 | img_glo = self.v_global_w(img_emb, img_ave) 399 | 400 | for i in range(n_image): 401 | img_i = img_emb[i].unsqueeze(0) 402 | img_i_expand = img_i.repeat(n_image, 1, 1) 403 | 404 | img_glo_i = img_glo[i].unsqueeze(0) 405 | 406 | sim_loc = torch.pow(torch.sub(img_emb, img_i_expand), 2) + 1e-6 407 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 408 | 409 | sim_glo = torch.pow(torch.sub(img_glo, img_glo_i), 2) + 1e-6 410 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 411 | 412 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 413 | 414 | if self.module_name == 'SGR': 415 | for module in self.SGR_module: 416 | sim_emb = module(sim_emb) 417 | sim_vec = sim_emb[:, 0, :] 418 | else: 419 | sim_vec = self.SAF_module(sim_emb) 420 | 421 | sim_img.append(self.sim_eval_w(sim_vec)) 422 | 423 | sim_img = torch.cat(sim_img, 1) 424 | 425 | sim_cap = [] 426 | n_caption = cap_emb.size(0) 427 | 428 | cap_glo = [] 429 | for i in range (n_caption): 430 | n_word = cap_lens[i] 431 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 432 | cap_ave_i = torch.mean(cap_i, 1) 433 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 434 | cap_glo.append(cap_glo_i) 435 | cap_glo = torch.cat(cap_glo, 0) 436 | 437 | for i in range(n_caption): 438 | n_word = cap_lens[i] 439 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 440 | cap_i_expand = cap_i.repeat(n_caption, 1, 1) 441 | cap_glo_i = cap_glo[i].unsqueeze(0) 442 | 443 | Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens) 444 | sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2) 445 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 446 | 447 | sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6 448 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 449 | 450 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 451 | 452 | if self.module_name == 'SGR': 453 | for module in self.SGR_module: 454 | sim_emb = module(sim_emb) 455 | sim_vec = sim_emb[:, 0, :] 456 | else: 457 | sim_vec = self.SAF_module(sim_emb) 458 | 459 | sim_cap.append(self.sim_eval_w(sim_vec)) 460 | 461 | sim_cap = torch.cat(sim_cap, 1) 462 | 463 | return sim_img, sim_cap 464 | 465 | 466 | def init_weights(self): 467 | for m in self.children(): 468 | if isinstance(m, nn.Linear): 469 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 470 | m.weight.data.uniform_(-r, r) 471 | m.bias.data.fill_(0) 472 | elif isinstance(m, nn.BatchNorm1d): 473 | m.weight.data.fill_(1) 474 | m.bias.data.zero_() 475 | 476 | 477 | def SCAN_attention(query, context, smooth, eps=1e-8, lens=None): 478 | """ 479 | query: (n_context, queryL, d) 480 | context: (n_context, sourceL, d) 481 | """ 482 | # --> (batch, d, queryL) 483 | queryT = torch.transpose(query, 1, 2) 484 | 485 | # (batch, sourceL, d)(batch, d, queryL) 486 | # --> (batch, sourceL, queryL) 487 | attn = torch.bmm(context, queryT) 488 | 489 | attn = nn.LeakyReLU(0.1)(attn) 490 | attn = l2norm(attn, 2) 491 | 492 | if lens is not None: 493 | for i in range(len(lens)): 494 | attn[i, lens[i]:] = -1e9 495 | 496 | # --> (batch, queryL, sourceL) 497 | attn = torch.transpose(attn, 1, 2).contiguous() 498 | # --> (batch, queryL, sourceL 499 | attn = F.softmax(attn * smooth, dim=2) 500 | 501 | # --> (batch, sourceL, queryL) 502 | attnT = torch.transpose(attn, 1, 2).contiguous() 503 | 504 | # --> (batch, d, sourceL) 505 | contextT = torch.transpose(context, 1, 2) 506 | # (batch x d x sourceL)(batch x sourceL x queryL) 507 | # --> (batch, d, queryL) 508 | weightedContext = torch.bmm(contextT, attnT) 509 | # --> (batch, queryL, d) 510 | weightedContext = torch.transpose(weightedContext, 1, 2) 511 | weightedContext = l2norm(weightedContext, dim=-1) 512 | 513 | return weightedContext 514 | 515 | class SGRAF(object): 516 | """ 517 | Similarity Reasoning and Filtration (SGRAF) Network 518 | """ 519 | 520 | def __init__(self, opt): 521 | # Build Models 522 | self.opt = opt 523 | self.grad_clip = opt.grad_clip 524 | self.img_enc = EncoderImage(opt, opt.img_dim, opt.embed_size, 525 | no_imgnorm=opt.no_imgnorm) 526 | self.txt_enc = EncoderText(opt, opt.vocab_size, opt.word_dim, 527 | opt.embed_size, opt.num_layers, 528 | use_bi_gru=opt.bi_gru, 529 | no_txtnorm=opt.no_txtnorm) 530 | self.sim_enc = EncoderSimilarity(opt, opt.embed_size, opt.sim_dim, 531 | opt.module_name, opt.sgr_step) 532 | 533 | if torch.cuda.is_available(): 534 | self.img_enc.cuda() 535 | self.txt_enc.cuda() 536 | self.sim_enc.cuda() 537 | cudnn.benchmark = True 538 | 539 | # Loss and Optimizer 540 | params = list(self.txt_enc.parameters()) 541 | params += list(self.img_enc.parameters()) 542 | params += list(self.sim_enc.parameters()) 543 | 544 | self.params = params 545 | 546 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 547 | self.Eiters = 0 548 | 549 | if 'f30k' in self.opt.data_name: 550 | self.targets = torch.ones(145000).cuda() 551 | self.preds = torch.ones(145000).cuda() 552 | self.preds_targets = torch.ones(145000).cuda() 553 | self.topos = torch.ones(145000).cuda() 554 | self.topo_targets = torch.ones(145000).cuda() 555 | elif 'coco' in self.opt.data_name: 556 | self.targets = torch.ones(566435).cuda() 557 | self.preds = torch.ones(566435).cuda() 558 | self.preds_targets = torch.ones(566435).cuda() 559 | self.topos = torch.ones(566435).cuda() 560 | self.topo_targets = torch.ones(566435).cuda() 561 | elif 'cc152k' in self.opt.data_name: 562 | self.targets = torch.ones(150000).cuda() 563 | self.preds = torch.ones(150000).cuda() 564 | self.preds_targets = torch.ones(150000).cuda() 565 | self.topos = torch.ones(150000).cuda() 566 | self.topo_targets = torch.ones(150000).cuda() 567 | 568 | def state_dict(self): 569 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.targets, self.topos, self.topo_targets] 570 | return state_dict 571 | 572 | def load_state_dict(self, state_dict): 573 | self.img_enc.load_state_dict(state_dict[0]) 574 | self.txt_enc.load_state_dict(state_dict[1]) 575 | self.sim_enc.load_state_dict(state_dict[2]) 576 | if len(state_dict) > 3: 577 | self.targets = state_dict[3] 578 | self.topos = state_dict[4] 579 | self.topo_targets = state_dict[5] 580 | 581 | def train_start(self): 582 | """switch to train mode""" 583 | self.img_enc.train() 584 | self.txt_enc.train() 585 | self.sim_enc.train() 586 | 587 | def val_start(self): 588 | """switch to evaluate mode""" 589 | self.img_enc.eval() 590 | self.txt_enc.eval() 591 | self.sim_enc.eval() 592 | 593 | def forward_emb(self, images, captions, lengths): 594 | """Compute the image and caption embeddings""" 595 | if torch.cuda.is_available(): 596 | images = images.cuda() 597 | captions = captions.cuda() 598 | 599 | # Forward feature encoding 600 | img_embs = self.img_enc(images) 601 | cap_embs = self.txt_enc(captions, lengths) 602 | return img_embs, cap_embs, lengths 603 | 604 | # Modifications by extending DECL 605 | def forward_sim(self, img_embs, cap_embs, cap_lens, mode='sim'): 606 | # Forward similarity encoding 607 | raw_sims = self.sim_enc(img_embs, cap_embs, cap_lens) 608 | if mode == 'sim': 609 | return torch.sigmoid(raw_sims) 610 | 611 | def train_emb(self, images, captions, lengths, ids=None, *args): 612 | """One training step given images and captions. 613 | """ 614 | self.Eiters += 1 615 | self.logger.update('Eit', self.Eiters) 616 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 617 | 618 | # compute the embeddings 619 | img_embs, cap_embs, cap_lens = self.forward_emb(images, captions, lengths) 620 | sims = self.forward_sim(img_embs, cap_embs, cap_lens) 621 | 622 | # measure accuracy and record loss 623 | self.optimizer.zero_grad() 624 | loss = self.forward_loss(sims) 625 | 626 | # compute gradient and do SGD step 627 | loss.backward() 628 | if self.grad_clip > 0: 629 | clip_grad_norm_(self.params, self.grad_clip) 630 | self.optimizer.step() 631 | -------------------------------------------------------------------------------- /GSC+/model/SGRAF.py: -------------------------------------------------------------------------------- 1 | """SGRAF model""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | 8 | import torch.backends.cudnn as cudnn 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | from torch.nn.utils.clip_grad import clip_grad_norm_ 11 | 12 | import numpy as np 13 | from collections import OrderedDict 14 | 15 | 16 | def l1norm(X, dim, eps=1e-8): 17 | """L1-normalize columns of X""" 18 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 19 | X = torch.div(X, norm) 20 | return X 21 | 22 | 23 | def l2norm(X, dim=-1, eps=1e-8): 24 | """L2-normalize columns of X""" 25 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 26 | X = torch.div(X, norm) 27 | return X 28 | 29 | 30 | def cosine_sim(x1, x2, dim=-1, eps=1e-8): 31 | """Returns cosine similarity between x1 and x2, computed along dim.""" 32 | w12 = torch.sum(x1 * x2, dim) 33 | w1 = torch.norm(x1, 2, dim) 34 | w2 = torch.norm(x2, 2, dim) 35 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 36 | 37 | 38 | class EncoderImage(nn.Module): 39 | """ 40 | Build local region representations by common-used FC-layer. 41 | Args: - images: raw local detected regions, shape: (batch_size, 36, 2048). 42 | Returns: - img_emb: finial local region embeddings, shape: (batch_size, 36, 1024). 43 | """ 44 | 45 | def __init__(self, opt, img_dim, embed_size, no_imgnorm=False): 46 | super(EncoderImage, self).__init__() 47 | self.opt = opt 48 | self.embed_size = embed_size 49 | self.no_imgnorm = no_imgnorm 50 | self.fc = nn.Linear(img_dim, embed_size) 51 | 52 | self.init_weights(self.fc) 53 | 54 | def init_weights(self, fc): 55 | """Xavier initialization for the fully connected layer""" 56 | r = np.sqrt(6.) / np.sqrt(fc.in_features + 57 | fc.out_features) 58 | fc.weight.data.uniform_(-r, r) 59 | fc.bias.data.fill_(0) 60 | 61 | def forward(self, images): 62 | """Extract image feature vectors.""" 63 | # assuming that the precomputed features are already l2-normalized 64 | img_emb = self.fc(images) 65 | 66 | # normalize in the joint embedding space 67 | if not self.no_imgnorm: 68 | img_emb = l2norm(img_emb, dim=-1) 69 | 70 | return img_emb 71 | 72 | def load_state_dict(self, state_dict): 73 | """Overwrite the default one to accept state_dict from Full model""" 74 | own_state = self.state_dict() 75 | new_state = OrderedDict() 76 | for name, param in state_dict.items(): 77 | if name in own_state: 78 | new_state[name] = param 79 | 80 | super(EncoderImage, self).load_state_dict(new_state) 81 | 82 | 83 | class EncoderText(nn.Module): 84 | """ 85 | Build local word representations by common-used Bi-GRU or GRU. 86 | Args: - images: raw local word ids, shape: (batch_size, L). 87 | Returns: - img_emb: final local word embeddings, shape: (batch_size, L, 1024). 88 | """ 89 | 90 | def __init__(self, opt, vocab_size, word_dim, embed_size, num_layers, 91 | use_bi_gru=False, no_txtnorm=False): 92 | super(EncoderText, self).__init__() 93 | self.opt = opt 94 | self.embed_size = embed_size 95 | self.no_txtnorm = no_txtnorm 96 | 97 | # word embedding 98 | self.embed = nn.Embedding(vocab_size, word_dim) 99 | self.dropout = nn.Dropout(0.4) 100 | 101 | # caption embedding 102 | self.use_bi_gru = use_bi_gru 103 | self.cap_rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) 104 | 105 | self.init_weights() 106 | 107 | def init_weights(self): 108 | self.embed.weight.data.uniform_(-0.1, 0.1) 109 | 110 | def forward(self, captions, lengths): 111 | """Handles variable size captions""" 112 | # embed word ids to vectors 113 | cap_emb = self.embed(captions) 114 | cap_emb = self.dropout(cap_emb) 115 | 116 | # pack the caption 117 | packed = pack_padded_sequence(cap_emb, lengths, batch_first=True) 118 | 119 | # forward propagate RNN 120 | out, _ = self.cap_rnn(packed) 121 | 122 | # reshape output to (batch_size, hidden_size) 123 | cap_emb, _ = pad_packed_sequence(out, batch_first=True) 124 | 125 | if self.use_bi_gru: 126 | cap_emb = (cap_emb[:, :, :cap_emb.size(2) // 2] + cap_emb[:, :, cap_emb.size(2) // 2:]) / 2 127 | 128 | # normalization in the joint embedding space 129 | if not self.no_txtnorm: 130 | cap_emb = l2norm(cap_emb, dim=-1) 131 | 132 | return cap_emb 133 | 134 | 135 | class VisualSA(nn.Module): 136 | """ 137 | Build global image representations by self-attention. 138 | Args: - local: local region embeddings, shape: (batch_size, 36, 1024) 139 | - raw_global: raw image by averaging regions, shape: (batch_size, 1024) 140 | Returns: - new_global: final image by self-attention, shape: (batch_size, 1024). 141 | """ 142 | 143 | def __init__(self, embed_dim, dropout_rate, num_region): 144 | super(VisualSA, self).__init__() 145 | 146 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 147 | nn.BatchNorm1d(num_region), 148 | nn.Tanh(), nn.Dropout(dropout_rate)) 149 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 150 | nn.BatchNorm1d(embed_dim), 151 | nn.Tanh(), nn.Dropout(dropout_rate)) 152 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 153 | 154 | self.init_weights() 155 | self.softmax = nn.Softmax(dim=1) 156 | 157 | def init_weights(self): 158 | for embeddings in self.children(): 159 | for m in embeddings: 160 | if isinstance(m, nn.Linear): 161 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 162 | m.weight.data.uniform_(-r, r) 163 | m.bias.data.fill_(0) 164 | elif isinstance(m, nn.BatchNorm1d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def forward(self, local, raw_global): 169 | # compute embedding of local regions and raw global image 170 | if len(local) == 1 and len(raw_global) == 1: 171 | l_emb = self.embedding_local[0](local) 172 | l_emb = self.embedding_local[2](l_emb) 173 | l_emb = self.embedding_local[3](l_emb) 174 | g_emb = self.embedding_global[0](raw_global) 175 | g_emb = self.embedding_global[2](g_emb) 176 | g_emb = self.embedding_global[3](g_emb) 177 | else: 178 | l_emb = self.embedding_local(local) 179 | g_emb = self.embedding_global(raw_global) 180 | 181 | # compute the normalized weights, shape: (batch_size, 36) 182 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 183 | common = l_emb.mul(g_emb) 184 | weights = self.embedding_common(common).squeeze(2) 185 | weights = self.softmax(weights) 186 | 187 | # compute final image, shape: (batch_size, 1024) 188 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 189 | new_global = l2norm(new_global, dim=-1) 190 | 191 | return new_global 192 | 193 | 194 | class TextSA(nn.Module): 195 | """ 196 | Build global text representations by self-attention. 197 | Args: - local: local word embeddings, shape: (batch_size, L, 1024) 198 | - raw_global: raw text by averaging words, shape: (batch_size, 1024) 199 | Returns: - new_global: final text by self-attention, shape: (batch_size, 1024). 200 | """ 201 | 202 | def __init__(self, embed_dim, dropout_rate): 203 | super(TextSA, self).__init__() 204 | 205 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 206 | nn.Tanh(), nn.Dropout(dropout_rate)) 207 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 208 | nn.Tanh(), nn.Dropout(dropout_rate)) 209 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 210 | 211 | self.init_weights() 212 | self.softmax = nn.Softmax(dim=1) 213 | 214 | def init_weights(self): 215 | for embeddings in self.children(): 216 | for m in embeddings: 217 | if isinstance(m, nn.Linear): 218 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 219 | m.weight.data.uniform_(-r, r) 220 | m.bias.data.fill_(0) 221 | elif isinstance(m, nn.BatchNorm1d): 222 | m.weight.data.fill_(1) 223 | m.bias.data.zero_() 224 | 225 | def forward(self, local, raw_global): 226 | # compute embedding of local words and raw global text 227 | l_emb = self.embedding_local(local) 228 | g_emb = self.embedding_global(raw_global) 229 | 230 | # compute the normalized weights, shape: (batch_size, L) 231 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 232 | common = l_emb.mul(g_emb) 233 | weights = self.embedding_common(common).squeeze(2) 234 | weights = self.softmax(weights) 235 | 236 | # compute final text, shape: (batch_size, 1024) 237 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 238 | new_global = l2norm(new_global, dim=-1) 239 | 240 | return new_global 241 | 242 | 243 | class GraphReasoning(nn.Module): 244 | """ 245 | Perform the similarity graph reasoning with a full-connected graph 246 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 247 | Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256) 248 | """ 249 | 250 | def __init__(self, sim_dim): 251 | super(GraphReasoning, self).__init__() 252 | 253 | self.graph_query_w = nn.Linear(sim_dim, sim_dim) 254 | self.graph_key_w = nn.Linear(sim_dim, sim_dim) 255 | self.sim_graph_w = nn.Linear(sim_dim, sim_dim) 256 | # self.relu = nn.LeakyReLU(negative_slope=0.1) 257 | self.relu = nn.ReLU() 258 | 259 | self.init_weights() 260 | 261 | def forward(self, sim_emb): 262 | sim_query = self.graph_query_w(sim_emb) 263 | sim_key = self.graph_key_w(sim_emb) 264 | sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1) 265 | sim_sgr = torch.bmm(sim_edge, sim_emb) 266 | sim_sgr = self.relu(self.sim_graph_w(sim_sgr)) 267 | return sim_sgr 268 | 269 | def init_weights(self): 270 | for m in self.children(): 271 | if isinstance(m, nn.Linear): 272 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 273 | m.weight.data.uniform_(-r, r) 274 | m.bias.data.fill_(0) 275 | elif isinstance(m, nn.BatchNorm1d): 276 | m.weight.data.fill_(1) 277 | m.bias.data.zero_() 278 | 279 | 280 | class AttentionFiltration(nn.Module): 281 | """ 282 | Perform the similarity Attention Filtration with a gate-based attention 283 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 284 | Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256) 285 | """ 286 | 287 | def __init__(self, sim_dim): 288 | super(AttentionFiltration, self).__init__() 289 | 290 | self.attn_sim_w = nn.Linear(sim_dim, 1) 291 | self.bn = nn.BatchNorm1d(1) 292 | 293 | self.init_weights() 294 | 295 | def forward(self, sim_emb): 296 | sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1) 297 | sim_saf = torch.matmul(sim_attn, sim_emb) 298 | sim_saf = l2norm(sim_saf.squeeze(1), dim=-1) 299 | return sim_saf 300 | 301 | def init_weights(self): 302 | for m in self.children(): 303 | if isinstance(m, nn.Linear): 304 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 305 | m.weight.data.uniform_(-r, r) 306 | m.bias.data.fill_(0) 307 | elif isinstance(m, nn.BatchNorm1d): 308 | m.weight.data.fill_(1) 309 | m.bias.data.zero_() 310 | 311 | 312 | class EncoderSimilarity(nn.Module): 313 | """ 314 | Compute the image-text similarity by SGR, SAF, AVE 315 | Args: - img_emb: local region embeddings, shape: (batch_size, 36, 1024) 316 | - cap_emb: local word embeddings, shape: (batch_size, L, 1024) 317 | Returns: 318 | - sim_all: final image-text similarities, shape: (batch_size, batch_size). 319 | """ 320 | 321 | def __init__(self, opt, embed_size, sim_dim, module_name='AVE', sgr_step=3): 322 | super(EncoderSimilarity, self).__init__() 323 | self.opt = opt 324 | self.module_name = module_name 325 | 326 | self.v_global_w = VisualSA(embed_size, 0.4, 36) 327 | self.t_global_w = TextSA(embed_size, 0.4) 328 | 329 | self.sim_tranloc_w = nn.Linear(embed_size, sim_dim) 330 | self.sim_tranglo_w = nn.Linear(embed_size, sim_dim) 331 | 332 | self.sim_eval_w = nn.Linear(sim_dim, 1) 333 | self.sigmoid = nn.Sigmoid() 334 | 335 | if module_name == 'SGR': 336 | self.SGR_module = nn.ModuleList([GraphReasoning(sim_dim) for i in range(sgr_step)]) 337 | elif module_name == 'SAF': 338 | self.SAF_module = AttentionFiltration(sim_dim) 339 | else: 340 | raise ValueError('Invalid input of opt.module_name in opts.py') 341 | 342 | self.init_weights() 343 | 344 | # Modifications by extending DECL 345 | def forward(self, img_emb, cap_emb, cap_lens): 346 | sim_all = [] 347 | sim_img = [] 348 | sim_cap = [] 349 | n_image = img_emb.size(0) 350 | n_caption = cap_emb.size(0) 351 | 352 | # get enhanced global images by self-attention 353 | img_ave = torch.mean(img_emb, 1) 354 | img_glo = self.v_global_w(img_emb, img_ave) 355 | 356 | for i in range(n_caption): 357 | # get the i-th sentence 358 | n_word = cap_lens[i] 359 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 360 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 361 | 362 | # get enhanced global i-th text by self-attention 363 | cap_ave_i = torch.mean(cap_i, 1) 364 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 365 | 366 | # local-global alignment construction 367 | Context_img = SCAN_attention(cap_i_expand, img_emb, smooth=9.0) 368 | sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2) 369 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 370 | 371 | sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2) 372 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 373 | 374 | # concat the global and local alignments 375 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 376 | 377 | # compute the final similarity vector 378 | if self.module_name == 'SGR': 379 | for module in self.SGR_module: 380 | sim_emb = module(sim_emb) 381 | sim_vec = sim_emb[:, 0, :] 382 | else: 383 | sim_vec = self.SAF_module(sim_emb) 384 | 385 | # compute the final similarity score 386 | # sim_i = self.sigmoid(self.sim_eval_w(sim_vec)) 387 | # sim_all.append(sim_i) 388 | sim_all.append(self.sim_eval_w(sim_vec)) 389 | 390 | sim_all = torch.cat(sim_all, 1) 391 | return sim_all 392 | 393 | # cap_glo = [] 394 | # for i in range (n_caption): 395 | # n_word = cap_lens[i] 396 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 397 | # cap_ave_i = torch.mean(cap_i, 1) 398 | # cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 399 | # cap_glo.append(cap_glo_i) 400 | # cap_glo = torch.cat(cap_glo, 0) 401 | 402 | # for i in range(n_image): 403 | # img_i = img_emb[i].unsqueeze(0) 404 | # img_i_expand = img_i.repeat(n_image, 1, 1) 405 | 406 | # img_glo_i = img_glo[i].unsqueeze(0) 407 | 408 | # sim_loc = torch.pow(torch.sub(img_emb, img_i_expand), 2) + 1e-6 409 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 410 | 411 | # sim_glo = torch.pow(torch.sub(img_glo, img_glo_i), 2) + 1e-6 412 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 413 | 414 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 415 | 416 | # Context_cap = SCAN_attention(img_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens) 417 | # sim_loc_cap = torch.pow(torch.sub(Context_cap, img_i_expand), 2) 418 | # sim_loc_cap = l2norm(self.sim_tranloc_w(sim_loc_cap), dim=-1) 419 | 420 | # sim_glo_cap = torch.pow(torch.sub(cap_glo, img_glo_i), 2) 421 | # sim_glo_cap = l2norm(self.sim_tranglo_w(sim_glo_cap), dim=-1) 422 | 423 | # sim_emb_cap = torch.cat([sim_glo_cap.unsqueeze(1), sim_loc_cap], 1) 424 | 425 | # sim_emb = torch.cat([sim_emb, sim_emb_cap], 0) 426 | 427 | # if self.module_name == 'SGR': 428 | # for module in self.SGR_module: 429 | # sim_emb = module(sim_emb) 430 | # sim_vec = sim_emb[:, 0, :] 431 | # else: 432 | # sim_vec = self.SAF_module(sim_emb) 433 | 434 | # # sim_img.append(self.sim_eval_w(sim_vec)) 435 | # sim_img_cap = self.sim_eval_w(sim_vec) 436 | # sim_img.append(sim_img_cap[: n_image]) 437 | # sim_all.append(sim_img_cap[n_image: ]) 438 | 439 | # # (n_image, n_caption) 440 | # sim_all = torch.cat(sim_all, 1) 441 | # sim_img = torch.cat(sim_img, 1) 442 | 443 | # for i in range(n_caption): 444 | # n_word = cap_lens[i] 445 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 446 | # cap_i_expand = cap_i.repeat(n_caption, 1, 1) 447 | # cap_glo_i = cap_glo[i].unsqueeze(0) 448 | 449 | # Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens) 450 | # sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2) 451 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 452 | 453 | # sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6 454 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 455 | 456 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 457 | 458 | # if self.module_name == 'SGR': 459 | # for module in self.SGR_module: 460 | # sim_emb = module(sim_emb) 461 | # sim_vec = sim_emb[:, 0, :] 462 | # else: 463 | # sim_vec = self.SAF_module(sim_emb) 464 | 465 | # sim_cap.append(self.sim_eval_w(sim_vec)) 466 | 467 | # sim_cap = torch.cat(sim_cap, 1) 468 | 469 | # return sim_all, sim_img, sim_cap 470 | 471 | def forward_topology(self, img_emb, cap_emb, cap_lens): 472 | sim_img = [] 473 | n_image = img_emb.size(0) 474 | 475 | img_ave = torch.mean(img_emb, 1) 476 | img_glo = self.v_global_w(img_emb, img_ave) 477 | 478 | img_emb_n = img_emb.unsqueeze(0).repeat(n_image, 1, 1, 1) 479 | img_emb_ni = img_emb.unsqueeze(1).repeat(1, n_image, 1, 1) 480 | sim_loc = torch.pow(torch.sub(img_emb_n, img_emb_ni), 2) + 1e-6 481 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 482 | 483 | img_glo_n = img_glo.unsqueeze(0).repeat(n_image, 1, 1) 484 | img_glo_ni = img_glo.unsqueeze(1).repeat(1, n_image, 1) 485 | sim_glo = torch.pow(torch.sub(img_glo_n, img_glo_ni), 2) + 1e-6 486 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 487 | 488 | sim_emb = torch.cat([sim_glo.unsqueeze(2), sim_loc], 2) 489 | sim_emb = sim_emb.view(-1, sim_emb.size(2), sim_emb.size(3)) 490 | if self.module_name == 'SGR': 491 | for module in self.SGR_module: 492 | sim_emb = module(sim_emb) 493 | sim_vec = sim_emb[:, 0, :] 494 | else: 495 | sim_vec = self.SAF_module(sim_emb) 496 | sim_img = self.sim_eval_w(sim_vec) 497 | sim_img = sim_img.view(n_image, n_image) 498 | 499 | sim_cap = [] 500 | n_caption = cap_emb.size(0) 501 | 502 | cap_glo = [] 503 | for i in range (n_caption): 504 | n_word = cap_lens[i] 505 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 506 | cap_ave_i = torch.mean(cap_i, 1) 507 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 508 | cap_glo.append(cap_glo_i) 509 | cap_glo = torch.cat(cap_glo, 0) 510 | 511 | cap_emb_n = cap_emb.unsqueeze(0).repeat(n_image, 1, 1, 1) 512 | cap_emb_ni = cap_emb.unsqueeze(1).repeat(1, n_image, 1, 1) 513 | sim_loc = torch.pow(torch.sub(cap_emb_n, cap_emb_ni), 2) + 1e-6 514 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 515 | 516 | cap_glo_n = cap_glo.unsqueeze(0).repeat(n_image, 1, 1) 517 | cap_glo_ni = cap_glo.unsqueeze(1).repeat(1, n_image, 1) 518 | sim_glo = torch.pow(torch.sub(cap_glo_n, cap_glo_ni), 2) + 1e-6 519 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 520 | 521 | sim_emb = torch.cat([sim_glo.unsqueeze(2), sim_loc], 2) 522 | sim_emb = sim_emb.view(-1, sim_emb.size(2), sim_emb.size(3)) 523 | if self.module_name == 'SGR': 524 | for module in self.SGR_module: 525 | sim_emb = module(sim_emb) 526 | sim_vec = sim_emb[:, 0, :] 527 | else: 528 | sim_vec = self.SAF_module(sim_emb) 529 | sim_cap = self.sim_eval_w(sim_vec) 530 | sim_cap = sim_cap.view(n_image, n_image) 531 | 532 | # for i in range(n_caption): 533 | # n_word = cap_lens[i] 534 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 535 | # cap_i_expand = cap_i.repeat(n_caption, 1, 1) 536 | # cap_glo_i = cap_glo[i].unsqueeze(0) 537 | 538 | # Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens) 539 | # sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2) 540 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 541 | 542 | # sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6 543 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 544 | 545 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 546 | 547 | # if self.module_name == 'SGR': 548 | # for module in self.SGR_module: 549 | # sim_emb = module(sim_emb) 550 | # sim_vec = sim_emb[:, 0, :] 551 | # else: 552 | # sim_vec = self.SAF_module(sim_emb) 553 | 554 | # sim_cap.append(self.sim_eval_w(sim_vec)) 555 | 556 | # sim_cap = torch.cat(sim_cap, 1) 557 | 558 | return sim_img, sim_cap 559 | 560 | 561 | def init_weights(self): 562 | for m in self.children(): 563 | if isinstance(m, nn.Linear): 564 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 565 | m.weight.data.uniform_(-r, r) 566 | m.bias.data.fill_(0) 567 | elif isinstance(m, nn.BatchNorm1d): 568 | m.weight.data.fill_(1) 569 | m.bias.data.zero_() 570 | 571 | 572 | def SCAN_attention(query, context, smooth, eps=1e-8, lens=None): 573 | """ 574 | query: (n_context, queryL, d) 575 | context: (n_context, sourceL, d) 576 | """ 577 | # --> (batch, d, queryL) 578 | queryT = torch.transpose(query, 1, 2) 579 | 580 | # (batch, sourceL, d)(batch, d, queryL) 581 | # --> (batch, sourceL, queryL) 582 | attn = torch.bmm(context, queryT) 583 | 584 | attn = nn.LeakyReLU(0.1)(attn) 585 | attn = l2norm(attn, 2) 586 | 587 | if lens is not None: 588 | for i in range(len(lens)): 589 | attn[i, lens[i]:] = -1e9 590 | 591 | # --> (batch, queryL, sourceL) 592 | attn = torch.transpose(attn, 1, 2).contiguous() 593 | # --> (batch, queryL, sourceL 594 | attn = F.softmax(attn * smooth, dim=2) 595 | 596 | # --> (batch, sourceL, queryL) 597 | attnT = torch.transpose(attn, 1, 2).contiguous() 598 | 599 | # --> (batch, d, sourceL) 600 | contextT = torch.transpose(context, 1, 2) 601 | # (batch x d x sourceL)(batch x sourceL x queryL) 602 | # --> (batch, d, queryL) 603 | weightedContext = torch.bmm(contextT, attnT) 604 | # --> (batch, queryL, d) 605 | weightedContext = torch.transpose(weightedContext, 1, 2) 606 | weightedContext = l2norm(weightedContext, dim=-1) 607 | 608 | return weightedContext 609 | 610 | class SGRAF(object): 611 | """ 612 | Similarity Reasoning and Filtration (SGRAF) Network 613 | """ 614 | 615 | def __init__(self, opt): 616 | # Build Models 617 | self.opt = opt 618 | self.grad_clip = opt.grad_clip 619 | self.img_enc = EncoderImage(opt, opt.img_dim, opt.embed_size, 620 | no_imgnorm=opt.no_imgnorm) 621 | self.txt_enc = EncoderText(opt, opt.vocab_size, opt.word_dim, 622 | opt.embed_size, opt.num_layers, 623 | use_bi_gru=opt.bi_gru, 624 | no_txtnorm=opt.no_txtnorm) 625 | self.sim_enc = EncoderSimilarity(opt, opt.embed_size, opt.sim_dim, 626 | opt.module_name, opt.sgr_step) 627 | 628 | if torch.cuda.is_available(): 629 | self.img_enc.cuda() 630 | self.txt_enc.cuda() 631 | self.sim_enc.cuda() 632 | cudnn.benchmark = True 633 | 634 | # Loss and Optimizer 635 | params = list(self.txt_enc.parameters()) 636 | params += list(self.img_enc.parameters()) 637 | params += list(self.sim_enc.parameters()) 638 | 639 | self.params = params 640 | 641 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 642 | self.Eiters = 0 643 | 644 | if 'f30k' in self.opt.data_name: 645 | self.targets = torch.ones(145000).cuda() 646 | self.preds = torch.ones(145000).cuda() 647 | self.preds_targets = torch.ones(145000).cuda() 648 | self.topos = torch.ones(145000).cuda() 649 | self.topo_targets = torch.ones(145000).cuda() 650 | elif 'coco' in self.opt.data_name: 651 | self.targets = torch.ones(566435).cuda() 652 | self.preds = torch.ones(566435).cuda() 653 | self.preds_targets = torch.ones(566435).cuda() 654 | self.topos = torch.ones(566435).cuda() 655 | self.topo_targets = torch.ones(566435).cuda() 656 | elif 'cc152k' in self.opt.data_name: 657 | self.targets = torch.ones(150000).cuda() 658 | self.preds = torch.ones(150000).cuda() 659 | self.preds_targets = torch.ones(150000).cuda() 660 | self.topos = torch.ones(150000).cuda() 661 | self.topo_targets = torch.ones(150000).cuda() 662 | 663 | def state_dict(self): 664 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.targets, self.topos, self.topo_targets] 665 | return state_dict 666 | 667 | def load_state_dict(self, state_dict): 668 | self.img_enc.load_state_dict(state_dict[0]) 669 | self.txt_enc.load_state_dict(state_dict[1]) 670 | self.sim_enc.load_state_dict(state_dict[2]) 671 | if len(state_dict) > 3: 672 | self.targets = state_dict[3] 673 | self.topos = state_dict[4] 674 | self.topo_targets = state_dict[5] 675 | 676 | def train_start(self): 677 | """switch to train mode""" 678 | self.img_enc.train() 679 | self.txt_enc.train() 680 | self.sim_enc.train() 681 | 682 | def val_start(self): 683 | """switch to evaluate mode""" 684 | self.img_enc.eval() 685 | self.txt_enc.eval() 686 | self.sim_enc.eval() 687 | 688 | def forward_emb(self, images, captions, lengths): 689 | """Compute the image and caption embeddings""" 690 | if torch.cuda.is_available(): 691 | images = images.cuda() 692 | captions = captions.cuda() 693 | 694 | # Forward feature encoding 695 | img_embs = self.img_enc(images) 696 | cap_embs = self.txt_enc(captions, lengths) 697 | return img_embs, cap_embs, lengths 698 | 699 | # Modifications by extending DECL 700 | def forward_sim(self, img_embs, cap_embs, cap_lens, mode='sim'): 701 | # Forward similarity encoding 702 | # raw_sims, raw_sims_img, raw_sims_cap = self.sim_enc(img_embs, cap_embs, cap_lens) 703 | raw_sims = self.sim_enc(img_embs, cap_embs, cap_lens) 704 | if mode == 'sim': 705 | # return torch.sigmoid(raw_sims), raw_sims_img, raw_sims_cap 706 | return torch.sigmoid(raw_sims) 707 | 708 | def train_emb(self, images, captions, lengths, ids=None, *args): 709 | """One training step given images and captions. 710 | """ 711 | self.Eiters += 1 712 | self.logger.update('Eit', self.Eiters) 713 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 714 | 715 | # compute the embeddings 716 | img_embs, cap_embs, cap_lens = self.forward_emb(images, captions, lengths) 717 | sims = self.forward_sim(img_embs, cap_embs, cap_lens) 718 | 719 | # measure accuracy and record loss 720 | self.optimizer.zero_grad() 721 | loss = self.forward_loss(sims) 722 | 723 | # compute gradient and do SGD step 724 | loss.backward() 725 | if self.grad_clip > 0: 726 | clip_grad_norm_(self.params, self.grad_clip) 727 | self.optimizer.step() 728 | --------------------------------------------------------------------------------