├── .DS_Store
├── GSC
├── model
│ ├── __pycache__
│ │ ├── DECL.cpython-37.pyc
│ │ ├── DECL.cpython-39.pyc
│ │ ├── RINCE.cpython-37.pyc
│ │ ├── RINCE.cpython-39.pyc
│ │ ├── SGRAF.cpython-37.pyc
│ │ └── SGRAF.cpython-39.pyc
│ ├── RINCE.py
│ └── SGRAF.py
├── eval.py
├── utils.py
├── train_coco.sh
├── train_f30k.sh
├── train_cc152k.sh
├── vocab.py
├── data.py
├── opts.py
├── train.py
└── evaluation.py
├── GSC+
├── model
│ ├── __pycache__
│ │ ├── DECL.cpython-37.pyc
│ │ ├── DECL.cpython-39.pyc
│ │ ├── RINCE.cpython-37.pyc
│ │ ├── RINCE.cpython-39.pyc
│ │ ├── SGRAF.cpython-37.pyc
│ │ └── SGRAF.cpython-39.pyc
│ ├── RINCE.py
│ └── SGRAF.py
├── eval.py
├── utils.py
├── train_f30k.sh
├── train_coco.sh
├── train_cc152k.sh
├── vocab.py
├── data.py
├── opts.py
├── train.py
└── evaluation.py
├── README.md
└── requirements.txt
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/.DS_Store
--------------------------------------------------------------------------------
/GSC/model/__pycache__/DECL.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/DECL.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC/model/__pycache__/DECL.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/DECL.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/DECL.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/DECL.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/DECL.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/DECL.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/RINCE.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/RINCE.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/RINCE.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/RINCE.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/SGRAF.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/SGRAF.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC+/model/__pycache__/SGRAF.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC+/model/__pycache__/SGRAF.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC/model/__pycache__/RINCE.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/RINCE.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC/model/__pycache__/RINCE.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/RINCE.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC/model/__pycache__/SGRAF.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/SGRAF.cpython-37.pyc
--------------------------------------------------------------------------------
/GSC/model/__pycache__/SGRAF.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MediaBrain-SJTU/GSC/HEAD/GSC/model/__pycache__/SGRAF.cpython-39.pyc
--------------------------------------------------------------------------------
/GSC/eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | from evaluation import eval_DECL
4 |
5 | if __name__ == '__main__':
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 | # DECL-SGRAF.
8 | avg_SGRAF = False
9 |
10 | data_path= '/remote-home/share/zhaozh/NC_Datasets/data'
11 | vocab_path= '/remote-home/share/zhaozh/NC_Datasets/vocab'
12 | ns = [0]
13 | for n in ns:
14 | print(f"\n============================>f30k noise:{n}")
15 | # checkpoint_paths = [
16 | # f'./f30K_SAF_noise{n}_best.tar',
17 | # f'./f30K_SGR_noise{n}_best.tar']
18 | checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO/runs/f30k_gamma/r_corr_0/gamma_01/checkpoint_dir/model_best.pth.tar']
19 | eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path)
20 |
21 | # for n in ns:
22 | # print(f"\n============================>coco noise:{n}")
23 | # # checkpoint_paths = [
24 | # # f'./coco_SAF_noise{n}_best.tar',
25 | # # f'./coco_SGR_noise{n}_best.tar']
26 | # checkpoint_paths = [f'/remote-home/zhaozh/NC/DECL/runs/coco/r_corr_{n}/checkpoint_dir/model_best.pth.tar']
27 | # eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path)
28 |
29 | # print(f"\n============================>cc152k")
30 | # eval_DECL(['./cc152k_SAF_best.tar', './cc152k_SGR_best.tar'], avg_SGRAF=avg_SGRAF)
31 |
--------------------------------------------------------------------------------
/GSC+/eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | from evaluation import eval_DECL
4 |
5 | if __name__ == '__main__':
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 | # DECL-SGRAF.
8 | avg_SGRAF = True
9 |
10 | data_path= '/remote-home/share/zhaozh/NC_Datasets/data'
11 | vocab_path= '/remote-home/share/zhaozh/NC_Datasets/vocab'
12 | ns = [4]
13 | for n in ns:
14 | print(f"\n============================>f30k noise:{n}")
15 | # checkpoint_paths = [
16 | # f'./f30K_SAF_noise{n}_best.tar',
17 | # f'./f30K_SGR_noise{n}_best.tar']
18 | checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/f30k_batch128/r_corr_4/gamma_001_biinfo/checkpoint_dir/model_best.pth.tar',
19 | f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/f30k_batch128/r_corr_4/gamma_001/checkpoint_dir/model_best.pth.tar']
20 | eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path)
21 |
22 | # for n in ns:
23 | # print(f"\n============================>coco noise:{n}")
24 | # # checkpoint_paths = [
25 | # # f'./coco_SAF_noise{n}_best.tar',
26 | # # f'./coco_SGR_noise{n}_best.tar']
27 | # checkpoint_paths = [f'/remote-home/zhaozh/NC/ELCL_TOPO+/runs/coco_batch128/r_corr_2/gamma_002/checkpoint_dir/model_best.pth.tar']
28 | # eval_DECL(checkpoint_paths, avg_SGRAF, data_path=data_path, vocab_path=vocab_path)
29 |
30 | # print(f"\n============================>cc152k")
31 | # eval_DECL(['./cc152k_SAF_best.tar', './cc152k_SGR_best.tar'], avg_SGRAF=avg_SGRAF)
32 |
--------------------------------------------------------------------------------
/GSC+/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self, name="", fmt=":f"):
8 | self.name = name
9 | self.fmt = fmt
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 | def __str__(self):
25 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
26 | return fmtstr.format(**self.__dict__)
27 |
28 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='',ckpt=True):
29 | tries = 15
30 | error = None
31 | # deal with unstable I/O. Usually not necessary.
32 | while tries:
33 | try:
34 | if ckpt:
35 | torch.save(state, prefix + filename)
36 | if is_best:
37 | torch.save(state, prefix + 'model_best.pth.tar')
38 | except IOError as e:
39 | error = e
40 | tries -= 1
41 | else:
42 | break
43 | print('model save {} failed, remaining {} trials'.format(filename, tries))
44 | if not tries:
45 | raise error
46 |
47 |
48 | def save_config(opt, file_path):
49 | with open(file_path, "w") as f:
50 | json.dump(opt.__dict__, f, indent=2)
51 |
52 |
53 | def load_config(opt, file_path):
54 | with open(file_path, "r") as f:
55 | opt.__dict__ = json.load(f)
56 |
--------------------------------------------------------------------------------
/GSC/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self, name="", fmt=":f"):
8 | self.name = name
9 | self.fmt = fmt
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 | def __str__(self):
25 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
26 | return fmtstr.format(**self.__dict__)
27 |
28 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='',ckpt=True):
29 | tries = 15
30 | error = None
31 | # deal with unstable I/O. Usually not necessary.
32 | while tries:
33 | try:
34 | if ckpt:
35 | torch.save(state, prefix + filename)
36 | if is_best:
37 | torch.save(state, prefix + 'model_best.pth.tar')
38 | except IOError as e:
39 | error = e
40 | tries -= 1
41 | else:
42 | break
43 | print('model save {} failed, remaining {} trials'.format(filename, tries))
44 | if not tries:
45 | raise error
46 |
47 |
48 | def save_config(opt, file_path):
49 | with open(file_path, "w") as f:
50 | json.dump(opt.__dict__, f, indent=2)
51 |
52 |
53 | def load_config(opt, file_path):
54 | with open(file_path, "r") as f:
55 | opt.__dict__ = json.load(f)
56 |
--------------------------------------------------------------------------------
/GSC/train_coco.sh:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | data_name=coco_precomp
3 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
4 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
5 |
6 | # noise settings
7 | noise_ratio=0.6
8 | noise_file=noise_file_by_caption/coco/coco_precomp_0.6.npy
9 |
10 | # loss settings
11 | contrastive_loss=InfoNCE
12 | temp=0.07
13 | lam=0.01
14 | q=0.01
15 |
16 | # train settings
17 | gpu=0
18 | warmup_epoch=0
19 | num_epochs=60
20 | lr_update=20
21 | batch_size=128
22 | module_name=SGR
23 | folder_name=coco/r_corr_2
24 |
25 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
26 | --warmup_epoch $warmup_epoch \
27 | --folder_name $folder_name \
28 | --noise_ratio $noise_ratio \
29 | --noise_file $noise_file \
30 | --num_epochs $num_epochs \
31 | --batch_size $batch_size \
32 | --lr_update $lr_update \
33 | --module_name $module_name \
34 | --learning_rate 0.0002 \
35 | --data_name $data_name \
36 | --data_path $data_path \
37 | --vocab_path $vocab_path \
38 | --contrastive_loss $contrastive_loss \
39 | --temp $temp \
40 | --lam $lam \
41 | --q $q \
42 |
43 |
--------------------------------------------------------------------------------
/GSC+/train_f30k.sh:
--------------------------------------------------------------------------------
1 | export OPENBLAS_NUM_THREADS=1
2 | export GOTO_NUM_THREADS=1
3 | export OMP_NUM_THREADS=1
4 | # dataset settings
5 | data_name=f30k_precomp
6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
8 |
9 | # noise settings
10 | noise_ratio=0.2
11 | noise_file=noise_file_by_caption/f30k/f30k_precomp_0.2.npy
12 |
13 | # loss settings
14 | contrastive_loss=InfoNCE
15 | temp=0.07
16 | lam=0.01
17 | q=0.01
18 |
19 | beta=0.7
20 | gamma=0.01
21 |
22 | # train settings
23 | gpu=0
24 | warmup_epoch=0
25 | num_epochs=60
26 | lr_update=15
27 | batch_size=128
28 | module_name=SGR
29 | folder_name=f30k/r_corr_2
30 |
31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
32 | --warmup_epoch $warmup_epoch \
33 | --folder_name $folder_name \
34 | --noise_ratio $noise_ratio \
35 | --noise_file $noise_file \
36 | --num_epochs $num_epochs \
37 | --batch_size $batch_size \
38 | --lr_update $lr_update \
39 | --module_name $module_name \
40 | --learning_rate 0.0002 \
41 | --data_name $data_name \
42 | --data_path $data_path \
43 | --vocab_path $vocab_path \
44 | --contrastive_loss $contrastive_loss \
45 | --temp $temp \
46 | --lam $lam \
47 | --q $q \
48 | --beta $beta \
49 | --gamma $gamma \
50 |
51 |
--------------------------------------------------------------------------------
/GSC/train_f30k.sh:
--------------------------------------------------------------------------------
1 | export OPENBLAS_NUM_THREADS=1
2 | export GOTO_NUM_THREADS=1
3 | export OMP_NUM_THREADS=1
4 | # dataset settings
5 | data_name=f30k_precomp
6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
8 |
9 | # noise settings
10 | noise_ratio=0.2
11 | noise_file=noise_file_by_caption/f30k/f30k_precomp_0.2.npy
12 |
13 | # loss settings
14 | contrastive_loss=InfoNCE
15 | temp=0.07
16 | lam=0.01
17 | q=0.01
18 |
19 | beta=0.7
20 | gamma=0.01
21 |
22 | # train settings
23 | gpu=0
24 | warmup_epoch=0
25 | num_epochs=60
26 | lr_update=15
27 | batch_size=128
28 | module_name=SGR
29 | folder_name=f30k/r_corr_2
30 |
31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
32 | --warmup_epoch $warmup_epoch \
33 | --folder_name $folder_name \
34 | --noise_ratio $noise_ratio \
35 | --noise_file $noise_file \
36 | --num_epochs $num_epochs \
37 | --batch_size $batch_size \
38 | --lr_update $lr_update \
39 | --module_name $module_name \
40 | --learning_rate 0.0002 \
41 | --data_name $data_name \
42 | --data_path $data_path \
43 | --vocab_path $vocab_path \
44 | --contrastive_loss $contrastive_loss \
45 | --temp $temp \
46 | --lam $lam \
47 | --q $q \
48 | --beta $beta \
49 | --gamma $gamma \
50 |
51 |
--------------------------------------------------------------------------------
/GSC/train_cc152k.sh:
--------------------------------------------------------------------------------
1 | export OPENBLAS_NUM_THREADS=1
2 | export GOTO_NUM_THREADS=1
3 | export OMP_NUM_THREADS=1
4 | # dataset settings
5 | data_name=cc152k_precomp
6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
8 |
9 | # noise settings
10 | noise_ratio=0.0
11 | noise_file=/remote-home/share/zhaozh/NC_Datasets/data
12 |
13 | # loss settings
14 | contrastive_loss=InfoNCE
15 | temp=0.07
16 | lam=0.01
17 | q=0.01
18 |
19 | beta=0.7
20 | gamma=0.01
21 |
22 | # train settings
23 | gpu=0
24 | warmup_epoch=0
25 | num_epochs=60
26 | lr_update=15
27 | batch_size=128
28 | module_name=SGR
29 | folder_name=cc152k/r_corr_2
30 |
31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
32 | --warmup_epoch $warmup_epoch \
33 | --folder_name $folder_name \
34 | --noise_ratio $noise_ratio \
35 | --noise_file $noise_file \
36 | --num_epochs $num_epochs \
37 | --batch_size $batch_size \
38 | --lr_update $lr_update \
39 | --module_name $module_name \
40 | --learning_rate 0.0002 \
41 | --data_name $data_name \
42 | --data_path $data_path \
43 | --vocab_path $vocab_path \
44 | --contrastive_loss $contrastive_loss \
45 | --temp $temp \
46 | --lam $lam \
47 | --q $q \
48 | --beta $beta \
49 | --gamma $gamma \
50 |
51 |
52 |
--------------------------------------------------------------------------------
/GSC+/train_coco.sh:
--------------------------------------------------------------------------------
1 | export OPENBLAS_NUM_THREADS=1
2 | export GOTO_NUM_THREADS=1
3 | export OMP_NUM_THREADS=1
4 | # dataset settings
5 | data_name=coco_precomp
6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
8 |
9 | # noise settings
10 | noise_ratio=0.2
11 | noise_file=noise_file_by_caption/coco/coco_precomp_0.2.npy
12 |
13 | # loss settings
14 | contrastive_loss=InfoNCE
15 | temp=0.07
16 | lam=0.01
17 | q=0.01
18 |
19 | beta=0.7
20 | gamma=0.05
21 |
22 | # train settings
23 | gpu=0
24 | warmup_epoch=0
25 | num_epochs=30
26 | lr_update=15
27 | batch_size=128
28 | module_name=SGR
29 | folder_name=coco/r_corr_2
30 |
31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
32 | --warmup_epoch $warmup_epoch \
33 | --folder_name $folder_name \
34 | --noise_ratio $noise_ratio \
35 | --noise_file $noise_file \
36 | --num_epochs $num_epochs \
37 | --batch_size $batch_size \
38 | --lr_update $lr_update \
39 | --module_name $module_name \
40 | --learning_rate 0.0002 \
41 | --data_name $data_name \
42 | --data_path $data_path \
43 | --vocab_path $vocab_path \
44 | --contrastive_loss $contrastive_loss \
45 | --temp $temp \
46 | --lam $lam \
47 | --q $q \
48 | --beta $beta \
49 | --gamma $gamma \
50 |
51 |
--------------------------------------------------------------------------------
/GSC+/train_cc152k.sh:
--------------------------------------------------------------------------------
1 | export OPENBLAS_NUM_THREADS=1
2 | export GOTO_NUM_THREADS=1
3 | export OMP_NUM_THREADS=1
4 | # dataset settings
5 | data_name=cc152k_precomp
6 | data_path=/remote-home/share/zhaozh/NC_Datasets/data
7 | vocab_path=/remote-home/share/zhaozh/NC_Datasets/vocab
8 |
9 | # noise settings
10 | noise_ratio=0.0
11 | noise_file=/remote-home/share/zhaozh/NC_Datasets/data
12 |
13 | # loss settings
14 | contrastive_loss=InfoNCE
15 | temp=0.07
16 | lam=0.01
17 | q=0.01
18 |
19 | beta=0.7
20 | gamma=0.02
21 |
22 | # train settings
23 | gpu=0
24 | warmup_epoch=0
25 | num_epochs=60
26 | lr_update=20
27 | batch_size=128
28 | module_name=SGR
29 | folder_name=cc152k/r_corr_0
30 |
31 | CUDA_VISIBLE_DEVICES=$gpu python train.py --gpu $gpu \
32 | --warmup_epoch $warmup_epoch \
33 | --folder_name $folder_name \
34 | --noise_ratio $noise_ratio \
35 | --noise_file $noise_file \
36 | --num_epochs $num_epochs \
37 | --batch_size $batch_size \
38 | --lr_update $lr_update \
39 | --module_name $module_name \
40 | --learning_rate 0.0002 \
41 | --data_name $data_name \
42 | --data_path $data_path \
43 | --vocab_path $vocab_path \
44 | --contrastive_loss $contrastive_loss \
45 | --temp $temp \
46 | --lam $lam \
47 | --q $q \
48 | --beta $beta \
49 | --gamma $gamma \
50 |
51 |
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mitigating Noisy Correspondence by Geometrical Structure Consistency Learning
2 |
3 |
5 |
6 | by Zihua Zhao, Mengzi Chen, Tianjie Dai, Jiangchao Yao, Bo Han, Ya Zhang, Yanfeng Wang at Cooperative Medianet Innovation Center, Shanghai Jiao Tong University, Shanghai Artificial Intelligence Laboratory, Hong Kong Baptist University
7 |
8 | Computer Vision and Pattern Recognition (CVPR), 2024.
9 |
10 | This repo is the official Pytorch implementation of GSC.
11 |
12 | ⚠️ This repository is being organized and updated continuously. Please note that this version is not the final release.
13 |
14 | ## Environment
15 |
16 | Code is implemented based on original code provided by DECL from https://github.com/QinYang79/DECL, which offers the standard code for retrieval framework, data loader and evaluation metrics.
17 |
18 | Besides, create the environment for running our code:
19 |
20 | ```bash
21 | conda create --name GSC python=3.9.13
22 | conda activate GSC
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Data Preparation
27 |
28 | For all of the dataset settings, we follow DECL including data split, sample preprocessing and noise simulation.
29 |
30 | #### MS-COCO and Flickr30K
31 |
32 | MS-COCO and Flickr30K are two datasets with simulated noisy correspondence. The simulated noise is created by randomly flipping captions. The images in the datasets are first extracted into features by SCAN (https://github.com/kuanghuei/SCAN).
33 |
34 | ### CC152K
35 |
36 | CC152K is a real-world noisy correspondence dataset. The images are also extracted into features by SCAN.
37 |
38 | Notably, SCAN is implemented on Python2.7 and Caffe, which may not be supported on latest Cuda versions. We are extracting the features through renting machines from platform AutoDL (https://www.autodl.com/login?url=/console/homepage/personal).
39 |
40 | ## Running
41 |
42 | For training on different datasets, we provide integrated bash orders. Evaluation metrics are implemented in eval.py.
43 |
44 | ```
45 | # for training and evaluation on Flickr30K as an example
46 | conda activate GSC
47 | sh train_f30k.sh
48 | python eval.py
49 | ```
50 |
51 | ## Citation
52 |
53 | If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.
54 |
55 | ```
56 | @inproceedings{zhao2024mitigating,
57 | title={Mitigating Noisy Correspondence by Geometrical Structure Consistency Learning},
58 | author={Zhao, Zihua and Chen, Mengxi and Dai, Tianjie and Yao, Jiangchao and Han, Bo and Zhang, Ya and Wang, Yanfeng},
59 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
60 | pages={27381--27390},
61 | year={2024}
62 | }
63 | ```
64 |
65 | If you have any problem with this code, please feel free to contact **[sjtuszzh@sjtu.edu.cn](mailto:sjtuszzh@sjtu.edu.cn)**.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | alabaster==0.7.13
3 | argon2-cffi==21.1.0
4 | async-generator==1.10
5 | attrs==21.2.0
6 | Babel==2.12.1
7 | backcall==0.2.0
8 | bleach==4.1.0
9 | cachetools==5.3.0
10 | certifi==2021.5.30
11 | cffi==1.15.0
12 | charset-normalizer==2.0.7
13 | click==8.1.3
14 | cloudpickle==1.6.0
15 | cycler==0.10.0
16 | Cython==0.29.34
17 | dask==2021.3.0
18 | dataclasses==0.6
19 | decorator==4.4.2
20 | defusedxml==0.7.1
21 | docopt==0.6.2
22 | docutils==0.19
23 | entrypoints==0.3
24 | filelock==3.12.0
25 | fsspec==2023.1.0
26 | google-auth==2.17.3
27 | google-auth-oauthlib==0.4.6
28 | greenlet==1.1.2
29 | grpcio==1.41.1
30 | huggingface-hub==0.14.0
31 | idna==3.3
32 | imageio==2.9.0
33 | imagesize==1.4.1
34 | importlib-metadata==4.8.2
35 | ipykernel==5.5.6
36 | ipython==7.16.1
37 | ipython-genutils==0.2.0
38 | ipywidgets==7.6.5
39 | jedi==0.17.0
40 | Jinja2==3.0.3
41 | joblib==1.0.1
42 | jsonschema==3.2.0
43 | jupyter==1.0.0
44 | jupyter-client==7.0.6
45 | jupyter-console==6.4.0
46 | jupyter-core==4.9.1
47 | jupyterlab-pygments==0.1.2
48 | jupyterlab-widgets==1.0.2
49 | kiwisolver==1.3.1
50 | lmdb==1.4.1
51 | Markdown==3.3.4
52 | MarkupSafe==2.0.1
53 | matplotlib==3.3.4
54 | MedPy==0.4.0
55 | mistune==0.8.4
56 | monai==0.7.0
57 | nbclient==0.5.8
58 | nbconvert==6.0.7
59 | nbformat==5.1.3
60 | nest-asyncio==1.5.1
61 | networkx==2.5.1
62 | nibabel==3.2.1
63 | nlpaug==1.1.7
64 | nltk==3.8.1
65 | notebook==6.4.5
66 | numpy==1.21.6
67 | oauthlib==3.2.2
68 | olefile==0.46
69 | opencv-contrib-python==4.5.2.52
70 | opencv-python-headless==4.5.4.58
71 | packaging==21.2
72 | pandas==1.1.5
73 | pandocfilters==1.5.0
74 | parso==0.8.2
75 | pexpect==4.8.0
76 | pickleshare==0.7.5
77 | Pillow==8.4.0
78 | pipreqs==0.4.11
79 | prefetch-generator==1.0.3
80 | prometheus-client==0.12.0
81 | prompt-toolkit==3.0.22
82 | protobuf==3.17.3
83 | ptyprocess==0.7.0
84 | pyasn1==0.5.0
85 | pyasn1-modules==0.3.0
86 | pycparser==2.21
87 | pydicom==2.2.2
88 | Pygments==2.15.1
89 | pylidc==0.2.2
90 | pyparsing==2.4.7
91 | pyrsistent==0.18.0
92 | python-dateutil==2.8.2
93 | pytz==2021.1
94 | PyWavelets==1.1.1
95 | PyYAML==5.4.1
96 | pyzmq==20.0.0
97 | qtconsole==5.2.0
98 | QtPy==1.9.0
99 | regex==2022.10.31
100 | requests==2.26.0
101 | requests-oauthlib==1.3.1
102 | rsa==4.9
103 | seaborn==0.12.2
104 | Send2Trash==1.8.0
105 | setuptools==58.0.4
106 | SimpleITK==2.0.2
107 | six==1.16.0
108 | snowballstemmer==2.2.0
109 | Sphinx==5.3.0
110 | sphinx-gallery==0.13.0
111 | sphinxcontrib-applehelp==1.0.2
112 | sphinxcontrib-devhelp==1.0.2
113 | sphinxcontrib-htmlhelp==2.0.0
114 | sphinxcontrib-jsmath==1.0.1
115 | sphinxcontrib-qthelp==1.0.3
116 | sphinxcontrib-serializinghtml==1.1.5
117 | SQLAlchemy==1.2.0
118 | tensorboard==2.11.2
119 | tensorboard-data-server==0.6.1
120 | tensorboard-logger==0.1.0
121 | tensorboard-plugin-wit==1.8.1
122 | tensorboardX==2.4
123 | terminado==0.9.2
124 | testpath==0.5.0
125 | threadpoolctl==2.1.0
126 | tifffile==2020.9.3
127 | tokenizers==0.13.3
128 | toolz==0.11.1
129 | torch-tb-profiler==0.4.1
130 | tornado==6.1
131 | tqdm==4.62.3
132 | traitlets==4.3.3
133 | transformers==4.28.1
134 | typing-extensions==3.7.4.3
135 | urllib3==1.26.7
136 | wcwidth==0.2.5
137 | webencodings==0.5.1
138 | Werkzeug==2.0.2
139 | wheel==0.37.0
140 | widgetsnbextension==3.5.1
141 | yarg==0.1.9
142 | zipp==3.6.0
143 |
--------------------------------------------------------------------------------
/GSC+/vocab.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Stacked Cross Attention Network implementation based on
3 | # https://arxiv.org/abs/1803.08024.
4 | # "Stacked Cross Attention for Image-Text Matching"
5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He
6 | #
7 | # Writen by Kuang-Huei Lee, 2018
8 | # ---------------------------------------------------------------
9 | """Vocabul ary wrapper"""
10 |
11 | import nltk
12 | from collections import Counter
13 | import argparse
14 | import os
15 | import json
16 | import csv
17 |
18 | annotations = {
19 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'],
20 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'],
21 | 'cc152k_precomp': ['train_caps.tsv', 'dev_caps.tsv'],
22 | }
23 |
24 |
25 | class Vocabulary(object):
26 | """Simple vocabulary wrapper."""
27 |
28 | def __init__(self):
29 | self.word2idx = {}
30 | self.idx2word = {}
31 | self.idx = 0
32 |
33 | def add_word(self, word):
34 | if word not in self.word2idx:
35 | self.word2idx[word] = self.idx
36 | self.idx2word[self.idx] = word
37 | self.idx += 1
38 |
39 | def __call__(self, word):
40 | if word not in self.word2idx:
41 | return self.word2idx['']
42 | return self.word2idx[word]
43 |
44 | def __len__(self):
45 | return len(self.word2idx)
46 |
47 |
48 | def serialize_vocab(vocab, dest):
49 | d = {}
50 | d['word2idx'] = vocab.word2idx
51 | d['idx2word'] = vocab.idx2word
52 | d['idx'] = vocab.idx
53 | with open(dest, "w") as f:
54 | json.dump(d, f)
55 |
56 |
57 | def deserialize_vocab(src):
58 | with open(src) as f:
59 | d = json.load(f)
60 | vocab = Vocabulary()
61 | vocab.word2idx = d['word2idx']
62 | vocab.idx2word = d['idx2word']
63 | vocab.idx = d['idx']
64 | return vocab
65 |
66 |
67 | def from_txt(txt):
68 | captions = []
69 | with open(txt, 'r') as f:
70 | for line in f:
71 | captions.append(line.strip())
72 | return captions
73 |
74 | def from_tsv(tsv):
75 | captions = []
76 | img_ids = []
77 | with open(tsv) as f:
78 | tsvreader = csv.reader(f, delimiter='\t')
79 | for line in tsvreader:
80 | captions.append(line[1])
81 | img_ids.append(line[0])
82 | return captions
83 |
84 | def build_vocab(data_path, data_name, caption_file, threshold):
85 | """Build a simple vocabulary wrapper."""
86 | counter = Counter()
87 | for path in caption_file[data_name]:
88 | if data_name == 'cc152k_precomp':
89 | full_path = os.path.join(os.path.join(data_path, data_name), path)
90 | captions = from_tsv(full_path)
91 | elif data_name in ['coco_precomp', 'f30k_precomp']:
92 | full_path = os.path.join(os.path.join(data_path, data_name), path)
93 | captions = from_txt(full_path)
94 | else:
95 | raise NotImplementedError('Not support!')
96 |
97 | for i, caption in enumerate(captions):
98 | tokens = nltk.tokenize.word_tokenize(caption.lower())
99 | counter.update(tokens)
100 |
101 | if i % 1000 == 0:
102 | print("[%d/%d] tokenized the captions." % (i, len(captions)))
103 |
104 | # Discard if the occurrence of the word is less than min_word_cnt.
105 | words = [word for word, cnt in counter.items() if cnt >= threshold]
106 |
107 | # Create a vocab wrapper and add some special tokens.
108 | vocab = Vocabulary()
109 | vocab.add_word('')
110 | vocab.add_word('')
111 | vocab.add_word('')
112 | vocab.add_word('')
113 |
114 | # Add words to the vocabulary.
115 | for i, word in enumerate(words):
116 | vocab.add_word(word)
117 | return vocab
118 |
119 |
120 | def main(data_path, data_name):
121 | vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4)
122 | serialize_vocab(vocab, './%s_vocab.json' % data_name)
123 | print("Saved vocabulary file to ", './%s_vocab.json' % data_name)
124 |
125 |
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--data_path', default='/data/RR/data')
130 | parser.add_argument('--data_name', default='cc152k_precomp',
131 | help='{coco,f30k,cc152k}_precomp')
132 | opt = parser.parse_args()
133 | main(opt.data_path, opt.data_name)
134 |
--------------------------------------------------------------------------------
/GSC/vocab.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Stacked Cross Attention Network implementation based on
3 | # https://arxiv.org/abs/1803.08024.
4 | # "Stacked Cross Attention for Image-Text Matching"
5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He
6 | #
7 | # Writen by Kuang-Huei Lee, 2018
8 | # ---------------------------------------------------------------
9 | """Vocabul ary wrapper"""
10 |
11 | import nltk
12 | from collections import Counter
13 | import argparse
14 | import os
15 | import json
16 | import csv
17 |
18 | annotations = {
19 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'],
20 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'],
21 | 'cc152k_precomp': ['train_caps.tsv', 'dev_caps.tsv'],
22 | }
23 |
24 |
25 | class Vocabulary(object):
26 | """Simple vocabulary wrapper."""
27 |
28 | def __init__(self):
29 | self.word2idx = {}
30 | self.idx2word = {}
31 | self.idx = 0
32 |
33 | def add_word(self, word):
34 | if word not in self.word2idx:
35 | self.word2idx[word] = self.idx
36 | self.idx2word[self.idx] = word
37 | self.idx += 1
38 |
39 | def __call__(self, word):
40 | if word not in self.word2idx:
41 | return self.word2idx['']
42 | return self.word2idx[word]
43 |
44 | def __len__(self):
45 | return len(self.word2idx)
46 |
47 |
48 | def serialize_vocab(vocab, dest):
49 | d = {}
50 | d['word2idx'] = vocab.word2idx
51 | d['idx2word'] = vocab.idx2word
52 | d['idx'] = vocab.idx
53 | with open(dest, "w") as f:
54 | json.dump(d, f)
55 |
56 |
57 | def deserialize_vocab(src):
58 | with open(src) as f:
59 | d = json.load(f)
60 | vocab = Vocabulary()
61 | vocab.word2idx = d['word2idx']
62 | vocab.idx2word = d['idx2word']
63 | vocab.idx = d['idx']
64 | return vocab
65 |
66 |
67 | def from_txt(txt):
68 | captions = []
69 | with open(txt, 'r') as f:
70 | for line in f:
71 | captions.append(line.strip())
72 | return captions
73 |
74 | def from_tsv(tsv):
75 | captions = []
76 | img_ids = []
77 | with open(tsv) as f:
78 | tsvreader = csv.reader(f, delimiter='\t')
79 | for line in tsvreader:
80 | captions.append(line[1])
81 | img_ids.append(line[0])
82 | return captions
83 |
84 | def build_vocab(data_path, data_name, caption_file, threshold):
85 | """Build a simple vocabulary wrapper."""
86 | counter = Counter()
87 | for path in caption_file[data_name]:
88 | if data_name == 'cc152k_precomp':
89 | full_path = os.path.join(os.path.join(data_path, data_name), path)
90 | captions = from_tsv(full_path)
91 | elif data_name in ['coco_precomp', 'f30k_precomp']:
92 | full_path = os.path.join(os.path.join(data_path, data_name), path)
93 | captions = from_txt(full_path)
94 | else:
95 | raise NotImplementedError('Not support!')
96 |
97 | for i, caption in enumerate(captions):
98 | tokens = nltk.tokenize.word_tokenize(caption.lower())
99 | counter.update(tokens)
100 |
101 | if i % 1000 == 0:
102 | print("[%d/%d] tokenized the captions." % (i, len(captions)))
103 |
104 | # Discard if the occurrence of the word is less than min_word_cnt.
105 | words = [word for word, cnt in counter.items() if cnt >= threshold]
106 |
107 | # Create a vocab wrapper and add some special tokens.
108 | vocab = Vocabulary()
109 | vocab.add_word('')
110 | vocab.add_word('')
111 | vocab.add_word('')
112 | vocab.add_word('')
113 |
114 | # Add words to the vocabulary.
115 | for i, word in enumerate(words):
116 | vocab.add_word(word)
117 | return vocab
118 |
119 |
120 | def main(data_path, data_name):
121 | vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4)
122 | serialize_vocab(vocab, './%s_vocab.json' % data_name)
123 | print("Saved vocabulary file to ", './%s_vocab.json' % data_name)
124 |
125 |
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--data_path', default='/data/RR/data')
130 | parser.add_argument('--data_name', default='cc152k_precomp',
131 | help='{coco,f30k,cc152k}_precomp')
132 | opt = parser.parse_args()
133 | main(opt.data_path, opt.data_name)
134 |
--------------------------------------------------------------------------------
/GSC/data.py:
--------------------------------------------------------------------------------
1 | """Dataloader"""
2 |
3 | import csv
4 | import torch
5 | import torch.utils.data as data
6 | import os
7 | import nltk
8 | import numpy as np
9 |
10 |
11 | class PrecompDataset(data.Dataset):
12 | """
13 | Load precomputed captions and image features
14 | Possible options: f30k_precomp, coco_precomp, cc152k_precomp
15 | """
16 |
17 | def __init__(self, data_path, data_split, vocab, opt=None):
18 | self.vocab = vocab
19 | loc = data_path + '/'
20 |
21 | # load the raw captions
22 | self.captions = []
23 |
24 | if 'cc152k' in opt.data_name:
25 | with open(loc + '%s_caps.tsv' % data_split) as f:
26 | tsvreader = csv.reader(f, delimiter='\t')
27 | for line in tsvreader:
28 | self.captions.append(line[1].strip())
29 | else:
30 | with open(loc + '%s_caps.txt' % data_split, 'r', encoding='utf-8') as f:
31 | for line in f.readlines():
32 | self.captions.append(line.strip())
33 |
34 | # load the image features
35 | self.images = np.load(loc + '%s_ims.npy' % data_split)
36 | img_len = self.images.shape[0]
37 | cap_len = len(self.captions)
38 |
39 | # rkiros data has redundancy in images, we divide by 5
40 | if img_len != cap_len:
41 | self.im_div = 5
42 | else:
43 | self.im_div = 1
44 |
45 | # if use noise index by image: self.noisy_inx = np.arange(img_len)
46 | self.noisy_inx = np.arange(cap_len) // self.im_div
47 | if data_split == 'train' and opt.noise_ratio > 0.0:
48 | noise_file = opt.noise_file
49 | if os.path.exists(noise_file):
50 | print('=> load noisy index from {}'.format(noise_file))
51 | self.noisy_inx = np.load(noise_file)
52 | else:
53 | noise_ratio = opt.noise_ratio
54 | inx = np.arange(img_len)
55 | np.random.shuffle(inx)
56 | noisy_inx = inx[0: int(noise_ratio * img_len)]
57 | shuffle_noisy_inx = np.array(noisy_inx)
58 | np.random.shuffle(shuffle_noisy_inx)
59 | self.noisy_inx[noisy_inx] = shuffle_noisy_inx
60 | print('Noisy ratio: %g' % noise_ratio)
61 | self.length = len(self.captions)
62 | self.corrs = (self.noisy_inx == (np.arange(self.length) // self.im_div)).astype(np.int)
63 |
64 | # the development set for coco is large and so validation would be slow
65 | if data_split == 'dev':
66 | self.length = 5000
67 | if 'cc152k' in opt.data_name:
68 | self.length = 1000
69 |
70 | def __getitem__(self, index):
71 | # handle the image redundancy
72 | # if use noise index by images: img_id = self.noisy_inx[int(index / self.im_div)]
73 | img_id = self.noisy_inx[index]
74 | image = torch.Tensor(self.images[img_id])
75 | caption = self.captions[index]
76 | vocab = self.vocab
77 |
78 | # convert caption (string) to word ids.
79 | tokens = nltk.tokenize.word_tokenize(caption.lower())
80 | caption = []
81 | caption.append(vocab(''))
82 | caption.extend([vocab(token) for token in tokens])
83 | caption.append(vocab(''))
84 | target = torch.Tensor(caption)
85 | if img_id == index // self.im_div:
86 | label = 1
87 | else:
88 | label = 0
89 | return image, target, index, img_id, label
90 |
91 | def __len__(self):
92 | return self.length
93 |
94 |
95 | def collate_fn(data):
96 | # Sort a data list by caption length
97 | data.sort(key=lambda x: len(x[1]), reverse=True)
98 | images, captions, ids, img_ids, label = zip(*data)
99 |
100 | # Merge images (convert tuple of 2D tensor to 3D tensor)
101 | images = torch.stack(images, 0)
102 |
103 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
104 | lengths = [len(cap) for cap in captions]
105 | targets = torch.zeros(len(captions), max(lengths)).long()
106 | for i, cap in enumerate(captions):
107 | end = lengths[i]
108 | targets[i, :end] = cap[:end]
109 |
110 | return images, targets, lengths, list(ids), list(label)
111 |
112 |
113 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
114 | shuffle=True, num_workers=2):
115 | dset = PrecompDataset(data_path, data_split, vocab, opt)
116 |
117 | data_loader = torch.utils.data.DataLoader(dataset=dset,
118 | batch_size=batch_size,
119 | shuffle=shuffle,
120 | pin_memory=False,
121 | num_workers=num_workers,
122 | collate_fn=collate_fn)
123 | return data_loader
124 |
125 |
126 | def get_loaders(data_name, vocab, batch_size, workers, opt):
127 | # get the data path
128 | dpath = os.path.join(opt.data_path, data_name)
129 |
130 | # get the train_loader
131 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
132 | batch_size, True, workers)
133 | # get the val_loader
134 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
135 | batch_size, False, workers)
136 | # get the test_loader
137 | test_loader = get_precomp_loader(dpath, 'test', vocab, opt,
138 | batch_size, False, workers)
139 | return train_loader, val_loader, test_loader
140 |
141 |
142 | def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt):
143 | # get the data path
144 | dpath = os.path.join(opt.data_path, data_name)
145 |
146 | # get the test_loader
147 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
148 | batch_size, False, workers)
149 | return test_loader
150 |
--------------------------------------------------------------------------------
/GSC/opts.py:
--------------------------------------------------------------------------------
1 | """Argument parser"""
2 |
3 | import argparse
4 | import os
5 | import random
6 | import sys
7 | import numpy as np
8 | import torch
9 |
10 | from utils import save_config
11 |
12 |
13 | def parse_opt():
14 | parser = argparse.ArgumentParser()
15 | # --------------------------- data path -------------------------#
16 | parser.add_argument('--data_name', default='cc152k_precomp',
17 | help='{coco,f30k,cc152k}_precomp')
18 | parser.add_argument('--data_path', default='/remote-home/share/zhaozh/NC_Datasets/data',
19 | help='path to datasets')
20 | parser.add_argument('--vocab_path', default='/remote-home/share/zhaozh/NC_Datasets/vocab',
21 | help='Path to saved vocabulary json files.')
22 | parser.add_argument('--folder_name', default='debug', help='Folder name to save the running results')
23 |
24 | # ----------------------- training setting ----------------------#
25 | parser.add_argument('--batch_size', default=128, type=int,
26 | help='Size of a training mini-batch.')
27 | parser.add_argument('--num_epochs', default=60, type=int,
28 | help='Number of training epochs.')
29 | parser.add_argument('--lr_update', default=20, type=int,
30 | help='Number of epochs to update the learning rate.')
31 | parser.add_argument('--learning_rate', default=2e-4, type=float,
32 | help='Initial learning rate.')
33 | parser.add_argument('--workers', default=4, type=int,
34 | help='Number of data loader workers.')
35 | parser.add_argument('--log_step', default=100, type=int,
36 | help='Number of steps to print and record the log.')
37 | parser.add_argument('--val_step', default=1000, type=int,
38 | help='Number of steps to run validation.')
39 | parser.add_argument('--grad_clip', default=2., type=float,
40 | help='Gradient clipping threshold.')
41 |
42 | # ------------------------- model settings (SGR, SAF, and other similarity models) -----------------------#
43 | parser.add_argument('--margin', default=0.2, type=float,
44 | help='Rank loss margin.')
45 | parser.add_argument('--max_violation', action='store_false',
46 | help='Use max instead of sum in the rank loss.')
47 | parser.add_argument('--img_dim', default=2048, type=int,
48 | help='Dimensionality of the image embedding.')
49 | parser.add_argument('--word_dim', default=300, type=int,
50 | help='Dimensionality of the word embedding.')
51 | parser.add_argument('--embed_size', default=1024, type=int,
52 | help='Dimensionality of the joint embedding.')
53 | parser.add_argument('--sim_dim', default=256, type=int,
54 | help='Dimensionality of the sim embedding.')
55 | parser.add_argument('--num_layers', default=1, type=int,
56 | help='Number of GRU layers.')
57 | parser.add_argument('--bi_gru', action='store_false',
58 | help='Use bidirectional GRU.')
59 | parser.add_argument('--no_imgnorm', action='store_true',
60 | help='Do not normalize the image embeddings.')
61 | parser.add_argument('--no_txtnorm', action='store_true',
62 | help='Do not normalize the text embeddings.')
63 | parser.add_argument('--module_name', default='SGR', type=str,
64 | help='SGR, SAF')
65 | parser.add_argument('--sgr_step', default=3, type=int,
66 | help='Step of the SGR.')
67 |
68 | # ------------------------- our DECL settings -----------------------#
69 | parser.add_argument('--warmup_if', default=False, type=bool,
70 | help='Need warmup training?')
71 | parser.add_argument('--warmup_epochs', default=0, type=int,
72 | help='Number of warmup epochs.')
73 | parser.add_argument('--warmup_model_path', default='',
74 | help='Path to load a pre-model')
75 | # noise settings
76 | parser.add_argument('--noise_ratio', default=0.0, type=float,
77 | help='Noisy ratio')
78 | parser.add_argument('--noise_file', default='/remote-home/share/zhaozh/NC_Datasets/data/noise_file_by_caption/f30k/f30k_precomp_0.0.npy',
79 | help='Path to noise index file')
80 | parser.add_argument("--seed", default=random.randint(0, 100), type=int, help="Random seed.")
81 |
82 | # loss settings
83 | parser.add_argument('--contrastive_loss', default='InfoNCE', help='Choose in Triplet/InfoNCE/RINCE.')
84 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature for Info NCE Loss.')
85 | parser.add_argument('--lam', default=0.01, type=float, help='Parameter for RINCE.')
86 | parser.add_argument('--q', default=0.1, type=float, help='Parameter for RINCE.')
87 |
88 | # elr settings
89 | parser.add_argument('--beta', default=0.7, type=float)
90 | parser.add_argument('--alpha', default=1.0, type=float)
91 | parser.add_argument('--gamma', default=1.0, type=float)
92 |
93 | # gpu settings
94 | parser.add_argument('--gpu', default='0', help='Which gpu to use.')
95 |
96 | opt = parser.parse_args()
97 | project_path = str(sys.path[0])
98 | opt.log_dir = f'{project_path}/runs/{opt.folder_name}/log_dir'
99 | opt.checkpoint_dir = f'{project_path}/runs/{opt.folder_name}/checkpoint_dir'
100 | if not os.path.isdir(opt.log_dir):
101 | os.makedirs(opt.log_dir)
102 | if not os.path.isdir(opt.checkpoint_dir):
103 | os.makedirs(opt.checkpoint_dir)
104 |
105 | # save opt
106 | save_config(opt, os.path.join(opt.log_dir, "config.json"))
107 |
108 | # set random seed
109 | random.seed(opt.seed)
110 | np.random.seed(opt.seed)
111 | torch.manual_seed(opt.seed)
112 | torch.random.manual_seed(opt.seed)
113 | if torch.cuda.is_available():
114 | torch.cuda.manual_seed(opt.seed)
115 | torch.backends.cudnn.benchmark = True
116 | return opt
117 |
--------------------------------------------------------------------------------
/GSC+/data.py:
--------------------------------------------------------------------------------
1 | """Dataloader"""
2 |
3 | import csv
4 | import torch
5 | import torch.utils.data as data
6 | import os
7 | import nltk
8 | import numpy as np
9 |
10 |
11 | class PrecompDataset(data.Dataset):
12 | """
13 | Load precomputed captions and image features
14 | Possible options: f30k_precomp, coco_precomp, cc152k_precomp
15 | """
16 |
17 | def __init__(self, data_path, data_split, vocab, opt=None):
18 | self.vocab = vocab
19 | loc = data_path + '/'
20 |
21 | # load the raw captions
22 | self.captions = []
23 |
24 | if 'cc152k' in opt.data_name:
25 | with open(loc + '%s_caps.tsv' % data_split) as f:
26 | tsvreader = csv.reader(f, delimiter='\t')
27 | for line in tsvreader:
28 | self.captions.append(line[1].strip())
29 | else:
30 | with open(loc + '%s_caps.txt' % data_split, 'r', encoding='utf-8') as f:
31 | for line in f.readlines():
32 | self.captions.append(line.strip())
33 |
34 | # load the image features
35 | self.images = np.load(loc + '%s_ims.npy' % data_split)
36 | img_len = self.images.shape[0]
37 | cap_len = len(self.captions)
38 |
39 | # rkiros data has redundancy in images, we divide by 5
40 | if img_len != cap_len:
41 | self.im_div = 5
42 | else:
43 | self.im_div = 1
44 |
45 | # if use noise index by image: self.noisy_inx = np.arange(img_len)
46 | self.noisy_inx = np.arange(cap_len) // self.im_div
47 | if data_split == 'train' and opt.noise_ratio > 0.0:
48 | noise_file = opt.noise_file
49 | if os.path.exists(noise_file):
50 | print('=> load noisy index from {}'.format(noise_file))
51 | self.noisy_inx = np.load(noise_file)
52 | else:
53 | noise_ratio = opt.noise_ratio
54 | inx = np.arange(cap_len)
55 | np.random.shuffle(inx)
56 | noisy_inx = inx[0: int(noise_ratio * cap_len)]
57 | shuffle_noisy_inx = np.array(noisy_inx) // 5
58 | np.random.shuffle(shuffle_noisy_inx)
59 | self.noisy_inx[noisy_inx] = shuffle_noisy_inx
60 | np.save(noise_file, self.noisy_inx)
61 | print('Noisy ratio: %g' % noise_ratio)
62 | self.length = len(self.captions)
63 | self.corrs = (self.noisy_inx == (np.arange(self.length) // self.im_div)).astype(int)
64 |
65 | # the development set for coco is large and so validation would be slow
66 | if data_split == 'dev':
67 | self.length = 5000
68 | if 'cc152k' in opt.data_name:
69 | self.length = 1000
70 |
71 | def __getitem__(self, index):
72 | # handle the image redundancy
73 | # if use noise index by images: img_id = self.noisy_inx[int(index / self.im_div)]
74 | img_id = self.noisy_inx[index]
75 | image = torch.Tensor(self.images[img_id])
76 | caption = self.captions[index]
77 | vocab = self.vocab
78 |
79 | # convert caption (string) to word ids.
80 | tokens = nltk.tokenize.word_tokenize(caption.lower())
81 | caption = []
82 | caption.append(vocab(''))
83 | caption.extend([vocab(token) for token in tokens])
84 | caption.append(vocab(''))
85 | target = torch.Tensor(caption)
86 | if img_id == index // self.im_div:
87 | label = 1
88 | else:
89 | label = 0
90 | return image, target, index, img_id, label
91 |
92 | def __len__(self):
93 | return self.length
94 |
95 |
96 | def collate_fn(data):
97 | # Sort a data list by caption length
98 | data.sort(key=lambda x: len(x[1]), reverse=True)
99 | images, captions, ids, img_ids, label = zip(*data)
100 |
101 | # Merge images (convert tuple of 2D tensor to 3D tensor)
102 | images = torch.stack(images, 0)
103 |
104 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
105 | lengths = [len(cap) for cap in captions]
106 | targets = torch.zeros(len(captions), max(lengths)).long()
107 | for i, cap in enumerate(captions):
108 | end = lengths[i]
109 | targets[i, :end] = cap[:end]
110 |
111 | return images, targets, lengths, list(ids), list(label)
112 |
113 |
114 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
115 | shuffle=True, num_workers=2):
116 | dset = PrecompDataset(data_path, data_split, vocab, opt)
117 |
118 | data_loader = torch.utils.data.DataLoader(dataset=dset,
119 | batch_size=batch_size,
120 | shuffle=shuffle,
121 | pin_memory=False,
122 | num_workers=num_workers,
123 | collate_fn=collate_fn)
124 | return data_loader
125 |
126 |
127 | def get_loaders(data_name, vocab, batch_size, workers, opt):
128 | # get the data path
129 | dpath = os.path.join(opt.data_path, data_name)
130 |
131 | # get the train_loader
132 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
133 | batch_size, True, workers)
134 | # get the val_loader
135 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
136 | batch_size, False, workers)
137 | # get the test_loader
138 | test_loader = get_precomp_loader(dpath, 'test', vocab, opt,
139 | batch_size, False, workers)
140 | return train_loader, val_loader, test_loader
141 |
142 |
143 | def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt):
144 | # get the data path
145 | dpath = os.path.join(opt.data_path, data_name)
146 |
147 | # get the test_loader
148 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
149 | batch_size, False, workers)
150 | return test_loader
151 |
--------------------------------------------------------------------------------
/GSC+/opts.py:
--------------------------------------------------------------------------------
1 | """Argument parser"""
2 |
3 | import argparse
4 | import os
5 | import random
6 | import sys
7 | import numpy as np
8 | import torch
9 |
10 | from utils import save_config
11 |
12 |
13 | def parse_opt():
14 | parser = argparse.ArgumentParser()
15 | # --------------------------- data path -------------------------#
16 | parser.add_argument('--data_name', default='f30k_precomp',
17 | help='{coco,f30k,cc152k}_precomp')
18 | parser.add_argument('--data_path', default='/remote-home/share/zhaozh/NC_Datasets/data',
19 | help='path to datasets')
20 | parser.add_argument('--vocab_path', default='/remote-home/share/zhaozh/NC_Datasets/vocab',
21 | help='Path to saved vocabulary json files.')
22 | parser.add_argument('--folder_name', default='debug', help='Folder name to save the running results')
23 |
24 | # ----------------------- training setting ----------------------#
25 | parser.add_argument('--batch_size', default=32, type=int,
26 | help='Size of a training mini-batch.')
27 | parser.add_argument('--num_epochs', default=60, type=int,
28 | help='Number of training epochs.')
29 | parser.add_argument('--lr_update', default=20, type=int,
30 | help='Number of epochs to update the learning rate.')
31 | parser.add_argument('--learning_rate', default=2e-4, type=float,
32 | help='Initial learning rate.')
33 | parser.add_argument('--workers', default=4, type=int,
34 | help='Number of data loader workers.')
35 | parser.add_argument('--log_step', default=100, type=int,
36 | help='Number of steps to print and record the log.')
37 | parser.add_argument('--val_step', default=1000, type=int,
38 | help='Number of steps to run validation.')
39 | parser.add_argument('--grad_clip', default=2., type=float,
40 | help='Gradient clipping threshold.')
41 |
42 | # ------------------------- model settings (SGR, SAF, and other similarity models) -----------------------#
43 | parser.add_argument('--margin', default=0.2, type=float,
44 | help='Rank loss margin.')
45 | parser.add_argument('--max_violation', action='store_false',
46 | help='Use max instead of sum in the rank loss.')
47 | parser.add_argument('--img_dim', default=2048, type=int,
48 | help='Dimensionality of the image embedding.')
49 | parser.add_argument('--word_dim', default=300, type=int,
50 | help='Dimensionality of the word embedding.')
51 | parser.add_argument('--embed_size', default=1024, type=int,
52 | help='Dimensionality of the joint embedding.')
53 | parser.add_argument('--sim_dim', default=256, type=int,
54 | help='Dimensionality of the sim embedding.')
55 | parser.add_argument('--num_layers', default=1, type=int,
56 | help='Number of GRU layers.')
57 | parser.add_argument('--bi_gru', action='store_false',
58 | help='Use bidirectional GRU.')
59 | parser.add_argument('--no_imgnorm', action='store_true',
60 | help='Do not normalize the image embeddings.')
61 | parser.add_argument('--no_txtnorm', action='store_true',
62 | help='Do not normalize the text embeddings.')
63 | parser.add_argument('--module_name', default='SGR', type=str,
64 | help='SGR, SAF')
65 | parser.add_argument('--sgr_step', default=3, type=int,
66 | help='Step of the SGR.')
67 |
68 | # ------------------------- our DECL settings -----------------------#
69 | parser.add_argument('--warmup_if', default=False, type=bool,
70 | help='Need warmup training?')
71 | parser.add_argument('--warmup_epochs', default=0, type=int,
72 | help='Number of warmup epochs.')
73 | parser.add_argument('--warmup_model_path', default='/remote-home/zhaozh/NC/ELCL_TOPO/runs_1031/f30k_upper/r_corr_4/gamma_001/checkpoint_dir/model_best.pth.tar',
74 | help='Path to load a pre-model')
75 | # noise settings
76 | parser.add_argument('--noise_ratio', default=0.4, type=float,
77 | help='Noisy ratio')
78 | parser.add_argument('--noise_file', default='/remote-home/share/zhaozh/NC_Datasets/data/noise_file_by_caption/f30k/f30k_precomp_0.4.npy',
79 | help='Path to noise index file')
80 | parser.add_argument("--seed", default=random.randint(0, 100), type=int, help="Random seed.")
81 |
82 | # loss settings
83 | parser.add_argument('--contrastive_loss', default='InfoNCE', help='Choose in Triplet/InfoNCE/RINCE.')
84 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature for Info NCE Loss.')
85 | parser.add_argument('--lam', default=0.01, type=float, help='Parameter for RINCE.')
86 | parser.add_argument('--q', default=0.1, type=float, help='Parameter for RINCE.')
87 |
88 | # elr settings
89 | parser.add_argument('--beta', default=0.7, type=float)
90 | parser.add_argument('--alpha', default=1.0, type=float)
91 | parser.add_argument('--gamma', default=1.0, type=float)
92 |
93 | # gpu settings
94 | parser.add_argument('--gpu', default='0', help='Which gpu to use.')
95 |
96 | opt = parser.parse_args()
97 | project_path = str(sys.path[0])
98 | opt.log_dir = f'{project_path}/runs/{opt.folder_name}/log_dir'
99 | opt.checkpoint_dir = f'{project_path}/runs/{opt.folder_name}/checkpoint_dir'
100 | if not os.path.isdir(opt.log_dir):
101 | os.makedirs(opt.log_dir)
102 | if not os.path.isdir(opt.checkpoint_dir):
103 | os.makedirs(opt.checkpoint_dir)
104 |
105 | # save opt
106 | save_config(opt, os.path.join(opt.log_dir, "config.json"))
107 |
108 | # set random seed
109 | random.seed(opt.seed)
110 | np.random.seed(opt.seed)
111 | torch.manual_seed(opt.seed)
112 | torch.random.manual_seed(opt.seed)
113 | if torch.cuda.is_available():
114 | torch.cuda.manual_seed(opt.seed)
115 | torch.backends.cudnn.benchmark = True
116 | return opt
117 |
--------------------------------------------------------------------------------
/GSC/model/RINCE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import clip_grad_norm_
5 | from model.SGRAF import SGRAF
6 | from sklearn.mixture import GaussianMixture
7 |
8 | import numpy as np
9 | from scipy.spatial.distance import cdist
10 | from utils import AverageMeter
11 |
12 | class RINCE(nn.Module):
13 | def __init__(self, opt):
14 | super(RINCE, self).__init__()
15 | self.opt = opt
16 | self.grad_clip = opt.grad_clip
17 | self.similarity_model = SGRAF(opt)
18 | self.params = list(self.similarity_model.params)
19 | self.optimizer = torch.optim.Adam(self.params, lr=opt.learning_rate)
20 | self.step = 0
21 |
22 | def state_dict(self):
23 | return self.similarity_model.state_dict()
24 |
25 | def load_state_dict(self, state_dict):
26 | self.similarity_model.load_state_dict(state_dict)
27 |
28 | def train_start(self):
29 | """switch to train mode"""
30 | self.similarity_model.train_start()
31 |
32 | def val_start(self):
33 | """switch to valuate mode"""
34 | self.similarity_model.val_start()
35 |
36 | def info_nce_loss(self, similarity_matrix, temp=None):
37 | if temp is None:
38 | temp = self.opt.temp
39 | labels = torch.eye(len(similarity_matrix)).float().cuda()
40 |
41 | # select and combine multiple positives
42 | pos = similarity_matrix[labels.bool()].view(similarity_matrix.shape[0], -1)
43 | # preds = pos
44 | pos = torch.exp(pos / temp).squeeze(1)
45 |
46 | # select only the negatives the negatives
47 | neg = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
48 | neg = torch.exp(neg / temp)
49 |
50 | neg_sum = neg.sum(1)
51 |
52 | loss = -(torch.log(pos / (pos + neg_sum)))
53 | preds = pos / (pos + neg_sum)
54 | return loss, preds
55 |
56 | def warmup_batch(self, images, captions, lengths, ids, corrs):
57 | self.step += 1
58 | batch_length = images.size(0)
59 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths)
60 | targets_batch = self.similarity_model.targets[ids]
61 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim')
62 |
63 | if self.opt.contrastive_loss == 'Triplet':
64 | loss_cl = self.similarity_model.forward_loss(sims)
65 | elif self.opt.contrastive_loss == 'InfoNCE':
66 | loss_cl, preds = self.info_nce_loss(sims)
67 |
68 | loss = loss_cl.mean()
69 |
70 | self.optimizer.zero_grad()
71 | loss = loss.mean()
72 | loss.backward()
73 | if self.grad_clip > 0:
74 | clip_grad_norm_(self.params, self.grad_clip)
75 | self.optimizer.step()
76 | self.logger.update('Step', self.step)
77 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr'])
78 | self.logger.update('Loss', loss.item(), batch_length)
79 |
80 | def train_batch(self, images, captions, lengths, ids, corrs):
81 | self.step += 1
82 | batch_length = images.size(0)
83 |
84 | targets_batch = self.similarity_model.targets[ids]
85 |
86 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths)
87 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim')
88 | loss_cl, preds = self.info_nce_loss(sims)
89 | loss_cl_t, preds_t = self.info_nce_loss(sims.t())
90 | loss_cl = (loss_cl + loss_cl_t) * 0.5 * targets_batch
91 | preds = (preds + preds_t) * 0.5
92 |
93 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens)
94 | sims_img = torch.sigmoid(sims_img) * targets_batch
95 | sims_cap = torch.sigmoid(sims_cap) * targets_batch
96 | sims_topo = sims_img @ sims_cap.t()
97 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0)
98 | loss_topo = loss_topo * targets_batch
99 |
100 | self.similarity_model.topos[ids] = torch.cosine_similarity(sims_img, sims_cap, dim=1).data.detach().float()
101 | self.similarity_model.preds[ids] = preds.data.detach().float()
102 | target_clean = (targets_batch * corrs).sum() / corrs.sum()
103 | target_noise = (targets_batch * (1 - corrs)).sum() / (1 - corrs).sum()
104 |
105 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma
106 |
107 | self.optimizer.zero_grad()
108 | loss.backward()
109 | if self.grad_clip > 0:
110 | clip_grad_norm_(self.params, self.grad_clip)
111 | self.optimizer.step()
112 | self.logger.update('Step', self.step)
113 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr'])
114 | self.logger.update('Loss', loss.item(), batch_length)
115 | self.logger.update('Loss_CL', loss_cl.mean().item(), batch_length)
116 | self.logger.update('Loss_TOPO', loss_topo.mean().item(), batch_length)
117 | self.logger.update('Target Clean', target_clean.item(), batch_length)
118 | self.logger.update('Target Noise', target_noise.item(), batch_length)
119 |
120 | def split_batch(self, corrs=None):
121 | topos = self.similarity_model.topos
122 | topos = (topos - topos.min()) / (topos.max() - topos.min())
123 |
124 | gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
125 | gmm.fit(topos.detach().cpu().numpy().reshape(-1, 1))
126 | prob = gmm.predict_proba(topos.detach().cpu().numpy().reshape(-1, 1))
127 | prob = prob[:, gmm.means_.argmax()]
128 |
129 | self.similarity_model.topo_targets = self.opt.beta * torch.tensor(prob).cuda() + (1 - self.opt.beta) * self.similarity_model.topo_targets
130 | self.similarity_model.preds_targets = self.opt.beta * self.similarity_model.preds + (1 - self.opt.beta) * self.similarity_model.preds_targets
131 | self.similarity_model.targets = torch.minimum(self.similarity_model.preds_targets, self.similarity_model.topo_targets)
132 |
133 | if corrs is not None:
134 | pred = prob > 0.5
135 | clean_acc = (pred * corrs).sum() / corrs.sum()
136 | noise_acc = ((1 - pred) * (1 - corrs)).sum() / (1 - corrs).sum()
137 | print('GMM Clean Acc: ' + str(clean_acc))
138 | print('GMM Noise Acc: ' + str(noise_acc))
--------------------------------------------------------------------------------
/GSC+/model/RINCE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import clip_grad_norm_
5 | from model.SGRAF import SGRAF
6 | from sklearn.mixture import GaussianMixture
7 |
8 | import numpy as np
9 | from scipy.spatial.distance import cdist
10 | from utils import AverageMeter
11 |
12 | class RINCE(nn.Module):
13 | def __init__(self, opt):
14 | super(RINCE, self).__init__()
15 | self.opt = opt
16 | self.grad_clip = opt.grad_clip
17 | self.similarity_model = SGRAF(opt)
18 | self.params = list(self.similarity_model.params)
19 | self.optimizer = torch.optim.Adam(self.params, lr=opt.learning_rate)
20 | self.step = 0
21 |
22 | def state_dict(self):
23 | return self.similarity_model.state_dict()
24 |
25 | def load_state_dict(self, state_dict):
26 | self.similarity_model.load_state_dict(state_dict)
27 |
28 | def train_start(self):
29 | """switch to train mode"""
30 | self.similarity_model.train_start()
31 |
32 | def val_start(self):
33 | """switch to valuate mode"""
34 | self.similarity_model.val_start()
35 |
36 | def info_nce_loss(self, similarity_matrix, temp=None):
37 | if temp is None:
38 | temp = self.opt.temp
39 | labels = torch.eye(len(similarity_matrix)).float().cuda()
40 |
41 | # select and combine multiple positives
42 | pos = similarity_matrix[labels.bool()].view(similarity_matrix.shape[0], -1)
43 | pos = torch.exp(pos / temp).squeeze(1)
44 |
45 | # select only the negatives the negatives
46 | neg = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
47 | neg = torch.exp(neg / temp)
48 |
49 | neg_sum = neg.sum(1)
50 |
51 | loss = -(torch.log(pos / (pos + neg_sum)))
52 | preds = pos / (pos + neg_sum)
53 | return loss, preds
54 |
55 | def warmup_batch(self, images, captions, lengths, ids, corrs):
56 | self.step += 1
57 | batch_length = images.size(0)
58 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths)
59 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim')
60 |
61 | loss_cl, preds = self.info_nce_loss(sims)
62 | loss_cl_t, preds_t = self.info_nce_loss(sims.t())
63 | loss_cl = (loss_cl + loss_cl_t) * 0.5
64 |
65 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens)
66 | sims_img = torch.sigmoid(sims_img)
67 | sims_cap = torch.sigmoid(sims_cap)
68 | sims_topo = sims_img @ sims_cap.t()
69 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0)
70 | loss_topo = loss_topo
71 |
72 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma
73 |
74 | self.optimizer.zero_grad()
75 | loss.backward()
76 | if self.grad_clip > 0:
77 | clip_grad_norm_(self.params, self.grad_clip)
78 | self.optimizer.step()
79 | self.logger.update('Step', self.step)
80 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr'])
81 | self.logger.update('Loss', loss.item(), batch_length)
82 |
83 | def train_batch(self, images, captions, lengths, ids, corrs, targets_batch):
84 | self.step += 1
85 | batch_length = images.size(0)
86 |
87 | img_embs, cap_embs, cap_lens = self.similarity_model.forward_emb(images, captions, lengths)
88 | sims = self.similarity_model.forward_sim(img_embs, cap_embs, cap_lens, 'sim')
89 | loss_cl, preds = self.info_nce_loss(sims)
90 | loss_cl_t, preds_t = self.info_nce_loss(sims.t())
91 | loss_cl = (loss_cl + loss_cl_t) * 0.5 * targets_batch
92 | preds = (preds + preds_t) * 0.5
93 |
94 | sims_img, sims_cap = self.similarity_model.sim_enc.forward_topology(img_embs, cap_embs, cap_lens)
95 | sims_img = torch.sigmoid(sims_img) * targets_batch
96 | sims_cap = torch.sigmoid(sims_cap) * targets_batch
97 | sims_topo = sims_img @ sims_cap.t()
98 | loss_topo, _ = self.info_nce_loss(sims_topo, temp=1.0)
99 | loss_topo = loss_topo * targets_batch
100 |
101 | self.similarity_model.topos[ids] = torch.cosine_similarity(sims_img, sims_cap, dim=1).data.detach().float()
102 | self.similarity_model.preds[ids] = preds.data.detach().float()
103 | target_clean = (targets_batch * corrs).sum() / corrs.sum()
104 | target_noise = (targets_batch * (1 - corrs)).sum() / (1 - corrs).sum()
105 |
106 | loss = loss_cl.mean() + loss_topo.mean() * self.opt.gamma
107 |
108 | self.optimizer.zero_grad()
109 | loss.backward()
110 | if self.grad_clip > 0:
111 | clip_grad_norm_(self.params, self.grad_clip)
112 | self.optimizer.step()
113 | self.logger.update('Step', self.step)
114 | self.logger.update('Lr', self.optimizer.param_groups[0]['lr'])
115 | self.logger.update('Loss', loss.item(), batch_length)
116 | self.logger.update('Loss_CL', loss_cl.mean().item(), batch_length)
117 | self.logger.update('Loss_TOPO', loss_topo.mean().item(), batch_length)
118 | self.logger.update('Target Clean', target_clean.item(), batch_length)
119 | self.logger.update('Target Noise', target_noise.item(), batch_length)
120 |
121 | def split_batch(self, corrs=None):
122 | topos = self.similarity_model.topos
123 | topos = (topos - topos.min()) / (topos.max() - topos.min())
124 |
125 | gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
126 | gmm.fit(topos.detach().cpu().numpy().reshape(-1, 1))
127 | prob = gmm.predict_proba(topos.detach().cpu().numpy().reshape(-1, 1))
128 | prob = prob[:, gmm.means_.argmax()]
129 |
130 | self.similarity_model.preds_targets = self.opt.beta * self.similarity_model.preds + (1 - self.opt.beta) * self.similarity_model.preds_targets
131 | self.similarity_model.topo_targets = self.opt.beta * torch.tensor(prob).cuda() + (1 - self.opt.beta) * self.similarity_model.topo_targets
132 | self.similarity_model.targets = torch.minimum(self.similarity_model.preds_targets, self.similarity_model.topo_targets)
133 |
134 | if corrs is not None:
135 | pred = prob > 0.5
136 | clean_acc = (pred * corrs).sum() / corrs.sum()
137 | noise_acc = ((1 - pred) * (1 - corrs)).sum() / (1 - corrs).sum()
138 | print('GMM Clean Acc: ' + str(clean_acc))
139 | print('GMM Noise Acc: ' + str(noise_acc))
--------------------------------------------------------------------------------
/GSC/train.py:
--------------------------------------------------------------------------------
1 | """Training script"""
2 |
3 | import logging
4 | import os
5 | import time
6 | from datetime import datetime
7 | import numpy as np
8 | import torch
9 | from sklearn.metrics import confusion_matrix
10 | import tensorboard_logger as tb_logger
11 | import data
12 | import opts
13 | from model.RINCE import RINCE
14 | from evaluation import encode_data, shard_attn_scores, i2t, t2i, AverageMeter, LogCollector
15 | from utils import save_checkpoint
16 | from vocab import deserialize_vocab
17 | import warnings
18 |
19 | warnings.filterwarnings("ignore")
20 |
21 | def adjust_learning_rate(opt, optimizer, epoch):
22 | """
23 | Sets the learning rate to the initial LR
24 | decayed by 10 after opt.lr_update epoch
25 | """
26 | lr = opt.learning_rate * (0.2 ** (epoch // opt.lr_update))
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = lr
29 |
30 |
31 | def train(opt, data_loader, val_loader, model, epoch, preds=None, mode='warmup', best_rsum=0):
32 | batch_time = AverageMeter()
33 | data_time = AverageMeter()
34 | train_logger = LogCollector()
35 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1
36 | end = time.time()
37 | logger.info("=> {mode} epoch: {0}".format(epoch, mode=mode))
38 | for i, (images, captions, lengths, ids, corrs) in enumerate(data_loader):
39 | if images.size(0) == 1:
40 | break
41 | model.train_start()
42 | data_time.update(time.time() - end)
43 | model.logger = train_logger
44 | ids = torch.tensor(ids).cuda()
45 | corrs = torch.tensor(corrs).cuda()
46 | if mode == 'warmup':
47 | model.warmup_batch(images, captions, lengths, ids, corrs)
48 | else:
49 | model.train_batch(images, captions, lengths, ids, corrs)
50 | batch_time.update(time.time() - end)
51 | if model.step % opt.log_step == 0:
52 | logger.info(
53 | 'Epoch ({mode}): [{0}][{1}/{2}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
54 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t{loss}'.format(epoch, i, num_loader_iter,
55 | mode=mode,
56 | batch_time=batch_time,
57 | data_time=data_time,
58 | loss=str(model.logger)))
59 | # Record logs in tensorboard
60 | tb_logger.log_value('epoch', epoch, step=model.step)
61 | tb_logger.log_value('step', i, step=model.step)
62 | tb_logger.log_value('batch_time', batch_time.val, step=model.step)
63 | tb_logger.log_value('data_time', data_time.val, step=model.step)
64 | model.logger.tb_log(tb_logger, step=model.step)
65 |
66 | model.split_batch(train_loader.dataset.corrs)
67 |
68 |
69 | def validation(opt, val_loader, model, test=False):
70 | # compute the encoding for all the validation images and captions
71 | if opt.data_name == 'cc152k_precomp':
72 | per_captions = 1
73 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
74 | per_captions = 5
75 | else:
76 | logger.info(f"No dataset")
77 | return 0
78 | if test:
79 | logger.info(f"=> Test")
80 | else:
81 | logger.info(f"=> Validation")
82 | # compute the encoding for all the validation images and captions
83 | img_embs, cap_embs, cap_lens = encode_data(model.similarity_model, val_loader, opt.log_step)
84 |
85 | # clear duplicate 5*images and keep 1*images
86 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)])
87 |
88 | # record computation time of validation
89 | start = time.time()
90 | sims = shard_attn_scores(model.similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000)
91 | end = time.time()
92 | logger.info(f"calculate similarity time: {end - start}")
93 |
94 | with torch.no_grad():
95 | loss_cl, preds = model.info_nce_loss(torch.tensor(sims[:, :1000]).float().cuda())
96 | loss_cl_t, preds_t = model.info_nce_loss(torch.tensor(sims[:, :1000]).float().cuda().t())
97 | loss_cl = (loss_cl + loss_cl_t) * 0.5
98 |
99 | sims_img, sims_cap = model.similarity_model.sim_enc.forward_topology(torch.tensor(img_embs[:128]).float().cuda(), torch.tensor(cap_embs[:128]).float().cuda(), cap_lens[:128])
100 | sims_img = torch.sigmoid(sims_img)
101 | sims_cap = torch.sigmoid(sims_cap)
102 | sims_topo = sims_img @ sims_cap.t()
103 | loss_topo, _ = model.info_nce_loss(sims_topo, temp=1.0)
104 |
105 | # caption retrieval
106 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims, per_captions, return_ranks=False)
107 | logger.info("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
108 | logger.info("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
109 | # image retrieval
110 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims, per_captions, return_ranks=False)
111 | logger.info("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
112 | logger.info("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
113 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
114 |
115 | logger.info("Sum of Recall: %.2f" % (r_sum))
116 | # sum of recalls to be used for early stopping
117 | if test:
118 | # record metrics in tensorboard
119 | tb_logger.log_value('r1', r1, step=model.step)
120 | tb_logger.log_value('r5', r5, step=model.step)
121 | tb_logger.log_value('r10', r10, step=model.step)
122 | tb_logger.log_value('medr', medr, step=model.step)
123 | tb_logger.log_value('meanr', meanr, step=model.step)
124 | tb_logger.log_value('r1i', r1i, step=model.step)
125 | tb_logger.log_value('r5i', r5i, step=model.step)
126 | tb_logger.log_value('r10i', r10i, step=model.step)
127 | tb_logger.log_value('medri', medri, step=model.step)
128 | tb_logger.log_value('meanr', meanr, step=model.step)
129 | tb_logger.log_value('r_sum', r_sum, step=model.step)
130 | tb_logger.log_value('loss_cl', loss_cl.mean().item(), step=model.step)
131 | tb_logger.log_value('loss_topo', loss_topo.mean().item(), step=model.step)
132 | else:
133 | # record metrics in tensorboard
134 | tb_logger.log_value('t-r1', r1, step=model.step)
135 | tb_logger.log_value('t-r5', r5, step=model.step)
136 | tb_logger.log_value('t-r10', r10, step=model.step)
137 | tb_logger.log_value('t-medr', medr, step=model.step)
138 | tb_logger.log_value('t-meanr', meanr, step=model.step)
139 | tb_logger.log_value('t-r1i', r1i, step=model.step)
140 | tb_logger.log_value('t-r5i', r5i, step=model.step)
141 | tb_logger.log_value('t-r10i', r10i, step=model.step)
142 | tb_logger.log_value('t-medri', medri, step=model.step)
143 | tb_logger.log_value('t-meanr', meanr, step=model.step)
144 | tb_logger.log_value('t-r_sum', r_sum, step=model.step)
145 | tb_logger.log_value('t-loss_cl', loss_cl.mean().item(), step=model.step)
146 | tb_logger.log_value('t-loss_topo', loss_topo.mean().item(), step=model.step)
147 | return r_sum
148 |
149 |
150 | def init_logging(log_file_path):
151 | logger = logging.getLogger(__name__)
152 | logger.setLevel(level=logging.DEBUG)
153 | formatter = logging.Formatter('%(asctime)s %(message)s')
154 | file_handler = logging.FileHandler(log_file_path)
155 | file_handler.setLevel(level=logging.INFO)
156 | file_handler.setFormatter(formatter)
157 | stream_handler = logging.StreamHandler()
158 | stream_handler.setLevel(logging.DEBUG)
159 | stream_handler.setFormatter(formatter)
160 | logger.addHandler(file_handler)
161 | logger.addHandler(stream_handler)
162 | return logger
163 |
164 |
165 | if __name__ == '__main__':
166 | opt = opts.parse_opt()
167 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
168 | tb_logger.configure(opt.log_dir, flush_secs=5)
169 | logger = init_logging(opt.log_dir + '/log.txt')
170 | logger.info(opt)
171 | logger.info(f"=> PID:{os.getpid()}, GPU:[{opt.gpu}], Noise ratio: {opt.noise_ratio}")
172 | logger.info(f"=> Log save path: '{opt.log_dir}'")
173 | logger.info(f"=> Checkpoint save path: '{opt.checkpoint_dir}'")
174 | # Load Vocabulary
175 | logger.info(f"=> Load vocabulary from '{opt.vocab_path}'")
176 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
177 | opt.vocab_size = len(vocab)
178 | # Load data loaders
179 | logger.info(f"=> Load loaders from '{opt.data_path}/{opt.data_name}'")
180 | train_loader, val_loader, test_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size,
181 | opt.workers, opt)
182 | # Construct the model (DECL-)
183 | logger.info(f"=> Similarity model is {opt.module_name}")
184 | model = RINCE(opt)
185 | best_rsum = 0
186 | start_epoch = 0
187 | if opt.warmup_if:
188 | if os.path.isfile(opt.warmup_model_path):
189 | checkpoint = torch.load(opt.warmup_model_path)
190 | model.load_state_dict(checkpoint['model'])
191 | logger.info(
192 | "=> Load warmup(pre-) checkpoint '{}' (epoch {})".format(opt.warmup_model_path, checkpoint['epoch']))
193 | if 'best_rsum' in checkpoint:
194 | if 'warmup' not in opt.warmup_model_path:
195 | start_epoch = checkpoint['epoch'] + 1
196 | best_rsum = checkpoint['best_rsum']
197 | model.step = checkpoint['step']
198 | else:
199 | logger.info(f"=> no checkpoint found at '{opt.warmup_model_path}', warmup start!")
200 | for e in range(opt.warmup_epochs):
201 | train(opt, train_loader, val_loader, model, e, mode='warmup')
202 | save_checkpoint({
203 | 'epoch': e,
204 | 'model': model.state_dict(),
205 | 'opt': opt,
206 | 'step': model.step
207 | }, is_best=False, filename='warmup_model_{}.pth.tar'.format(e), prefix=opt.checkpoint_dir + '/')
208 | else:
209 | logger.info("=> No warmup stage")
210 |
211 | for epoch in range(start_epoch, opt.num_epochs):
212 | adjust_learning_rate(opt, model.optimizer, epoch)
213 | start_time = datetime.now()
214 | train(opt, train_loader, val_loader, model, epoch, mode='train', best_rsum=best_rsum)
215 | end_time = datetime.now()
216 | tb_logger.log_value('cost_time', int((end_time - start_time).seconds), step=epoch)
217 | validation(opt, test_loader, model, test=True)
218 | rsum = validation(opt, val_loader, model)
219 | is_best = rsum > best_rsum
220 | best_rsum = max(rsum, best_rsum)
221 | save_checkpoint({
222 | 'epoch': epoch,
223 | 'model': model.state_dict(),
224 | 'best_rsum': best_rsum,
225 | 'opt': opt,
226 | 'step': model.step,
227 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.checkpoint_dir + '/')
228 |
--------------------------------------------------------------------------------
/GSC+/train.py:
--------------------------------------------------------------------------------
1 | """Training script"""
2 |
3 | import logging
4 | import os
5 | import time
6 | from datetime import datetime
7 | import numpy as np
8 | import torch
9 | from sklearn.metrics import confusion_matrix
10 | import tensorboard_logger as tb_logger
11 | import data
12 | import opts
13 | from model.RINCE import RINCE
14 | from evaluation import encode_data, shard_attn_scores, i2t, t2i, AverageMeter, LogCollector
15 | from utils import save_checkpoint
16 | from vocab import deserialize_vocab
17 | import warnings
18 |
19 | warnings.filterwarnings("ignore")
20 |
21 | def adjust_learning_rate(opt, optimizer, epoch):
22 | """
23 | Sets the learning rate to the initial LR
24 | decayed by 10 after opt.lr_update epoch
25 | """
26 | lr = opt.learning_rate * (0.2 ** (epoch // opt.lr_update))
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = lr
29 |
30 |
31 | def train(opt, data_loader, val_loader, model_A, model_B, epoch, preds=None, mode='warmup', best_rsum=0):
32 | batch_time = AverageMeter()
33 | data_time = AverageMeter()
34 | train_logger = LogCollector()
35 | num_loader_iter = len(train_loader.dataset) // train_loader.batch_size + 1
36 | end = time.time()
37 | logger.info("=> {mode} epoch: {0}".format(epoch, mode=mode))
38 | for i, (images, captions, lengths, ids, corrs) in enumerate(data_loader):
39 | # print(i)
40 | if images.size(0) == 1:
41 | break
42 | model_A.train_start()
43 | model_B.train_start()
44 | data_time.update(time.time() - end)
45 | model_A.logger = train_logger
46 | model_B.logger = train_logger
47 | ids = torch.tensor(ids).cuda()
48 | corrs = torch.tensor(corrs).cuda()
49 | if mode == 'warmup':
50 | model_A.warmup_batch(images, captions, lengths)
51 | model_B.warmup_batch(images, captions, lengths)
52 | else:
53 | model_A.train_batch(images, captions, lengths, ids, corrs, model_B.similarity_model.targets[ids])
54 | model_B.train_batch(images, captions, lengths, ids, corrs, model_A.similarity_model.targets[ids])
55 | batch_time.update(time.time() - end)
56 | if model_A.step % opt.log_step == 0:
57 | logger.info(
58 | 'Epoch ({mode}): [{0}][{1}/{2}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
59 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t{loss}'.format(epoch, i, num_loader_iter,
60 | mode=mode,
61 | batch_time=batch_time,
62 | data_time=data_time,
63 | loss=str(model_A.logger)))
64 | # Record logs in tensorboard
65 | tb_logger.log_value('epoch', epoch, step=model_A.step)
66 | tb_logger.log_value('step', i, step=model_A.step)
67 | tb_logger.log_value('batch_time', batch_time.val, step=model_A.step)
68 | tb_logger.log_value('data_time', data_time.val, step=model_A.step)
69 | model_A.logger.tb_log(tb_logger, step=model_A.step)
70 | model_B.logger.tb_log(tb_logger, step=model_B.step)
71 |
72 | if mode != 'warmup':
73 | model_A.split_batch(train_loader.dataset.corrs)
74 | model_B.split_batch(train_loader.dataset.corrs)
75 |
76 |
77 | def validation(opt, val_loader, model_A, model_B, test=False):
78 | # compute the encoding for all the validation images and captions
79 | if opt.data_name == 'cc152k_precomp':
80 | per_captions = 1
81 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
82 | per_captions = 5
83 | else:
84 | logger.info(f"No dataset")
85 | return 0
86 | if test:
87 | logger.info(f"=> Test")
88 | else:
89 | logger.info(f"=> Validation")
90 | # compute the encoding for all the validation images and captions
91 | img_embs_A, cap_embs_A, cap_lens_A = encode_data(model_A.similarity_model, val_loader, opt.log_step)
92 | img_embs_A = np.array([img_embs_A[i] for i in range(0, len(img_embs_A), per_captions)])
93 |
94 | img_embs_B, cap_embs_B, cap_lens_B = encode_data(model_B.similarity_model, val_loader, opt.log_step)
95 | img_embs_B = np.array([img_embs_B[i] for i in range(0, len(img_embs_B), per_captions)])
96 |
97 |
98 | # record computation time of validation
99 | start = time.time()
100 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_A, cap_embs_A, cap_lens_A, opt, shard_size=1000)
101 | end = time.time()
102 | logger.info(f"calculate similarity time: {end - start}")
103 |
104 | start = time.time()
105 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_B, cap_embs_B, cap_lens_B, opt, shard_size=1000)
106 | end = time.time()
107 | logger.info(f"calculate similarity time: {end - start}")
108 | sims = (sims_A + sims_B) / 2
109 |
110 | # caption retrieval
111 | (r1, r5, r10, medr, meanr) = i2t(img_embs_A.shape[0], sims, per_captions, return_ranks=False)
112 | logger.info("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
113 | logger.info("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
114 | # image retrieval
115 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs_A.shape[0], sims, per_captions, return_ranks=False)
116 | logger.info("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
117 | logger.info("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
118 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
119 |
120 | logger.info("Sum of Recall: %.2f" % (r_sum))
121 | # sum of recalls to be used for early stopping
122 | if test:
123 | # record metrics in tensorboard
124 | tb_logger.log_value('r1', r1, step=model_A.step)
125 | tb_logger.log_value('r5', r5, step=model_A.step)
126 | tb_logger.log_value('r10', r10, step=model_A.step)
127 | tb_logger.log_value('medr', medr, step=model_A.step)
128 | tb_logger.log_value('meanr', meanr, step=model_A.step)
129 | tb_logger.log_value('r1i', r1i, step=model_A.step)
130 | tb_logger.log_value('r5i', r5i, step=model_A.step)
131 | tb_logger.log_value('r10i', r10i, step=model_A.step)
132 | tb_logger.log_value('medri', medri, step=model_A.step)
133 | tb_logger.log_value('meanr', meanr, step=model_A.step)
134 | tb_logger.log_value('r_sum', r_sum, step=model_A.step)
135 | else:
136 | # record metrics in tensorboard
137 | tb_logger.log_value('t-r1', r1, step=model_A.step)
138 | tb_logger.log_value('t-r5', r5, step=model_A.step)
139 | tb_logger.log_value('t-r10', r10, step=model_A.step)
140 | tb_logger.log_value('t-medr', medr, step=model_A.step)
141 | tb_logger.log_value('t-meanr', meanr, step=model_A.step)
142 | tb_logger.log_value('t-r1i', r1i, step=model_A.step)
143 | tb_logger.log_value('t-r5i', r5i, step=model_A.step)
144 | tb_logger.log_value('t-r10i', r10i, step=model_A.step)
145 | tb_logger.log_value('t-medri', medri, step=model_A.step)
146 | tb_logger.log_value('t-meanr', meanr, step=model_A.step)
147 | tb_logger.log_value('t-r_sum', r_sum, step=model_A.step)
148 | return r_sum
149 |
150 | def init_logging(log_file_path):
151 | logger = logging.getLogger(__name__)
152 | logger.setLevel(level=logging.DEBUG)
153 | formatter = logging.Formatter('%(asctime)s %(message)s')
154 | file_handler = logging.FileHandler(log_file_path)
155 | file_handler.setLevel(level=logging.INFO)
156 | file_handler.setFormatter(formatter)
157 | stream_handler = logging.StreamHandler()
158 | stream_handler.setLevel(logging.DEBUG)
159 | stream_handler.setFormatter(formatter)
160 | logger.addHandler(file_handler)
161 | logger.addHandler(stream_handler)
162 | return logger
163 |
164 |
165 | if __name__ == '__main__':
166 | opt = opts.parse_opt()
167 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
168 | tb_logger.configure(opt.log_dir, flush_secs=5)
169 | logger = init_logging(opt.log_dir + '/log.txt')
170 | logger.info(opt)
171 | logger.info(f"=> PID:{os.getpid()}, GPU:[{opt.gpu}], Noise ratio: {opt.noise_ratio}")
172 | logger.info(f"=> Log save path: '{opt.log_dir}'")
173 | logger.info(f"=> Checkpoint save path: '{opt.checkpoint_dir}'")
174 | # Load Vocabulary
175 | logger.info(f"=> Load vocabulary from '{opt.vocab_path}'")
176 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
177 | opt.vocab_size = len(vocab)
178 | # Load data loaders
179 | logger.info(f"=> Load loaders from '{opt.data_path}/{opt.data_name}'")
180 | train_loader, val_loader, test_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size,
181 | opt.workers, opt)
182 | # Construct the model (DECL-)
183 | logger.info(f"=> Similarity model is {opt.module_name}")
184 | model_A = RINCE(opt)
185 | model_B = RINCE(opt)
186 | best_rsum = 0
187 | start_epoch = 0
188 | if opt.warmup_if:
189 | if os.path.isfile(opt.warmup_model_path):
190 | checkpoint = torch.load(opt.warmup_model_path)
191 | model_A.load_state_dict(checkpoint['model_A'])
192 | model_B.load_state_dict(checkpoint['model_B'])
193 | logger.info(
194 | "=> Load warmup(pre-) checkpoint '{}' (epoch {})".format(opt.warmup_model_path, checkpoint['epoch']))
195 | if 'best_rsum' in checkpoint:
196 | if 'warmup' not in opt.warmup_model_path:
197 | start_epoch = checkpoint['epoch'] + 1
198 | best_rsum = checkpoint['best_rsum']
199 | model_A.step = checkpoint['step']
200 | model_B.step = checkpoint['step']
201 | else:
202 | logger.info(f"=> no checkpoint found at '{opt.warmup_model_path}', warmup start!")
203 | for e in range(opt.warmup_epochs):
204 | train(opt, train_loader, val_loader, model_A, model_B, e, mode='warmup')
205 | save_checkpoint({
206 | 'epoch': e,
207 | 'model_A': model_A.state_dict(),
208 | 'model_B': model_B.state_dict(),
209 | 'opt': opt,
210 | 'step': model_A.step
211 | }, is_best=False, filename='warmup_model_{}.pth.tar'.format(e), prefix=opt.checkpoint_dir + '/')
212 | else:
213 | logger.info("=> No warmup stage")
214 |
215 | for epoch in range(start_epoch, opt.num_epochs):
216 | adjust_learning_rate(opt, model_A.optimizer, epoch)
217 | adjust_learning_rate(opt, model_B.optimizer, epoch)
218 | start_time = datetime.now()
219 | train(opt, train_loader, val_loader, model_A, model_B, epoch, mode='train', best_rsum=best_rsum)
220 | end_time = datetime.now()
221 | tb_logger.log_value('cost_time', int((end_time - start_time).seconds), step=epoch)
222 | validation(opt, test_loader, model_A, model_B, test=True)
223 | rsum = validation(opt, val_loader, model_A, model_B)
224 | is_best = rsum > best_rsum
225 | best_rsum = max(rsum, best_rsum)
226 | save_checkpoint({
227 | 'epoch': epoch,
228 | 'model_A': model_A.state_dict(),
229 | 'model_B': model_B.state_dict(),
230 | 'best_rsum': best_rsum,
231 | 'opt': opt,
232 | 'step': model_A.step,
233 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.checkpoint_dir + '/')
234 |
--------------------------------------------------------------------------------
/GSC/evaluation.py:
--------------------------------------------------------------------------------
1 | """Evaluation"""
2 |
3 | from __future__ import print_function
4 | import os
5 | import sys
6 | import torch
7 | import numpy as np
8 |
9 | from model.RINCE import RINCE
10 | from data import get_test_loader
11 | from vocab import deserialize_vocab
12 | from collections import OrderedDict
13 |
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 |
18 | def __init__(self):
19 | self.reset()
20 |
21 | def reset(self):
22 | self.val = 0
23 | self.avg = 0
24 | self.sum = 0
25 | self.count = 0
26 |
27 | def update(self, val, n=0):
28 | self.val = val
29 | self.sum += val * n
30 | self.count += n
31 | self.avg = self.sum / (.0001 + self.count)
32 |
33 | def __str__(self):
34 | """String representation for logging
35 | """
36 | # for values that should be recorded exactly e.g. iteration number
37 | if self.count == 0:
38 | return str(self.val)
39 | # for stats
40 | return '%.4f (%.4f)' % (self.val, self.avg)
41 |
42 |
43 | class LogCollector(object):
44 | """A collection of logging objects that can change from train to val"""
45 |
46 | def __init__(self):
47 | # to keep the order of logged variables deterministic
48 | self.meters = OrderedDict()
49 |
50 | def update(self, k, v, n=0):
51 | # create a new meter if previously not recorded
52 | if k not in self.meters:
53 | self.meters[k] = AverageMeter()
54 | self.meters[k].update(v, n)
55 |
56 | def __str__(self):
57 | """Concatenate the meters in one log line
58 | """
59 | s = ''
60 | for i, (k, v) in enumerate(self.meters.items()):
61 | if i > 0:
62 | s += ' '
63 | s += k + ' ' + str(v)
64 | return s
65 |
66 | def tb_log(self, tb_logger, prefix='', step=None):
67 | """Log using tensorboard
68 | """
69 | for k, v in self.meters.items():
70 | tb_logger.log_value(prefix + k, v.val, step=step)
71 |
72 |
73 | def encode_data(model, data_loader, log_step=10, logging=print, sub=0):
74 | """Encode all images and captions loadable by `data_loader`
75 | """
76 | val_logger = LogCollector()
77 |
78 | # switch to evaluate mode
79 | model.val_start()
80 |
81 | # np array to keep all the embeddings
82 | img_embs = None
83 | cap_embs = None
84 |
85 | max_n_word = 0
86 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader):
87 | max_n_word = max(max_n_word, max(lengths))
88 | ids_ = []
89 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader):
90 | # make sure val logger is used
91 | model.logger = val_logger
92 | ids_ += ids
93 | # compute the embeddings
94 | with torch.no_grad():
95 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths)
96 | if img_embs is None:
97 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
98 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
99 | cap_lens = [0] * len(data_loader.dataset)
100 | # cache embeddings
101 | img_embs[ids, :, :] = img_emb.data.cpu().numpy().copy()
102 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy()
103 |
104 | for j, nid in enumerate(ids):
105 | cap_lens[nid] = cap_len[j]
106 |
107 | del images, captions
108 | if sub > 0:
109 | print(f"===>batch {i}")
110 | if sub > 0 and i > sub:
111 | break
112 | if sub > 0:
113 | return np.array(img_embs)[ids_].tolist(), np.array(cap_embs)[ids_].tolist(), np.array(cap_lens)[
114 | ids_].tolist(), ids_
115 | else:
116 | return img_embs, cap_embs, cap_lens
117 |
118 |
119 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100, mode="sim"):
120 | n_im_shard = (len(img_embs) - 1) // shard_size + 1
121 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1
122 |
123 | sims = np.zeros((len(img_embs), len(cap_embs)))
124 | for i in range(n_im_shard):
125 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs))
126 | for j in range(n_cap_shard):
127 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j))
128 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs))
129 | with torch.no_grad():
130 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda()
131 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda()
132 | l = cap_lens[ca_start:ca_end]
133 | if mode == "sim":
134 | sim = model.forward_sim(im, ca, l, mode)
135 | else:
136 | _, sim, _ = model.forward_sim(im, ca, l, mode) # Calculate evidence for retrieval
137 |
138 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy()
139 | sys.stdout.write('\n')
140 | return sims
141 |
142 |
143 | def t2i(npts, sims, per_captions=1, return_ranks=False):
144 | """
145 | Text->Images (Image Search)
146 | Images: (N, n_region, d) matrix of images
147 | Captions: (per_captions * N, max_n_word, d) matrix of captions
148 | CapLens: (per_captions * N) array of caption lengths
149 | sims: (N, per_captions * N) matrix of similarity im-cap
150 | """
151 | ranks = np.zeros(per_captions * npts)
152 | top1 = np.zeros(per_captions * npts)
153 | top5 = np.zeros((per_captions * npts, 5), dtype=int)
154 |
155 | # --> (per_captions * N(caption), N(image))
156 | sims = sims.T
157 | retreivaled_index = []
158 | for index in range(npts):
159 | for i in range(per_captions):
160 | inds = np.argsort(sims[per_captions * index + i])[::-1]
161 | retreivaled_index.append(inds)
162 | ranks[per_captions * index + i] = np.where(inds == index)[0][0]
163 | top1[per_captions * index + i] = inds[0]
164 | top5[per_captions * index + i] = inds[0:5]
165 |
166 | # Compute metrics
167 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
168 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
169 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
170 | medr = np.floor(np.median(ranks)) + 1
171 | meanr = ranks.mean() + 1
172 | if return_ranks:
173 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index)
174 | else:
175 | return (r1, r5, r10, medr, meanr)
176 |
177 |
178 | def i2t(npts, sims, per_captions=1, return_ranks=False):
179 | """
180 | Images->Text (Image Annotation)
181 | Images: (N, n_region, d) matrix of images
182 | Captions: (per_captions * N, max_n_word, d) matrix of captions
183 | CapLens: (per_captions * N) array of caption lengths
184 | sims: (N, per_captions * N) matrix of similarity im-cap
185 | """
186 | ranks = np.zeros(npts)
187 | top1 = np.zeros(npts)
188 | top5 = np.zeros((npts, 5), dtype=int)
189 | retreivaled_index = []
190 | for index in range(npts):
191 | inds = np.argsort(sims[index])[::-1]
192 | retreivaled_index.append(inds)
193 | # Score
194 | rank = 1e20
195 | for i in range(per_captions * index, per_captions * index + per_captions, 1):
196 | tmp = np.where(inds == i)[0][0]
197 | if tmp < rank:
198 | rank = tmp
199 | ranks[index] = rank
200 | top1[index] = inds[0]
201 | top5[index] = inds[0:5]
202 |
203 | # Compute metrics
204 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
205 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
206 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
207 | medr = np.floor(np.median(ranks)) + 1
208 | meanr = ranks.mean() + 1
209 | if return_ranks:
210 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index)
211 | else:
212 | return (r1, r5, r10, medr, meanr)
213 |
214 |
215 | def validation(opt, val_loader, model, fold=False):
216 | # compute the encoding for all the validation images and captions
217 | if opt.data_name == 'cc152k_precomp':
218 | per_captions = 1
219 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
220 | per_captions = 5
221 | else:
222 | print(f"No dataset")
223 | return 0
224 |
225 | model.val_start()
226 | print('Encoding with model')
227 | img_embs, cap_embs, cap_lens = encode_data(model.similarity_model, val_loader)
228 | # clear duplicate 5*images and keep 1*images FIXME
229 | if not fold:
230 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)])
231 | # record computation time of validation
232 | print('Computing similarity from model')
233 | sims_mean = shard_attn_scores(model.similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000,
234 | mode="sim")
235 | print("Calculate similarity time with model")
236 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
237 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
238 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
239 | # image retrieval
240 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
241 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
242 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
243 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
244 | print("Sum of Recall: %.2f" % (r_sum))
245 | else:
246 | # 5fold cross-validation, only for MSCOCO
247 | results = []
248 | for i in range(5):
249 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
250 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
251 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
252 | sims = shard_attn_scores(model.similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt,
253 | shard_size=1000,
254 | mode="not sim")
255 |
256 | print('Computing similarity from model')
257 | r, rt = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
258 | ri, rti = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
259 |
260 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
261 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
262 | ar = (r[0] + r[1] + r[2]) / 3
263 | ari = (ri[0] + ri[1] + ri[2]) / 3
264 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
265 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
266 | results += [list(r) + list(ri) + [ar, ari, rsum]]
267 | print("-----------------------------------")
268 | print("Mean metrics: ")
269 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
270 | a = np.array(mean_metrics)
271 | print("Average i2t Recall: %.1f" % mean_metrics[11])
272 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
273 | mean_metrics[:5])
274 | print("Average t2i Recall: %.1f" % mean_metrics[12])
275 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
276 | mean_metrics[5:10])
277 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum()))
278 |
279 |
280 | def validation_dul(opt, val_loader, models, fold=False):
281 | # compute the encoding for all the validation images and captions
282 | if opt.data_name == 'cc152k_precomp':
283 | per_captions = 1
284 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
285 | per_captions = 5
286 | else:
287 | print(f"No dataset")
288 | return 0
289 |
290 | models[0].val_start()
291 | models[1].val_start()
292 | print('Encoding with model')
293 | img_embs, cap_embs, cap_lens = encode_data(models[0].similarity_model, val_loader, opt.log_step)
294 | img_embs1, cap_embs1, cap_lens1 = encode_data(models[1].similarity_model, val_loader, opt.log_step)
295 | if not fold:
296 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)])
297 | img_embs1 = np.array([img_embs1[i] for i in range(0, len(img_embs1), per_captions)])
298 | # record computation time of validation
299 | print('Computing similarity from model')
300 | sims_mean = shard_attn_scores(models[0].similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000,
301 | mode="sim")
302 | sims_mean += shard_attn_scores(models[1].similarity_model, img_embs1, cap_embs1, cap_lens1, opt,
303 | shard_size=1000, mode="sim")
304 | sims_mean /= 2
305 | print("Calculate similarity time with model")
306 | # caption retrieval
307 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
308 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
309 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
310 | # image retrieval
311 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
312 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
313 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
314 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
315 | print("Sum of Recall: %.2f" % (r_sum))
316 | return r_sum
317 | else:
318 | # 5fold cross-validation, only for MSCOCO
319 | results = []
320 | for i in range(5):
321 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
322 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
323 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
324 |
325 | img_embs_shard1 = img_embs1[i * 5000:(i + 1) * 5000:5]
326 | cap_embs_shard1 = cap_embs1[i * 5000:(i + 1) * 5000]
327 | cap_lens_shard1 = cap_lens1[i * 5000:(i + 1) * 5000]
328 | sims = shard_attn_scores(models[0].similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt,
329 | shard_size=1000,
330 | mode="not sim")
331 | sims += shard_attn_scores(models[1].similarity_model, img_embs_shard1, cap_embs_shard1, cap_lens_shard1,
332 | opt,
333 | shard_size=1000,
334 | mode="not sim")
335 | sims /= 2
336 |
337 | print('Computing similarity from model')
338 | r, rt0 = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
339 | ri, rti0 = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
340 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
341 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
342 | if i == 0:
343 | rt, rti = rt0, rti0
344 | ar = (r[0] + r[1] + r[2]) / 3
345 | ari = (ri[0] + ri[1] + ri[2]) / 3
346 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
347 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
348 | results += [list(r) + list(ri) + [ar, ari, rsum]]
349 |
350 | print("-----------------------------------")
351 | print("Mean metrics: ")
352 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
353 | a = np.array(mean_metrics)
354 |
355 | print("Average i2t Recall: %.1f" % mean_metrics[11])
356 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
357 | mean_metrics[:5])
358 | print("Average t2i Recall: %.1f" % mean_metrics[12])
359 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
360 | mean_metrics[5:10])
361 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum()))
362 |
363 |
364 | def eval_DECL(checkpoint_paths, avg_SGRAF=True, data_path=None, vocab_path=None):
365 | if avg_SGRAF is False:
366 | print(f"Load checkpoint from '{checkpoint_paths[0]}'")
367 | checkpoint = torch.load(checkpoint_paths[0])
368 | opt = checkpoint['opt']
369 | opt.ssl = False
370 | print(
371 | f"Noise ratio is {opt.noise_ratio}, module is {opt.module_name}, best validation epoch is {checkpoint['epoch']} ({checkpoint['best_rsum']})")
372 | if vocab_path != None:
373 | opt.vocab_path = vocab_path
374 | if data_path != None:
375 | opt.data_path = data_path
376 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
377 | model = RINCE(opt)
378 | model.load_state_dict(checkpoint['model'])
379 | if 'coco' in opt.data_name:
380 | test_loader = get_test_loader('testall', opt.data_name, vocab, 100, 0, opt)
381 | validation(opt, test_loader, model=model, fold=True)
382 | validation(opt, test_loader, model=model, fold=False)
383 | else:
384 | test_loader = get_test_loader('test', opt.data_name, vocab, 100, 0, opt)
385 | validation(opt, test_loader, model=model, fold=False)
386 | else:
387 | assert len(checkpoint_paths) == 2
388 | print(f"Load checkpoint from '{checkpoint_paths}'")
389 | checkpoint0 = torch.load(checkpoint_paths[0])
390 | checkpoint1 = torch.load(checkpoint_paths[1])
391 | opt0 = checkpoint0['opt']
392 | opt1 = checkpoint1['opt']
393 | print(
394 | f"Noise ratios are {opt0.noise_ratio} and {opt1.noise_ratio}, "
395 | f"modules are {opt0.module_name} and {opt1.module_name}, best validation epochs are {checkpoint0['epoch']}"
396 | f" ({checkpoint0['best_rsum']}) and {checkpoint1['epoch']} ({checkpoint1['best_rsum']})")
397 | vocab = deserialize_vocab(os.path.join(opt0.vocab_path, '%s_vocab.json' % opt0.data_name))
398 | model0 = RINCE(opt0)
399 | model1 = RINCE(opt1)
400 |
401 | model0.load_state_dict(checkpoint0['model'])
402 | model1.load_state_dict(checkpoint1['model'])
403 | if 'coco' in opt0.data_name:
404 | test_loader = get_test_loader('testall', opt0.data_name, vocab, 100, 0, opt0)
405 | print(f'=====>model {opt0.module_name} fold:True')
406 | validation(opt0, test_loader, model0, fold=True)
407 | print(f'=====>model {opt1.module_name} fold:True')
408 | validation(opt0, test_loader, model1, fold=True)
409 | print(f'=====>model SGRAF fold:True')
410 | validation_dul(opt0, test_loader, models=[model0, model1], fold=True)
411 |
412 | print(f'=====>model {opt0.module_name} fold:False')
413 | validation(opt0, test_loader, model0, fold=False)
414 | print(f'=====>model {opt1.module_name} fold:False')
415 | validation(opt0, test_loader, model1, fold=False)
416 | print('=====>model SGRAF fold:False')
417 | validation_dul(opt0, test_loader, models=[model0, model1], fold=False)
418 |
419 | else:
420 | test_loader = get_test_loader('test', opt0.data_name, vocab, 100, 0, opt0)
421 | print(f'=====>model {opt0.module_name} fold:False')
422 | validation(opt0, test_loader, model0, fold=False)
423 | print(f'=====>model {opt1.module_name} fold:False')
424 | validation(opt0, test_loader, model1, fold=False)
425 | print('=====>model SGRAF fold:False')
426 | validation_dul(opt0, test_loader, models=[model0, model1], fold=False)
427 |
--------------------------------------------------------------------------------
/GSC+/evaluation.py:
--------------------------------------------------------------------------------
1 | """Evaluation"""
2 |
3 | from __future__ import print_function
4 | import os
5 | import sys
6 | import torch
7 | import numpy as np
8 |
9 | from model.RINCE import RINCE
10 | from data import get_test_loader
11 | from vocab import deserialize_vocab
12 | from collections import OrderedDict
13 |
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 |
18 | def __init__(self):
19 | self.reset()
20 |
21 | def reset(self):
22 | self.val = 0
23 | self.avg = 0
24 | self.sum = 0
25 | self.count = 0
26 |
27 | def update(self, val, n=0):
28 | self.val = val
29 | self.sum += val * n
30 | self.count += n
31 | self.avg = self.sum / (.0001 + self.count)
32 |
33 | def __str__(self):
34 | """String representation for logging
35 | """
36 | # for values that should be recorded exactly e.g. iteration number
37 | if self.count == 0:
38 | return str(self.val)
39 | # for stats
40 | return '%.4f (%.4f)' % (self.val, self.avg)
41 |
42 |
43 | class LogCollector(object):
44 | """A collection of logging objects that can change from train to val"""
45 |
46 | def __init__(self):
47 | # to keep the order of logged variables deterministic
48 | self.meters = OrderedDict()
49 |
50 | def update(self, k, v, n=0):
51 | # create a new meter if previously not recorded
52 | if k not in self.meters:
53 | self.meters[k] = AverageMeter()
54 | self.meters[k].update(v, n)
55 |
56 | def __str__(self):
57 | """Concatenate the meters in one log line
58 | """
59 | s = ''
60 | for i, (k, v) in enumerate(self.meters.items()):
61 | if i > 0:
62 | s += ' '
63 | s += k + ' ' + str(v)
64 | return s
65 |
66 | def tb_log(self, tb_logger, prefix='', step=None):
67 | """Log using tensorboard
68 | """
69 | for k, v in self.meters.items():
70 | tb_logger.log_value(prefix + k, v.val, step=step)
71 |
72 |
73 | def encode_data(model, data_loader, log_step=10, logging=print, sub=0):
74 | """Encode all images and captions loadable by `data_loader`
75 | """
76 | val_logger = LogCollector()
77 |
78 | # switch to evaluate mode
79 | model.val_start()
80 |
81 | # np array to keep all the embeddings
82 | img_embs = None
83 | cap_embs = None
84 |
85 | max_n_word = 0
86 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader):
87 | max_n_word = max(max_n_word, max(lengths))
88 | ids_ = []
89 | for i, (images, captions, lengths, ids, _) in enumerate(data_loader):
90 | # make sure val logger is used
91 | model.logger = val_logger
92 | ids_ += ids
93 | # compute the embeddings
94 | with torch.no_grad():
95 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths)
96 | if img_embs is None:
97 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2)))
98 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2)))
99 | cap_lens = [0] * len(data_loader.dataset)
100 | # cache embeddings
101 | img_embs[ids, :, :] = img_emb.data.cpu().numpy().copy()
102 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy()
103 |
104 | for j, nid in enumerate(ids):
105 | cap_lens[nid] = cap_len[j]
106 |
107 | del images, captions
108 | if sub > 0:
109 | print(f"===>batch {i}")
110 | if sub > 0 and i > sub:
111 | break
112 | if sub > 0:
113 | return np.array(img_embs)[ids_].tolist(), np.array(cap_embs)[ids_].tolist(), np.array(cap_lens)[
114 | ids_].tolist(), ids_
115 | else:
116 | return img_embs, cap_embs, cap_lens
117 |
118 |
119 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100, mode="sim"):
120 | n_im_shard = (len(img_embs) - 1) // shard_size + 1
121 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1
122 |
123 | sims = np.zeros((len(img_embs), len(cap_embs)))
124 | for i in range(n_im_shard):
125 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs))
126 | for j in range(n_cap_shard):
127 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j))
128 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs))
129 | with torch.no_grad():
130 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda()
131 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda()
132 | l = cap_lens[ca_start:ca_end]
133 | if mode == "sim":
134 | sim = model.forward_sim(im, ca, l, mode)
135 | else:
136 | _, sim, _ = model.forward_sim(im, ca, l, mode) # Calculate evidence for retrieval
137 |
138 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy()
139 | sys.stdout.write('\n')
140 | return sims
141 |
142 |
143 | def t2i(npts, sims, per_captions=1, return_ranks=False):
144 | """
145 | Text->Images (Image Search)
146 | Images: (N, n_region, d) matrix of images
147 | Captions: (per_captions * N, max_n_word, d) matrix of captions
148 | CapLens: (per_captions * N) array of caption lengths
149 | sims: (N, per_captions * N) matrix of similarity im-cap
150 | """
151 | ranks = np.zeros(per_captions * npts)
152 | top1 = np.zeros(per_captions * npts)
153 | top5 = np.zeros((per_captions * npts, 5), dtype=int)
154 |
155 | # --> (per_captions * N(caption), N(image))
156 | sims = sims.T
157 | retreivaled_index = []
158 | for index in range(npts):
159 | for i in range(per_captions):
160 | inds = np.argsort(sims[per_captions * index + i])[::-1]
161 | retreivaled_index.append(inds)
162 | ranks[per_captions * index + i] = np.where(inds == index)[0][0]
163 | top1[per_captions * index + i] = inds[0]
164 | top5[per_captions * index + i] = inds[0:5]
165 |
166 | # Compute metrics
167 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
168 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
169 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
170 | medr = np.floor(np.median(ranks)) + 1
171 | meanr = ranks.mean() + 1
172 | if return_ranks:
173 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index)
174 | else:
175 | return (r1, r5, r10, medr, meanr)
176 |
177 |
178 | def i2t(npts, sims, per_captions=1, return_ranks=False):
179 | """
180 | Images->Text (Image Annotation)
181 | Images: (N, n_region, d) matrix of images
182 | Captions: (per_captions * N, max_n_word, d) matrix of captions
183 | CapLens: (per_captions * N) array of caption lengths
184 | sims: (N, per_captions * N) matrix of similarity im-cap
185 | """
186 | ranks = np.zeros(npts)
187 | top1 = np.zeros(npts)
188 | top5 = np.zeros((npts, 5), dtype=int)
189 | retreivaled_index = []
190 | for index in range(npts):
191 | inds = np.argsort(sims[index])[::-1]
192 | retreivaled_index.append(inds)
193 | # Score
194 | rank = 1e20
195 | for i in range(per_captions * index, per_captions * index + per_captions, 1):
196 | tmp = np.where(inds == i)[0][0]
197 | if tmp < rank:
198 | rank = tmp
199 | ranks[index] = rank
200 | top1[index] = inds[0]
201 | top5[index] = inds[0:5]
202 |
203 | # Compute metrics
204 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
205 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
206 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
207 | medr = np.floor(np.median(ranks)) + 1
208 | meanr = ranks.mean() + 1
209 | if return_ranks:
210 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index)
211 | else:
212 | return (r1, r5, r10, medr, meanr)
213 |
214 |
215 | def validation(opt, val_loader, model_A, model_B, fold=False):
216 | # compute the encoding for all the validation images and captions
217 | if opt.data_name == 'cc152k_precomp':
218 | per_captions = 1
219 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
220 | per_captions = 5
221 | else:
222 | print(f"No dataset")
223 | return 0
224 |
225 | model_A.val_start()
226 | model_B.val_start()
227 | print('Encoding with model')
228 | img_embs_A, cap_embs_A, cap_lens_A = encode_data(model_A.similarity_model, val_loader, opt.log_step)
229 | if not fold:
230 | img_embs_A = np.array([img_embs_A[i] for i in range(0, len(img_embs_A), per_captions)])
231 |
232 | img_embs_B, cap_embs_B, cap_lens_B = encode_data(model_B.similarity_model, val_loader, opt.log_step)
233 | if not fold:
234 | img_embs_B = np.array([img_embs_B[i] for i in range(0, len(img_embs_B), per_captions)])
235 |
236 | if not fold:
237 | # record computation time of validation
238 | print('Computing similarity from model')
239 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_A, cap_embs_A, cap_lens_A, opt, shard_size=1000)
240 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_B, cap_embs_B, cap_lens_B, opt, shard_size=1000)
241 | sims = (sims_A + sims_B) / 2
242 | # sims = sims_B
243 | print("Calculate similarity time with model")
244 | (r1, r5, r10, medr, meanr) = i2t(img_embs_A.shape[0], sims, per_captions, return_ranks=False)
245 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
246 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
247 | # image retrieval
248 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs_A.shape[0], sims, per_captions, return_ranks=False)
249 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
250 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
251 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
252 | print("Sum of Recall: %.2f" % (r_sum))
253 | else:
254 | # 5fold cross-validation, only for MSCOCO
255 | results = []
256 | for i in range(5):
257 | img_embs_shard_A = img_embs_A[i * 5000:(i + 1) * 5000:5]
258 | cap_embs_shard_A = cap_embs_A[i * 5000:(i + 1) * 5000]
259 | cap_lens_shard_A = cap_lens_A[i * 5000:(i + 1) * 5000]
260 | sims_A = shard_attn_scores(model_A.similarity_model, img_embs_shard_A, cap_embs_shard_A, cap_lens_shard_A, opt,
261 | shard_size=1000,
262 | mode="sim")
263 |
264 | img_embs_shard_B = img_embs_B[i * 5000:(i + 1) * 5000:5]
265 | cap_embs_shard_B = cap_embs_B[i * 5000:(i + 1) * 5000]
266 | cap_lens_shard_B = cap_lens_B[i * 5000:(i + 1) * 5000]
267 | sims_B = shard_attn_scores(model_B.similarity_model, img_embs_shard_B, cap_embs_shard_B, cap_lens_shard_B, opt,
268 | shard_size=1000,
269 | mode="sim")
270 |
271 | sims = (sims_A + sims_B) / 2
272 |
273 | print('Computing similarity from model')
274 | r, rt = i2t(img_embs_shard_A.shape[0], sims, per_captions, return_ranks=True)
275 | ri, rti = t2i(img_embs_shard_A.shape[0], sims, per_captions, return_ranks=True)
276 |
277 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
278 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
279 | ar = (r[0] + r[1] + r[2]) / 3
280 | ari = (ri[0] + ri[1] + ri[2]) / 3
281 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
282 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
283 | results += [list(r) + list(ri) + [ar, ari, rsum]]
284 | print("-----------------------------------")
285 | print("Mean metrics: ")
286 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
287 | a = np.array(mean_metrics)
288 | print("Average i2t Recall: %.1f" % mean_metrics[11])
289 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
290 | mean_metrics[:5])
291 | print("Average t2i Recall: %.1f" % mean_metrics[12])
292 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
293 | mean_metrics[5:10])
294 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum()))
295 |
296 |
297 | def validation_dul(opt, val_loader, models, fold=False):
298 | # compute the encoding for all the validation images and captions
299 | if opt.data_name == 'cc152k_precomp':
300 | per_captions = 1
301 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']:
302 | per_captions = 5
303 | else:
304 | print(f"No dataset")
305 | return 0
306 |
307 | models[0].val_start()
308 | models[1].val_start()
309 | models[2].val_start()
310 | models[3].val_start()
311 | print('Encoding with model')
312 | img_embs, cap_embs, cap_lens = encode_data(models[0].similarity_model, val_loader, opt.log_step)
313 | img_embs1, cap_embs1, cap_lens1 = encode_data(models[1].similarity_model, val_loader, opt.log_step)
314 | img_embs2, cap_embs2, cap_lens2 = encode_data(models[2].similarity_model, val_loader, opt.log_step)
315 | img_embs3, cap_embs3, cap_lens3 = encode_data(models[3].similarity_model, val_loader, opt.log_step)
316 | if not fold:
317 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)])
318 | img_embs1 = np.array([img_embs1[i] for i in range(0, len(img_embs1), per_captions)])
319 | img_embs2 = np.array([img_embs2[i] for i in range(0, len(img_embs2), per_captions)])
320 | img_embs3 = np.array([img_embs3[i] for i in range(0, len(img_embs3), per_captions)])
321 | # record computation time of validation
322 | print('Computing similarity from model')
323 | sims_mean = shard_attn_scores(models[0].similarity_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000,
324 | mode="sim")
325 | sims_mean += shard_attn_scores(models[1].similarity_model, img_embs1, cap_embs1, cap_lens1, opt,
326 | shard_size=1000, mode="sim")
327 | sims_mean += shard_attn_scores(models[2].similarity_model, img_embs2, cap_embs2, cap_lens2, opt,
328 | shard_size=1000, mode="sim")
329 | sims_mean += shard_attn_scores(models[3].similarity_model, img_embs3, cap_embs3, cap_lens3, opt,
330 | shard_size=1000, mode="sim")
331 | sims_mean /= 4
332 | print("Calculate similarity time with model")
333 | # caption retrieval
334 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
335 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3))
336 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr))
337 | # image retrieval
338 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False)
339 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3))
340 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr))
341 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i
342 | print("Sum of Recall: %.2f" % (r_sum))
343 | return r_sum
344 | else:
345 | # 5fold cross-validation, only for MSCOCO
346 | results = []
347 | for i in range(5):
348 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
349 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
350 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
351 |
352 | img_embs_shard1 = img_embs1[i * 5000:(i + 1) * 5000:5]
353 | cap_embs_shard1 = cap_embs1[i * 5000:(i + 1) * 5000]
354 | cap_lens_shard1 = cap_lens1[i * 5000:(i + 1) * 5000]
355 | sims = shard_attn_scores(models[0].similarity_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt,
356 | shard_size=1000,
357 | mode="not sim")
358 | sims += shard_attn_scores(models[1].similarity_model, img_embs_shard1, cap_embs_shard1, cap_lens_shard1,
359 | opt,
360 | shard_size=1000,
361 | mode="not sim")
362 | sims /= 2
363 |
364 | print('Computing similarity from model')
365 | r, rt0 = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
366 | ri, rti0 = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True)
367 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
368 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
369 | if i == 0:
370 | rt, rti = rt0, rti0
371 | ar = (r[0] + r[1] + r[2]) / 3
372 | ari = (ri[0] + ri[1] + ri[2]) / 3
373 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
374 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
375 | results += [list(r) + list(ri) + [ar, ari, rsum]]
376 |
377 | print("-----------------------------------")
378 | print("Mean metrics: ")
379 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
380 | a = np.array(mean_metrics)
381 |
382 | print("Average i2t Recall: %.1f" % mean_metrics[11])
383 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
384 | mean_metrics[:5])
385 | print("Average t2i Recall: %.1f" % mean_metrics[12])
386 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
387 | mean_metrics[5:10])
388 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum()))
389 |
390 |
391 | def eval_DECL(checkpoint_paths, avg_SGRAF=True, data_path=None, vocab_path=None):
392 | if avg_SGRAF is False:
393 | print(f"Load checkpoint from '{checkpoint_paths[0]}'")
394 | checkpoint = torch.load(checkpoint_paths[0])
395 | opt = checkpoint['opt']
396 | opt.ssl = False
397 | print(
398 | f"Noise ratio is {opt.noise_ratio}, module is {opt.module_name}, best validation epoch is {checkpoint['epoch']} ({checkpoint['best_rsum']})")
399 | if vocab_path != None:
400 | opt.vocab_path = vocab_path
401 | if data_path != None:
402 | opt.data_path = data_path
403 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
404 | model_A = RINCE(opt)
405 | model_B = RINCE(opt)
406 | model_A.load_state_dict(checkpoint['model_A'])
407 | model_B.load_state_dict(checkpoint['model_B'])
408 |
409 | if 'coco' in opt.data_name:
410 | test_loader = get_test_loader('testall', opt.data_name, vocab, 100, 0, opt)
411 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=True)
412 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=False)
413 | else:
414 | test_loader = get_test_loader('test', opt.data_name, vocab, 100, 0, opt)
415 | validation(opt, test_loader, model_A=model_A, model_B=model_B, fold=False)
416 | else:
417 | assert len(checkpoint_paths) == 2
418 | print(f"Load checkpoint from '{checkpoint_paths}'")
419 | checkpoint0 = torch.load(checkpoint_paths[0])
420 | checkpoint1 = torch.load(checkpoint_paths[1])
421 | opt0 = checkpoint0['opt']
422 | opt1 = checkpoint1['opt']
423 | print(
424 | f"Noise ratios are {opt0.noise_ratio} and {opt1.noise_ratio}, "
425 | f"modules are {opt0.module_name} and {opt1.module_name}, best validation epochs are {checkpoint0['epoch']}"
426 | f" ({checkpoint0['best_rsum']}) and {checkpoint1['epoch']} ({checkpoint1['best_rsum']})")
427 | vocab = deserialize_vocab(os.path.join(opt0.vocab_path, '%s_vocab.json' % opt0.data_name))
428 | model0_A = RINCE(opt0)
429 | model0_B = RINCE(opt0)
430 | model1_A = RINCE(opt1)
431 | model1_B = RINCE(opt1)
432 |
433 | model0_A.load_state_dict(checkpoint0['model_A'])
434 | model0_B.load_state_dict(checkpoint0['model_B'])
435 | model1_A.load_state_dict(checkpoint1['model_A'])
436 | model1_B.load_state_dict(checkpoint1['model_B'])
437 | if 'coco' in opt0.data_name:
438 | test_loader = get_test_loader('testall', opt0.data_name, vocab, 100, 0, opt0)
439 | print(f'=====>model {opt0.module_name} fold:True')
440 | validation(opt0, test_loader, model0_A, model0_B, fold=True)
441 | print(f'=====>model {opt1.module_name} fold:True')
442 | validation(opt0, test_loader, model1_A, model1_B, fold=True)
443 | print(f'=====>model SGRAF fold:True')
444 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=True)
445 |
446 | print(f'=====>model {opt0.module_name} fold:False')
447 | validation(opt0, test_loader, model0_A, model0_B, fold=False)
448 | print(f'=====>model {opt1.module_name} fold:False')
449 | validation(opt0, test_loader, model0_A, model0_B, fold=False)
450 | print('=====>model SGRAF fold:False')
451 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=False)
452 |
453 | else:
454 | test_loader = get_test_loader('test', opt0.data_name, vocab, 100, 0, opt0)
455 | print(f'=====>model {opt0.module_name} fold:False')
456 | validation(opt0, test_loader, model0_A, model0_B, fold=False)
457 | print(f'=====>model {opt1.module_name} fold:False')
458 | validation(opt0, test_loader, model1_A, model1_B, fold=False)
459 | print('=====>model SGRAF fold:False')
460 | validation_dul(opt0, test_loader, models=[model0_A, model0_B, model1_A, model1_B], fold=False)
--------------------------------------------------------------------------------
/GSC/model/SGRAF.py:
--------------------------------------------------------------------------------
1 | """SGRAF model"""
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import torch.nn.functional as F
7 |
8 | import torch.backends.cudnn as cudnn
9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
10 | from torch.nn.utils.clip_grad import clip_grad_norm_
11 |
12 | import numpy as np
13 | from collections import OrderedDict
14 |
15 |
16 | def l1norm(X, dim, eps=1e-8):
17 | """L1-normalize columns of X"""
18 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
19 | X = torch.div(X, norm)
20 | return X
21 |
22 |
23 | def l2norm(X, dim=-1, eps=1e-8):
24 | """L2-normalize columns of X"""
25 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
26 | X = torch.div(X, norm)
27 | return X
28 |
29 |
30 | def cosine_sim(x1, x2, dim=-1, eps=1e-8):
31 | """Returns cosine similarity between x1 and x2, computed along dim."""
32 | w12 = torch.sum(x1 * x2, dim)
33 | w1 = torch.norm(x1, 2, dim)
34 | w2 = torch.norm(x2, 2, dim)
35 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
36 |
37 |
38 | class EncoderImage(nn.Module):
39 | """
40 | Build local region representations by common-used FC-layer.
41 | Args: - images: raw local detected regions, shape: (batch_size, 36, 2048).
42 | Returns: - img_emb: finial local region embeddings, shape: (batch_size, 36, 1024).
43 | """
44 |
45 | def __init__(self, opt, img_dim, embed_size, no_imgnorm=False):
46 | super(EncoderImage, self).__init__()
47 | self.opt = opt
48 | self.embed_size = embed_size
49 | self.no_imgnorm = no_imgnorm
50 | self.fc = nn.Linear(img_dim, embed_size)
51 |
52 | self.init_weights(self.fc)
53 |
54 | def init_weights(self, fc):
55 | """Xavier initialization for the fully connected layer"""
56 | r = np.sqrt(6.) / np.sqrt(fc.in_features +
57 | fc.out_features)
58 | fc.weight.data.uniform_(-r, r)
59 | fc.bias.data.fill_(0)
60 |
61 | def forward(self, images):
62 | """Extract image feature vectors."""
63 | # assuming that the precomputed features are already l2-normalized
64 | img_emb = self.fc(images)
65 |
66 | # normalize in the joint embedding space
67 | if not self.no_imgnorm:
68 | img_emb = l2norm(img_emb, dim=-1)
69 |
70 | return img_emb
71 |
72 | def load_state_dict(self, state_dict):
73 | """Overwrite the default one to accept state_dict from Full model"""
74 | own_state = self.state_dict()
75 | new_state = OrderedDict()
76 | for name, param in state_dict.items():
77 | if name in own_state:
78 | new_state[name] = param
79 |
80 | super(EncoderImage, self).load_state_dict(new_state)
81 |
82 |
83 | class EncoderText(nn.Module):
84 | """
85 | Build local word representations by common-used Bi-GRU or GRU.
86 | Args: - images: raw local word ids, shape: (batch_size, L).
87 | Returns: - img_emb: final local word embeddings, shape: (batch_size, L, 1024).
88 | """
89 |
90 | def __init__(self, opt, vocab_size, word_dim, embed_size, num_layers,
91 | use_bi_gru=False, no_txtnorm=False):
92 | super(EncoderText, self).__init__()
93 | self.opt = opt
94 | self.embed_size = embed_size
95 | self.no_txtnorm = no_txtnorm
96 |
97 | # word embedding
98 | self.embed = nn.Embedding(vocab_size, word_dim)
99 | self.dropout = nn.Dropout(0.4)
100 |
101 | # caption embedding
102 | self.use_bi_gru = use_bi_gru
103 | self.cap_rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
104 |
105 | self.init_weights()
106 |
107 | def init_weights(self):
108 | self.embed.weight.data.uniform_(-0.1, 0.1)
109 |
110 | def forward(self, captions, lengths):
111 | """Handles variable size captions"""
112 | # embed word ids to vectors
113 | cap_emb = self.embed(captions)
114 | cap_emb = self.dropout(cap_emb)
115 |
116 | # pack the caption
117 | packed = pack_padded_sequence(cap_emb, lengths, batch_first=True)
118 |
119 | # forward propagate RNN
120 | out, _ = self.cap_rnn(packed)
121 |
122 | # reshape output to (batch_size, hidden_size)
123 | cap_emb, _ = pad_packed_sequence(out, batch_first=True)
124 |
125 | if self.use_bi_gru:
126 | cap_emb = (cap_emb[:, :, :cap_emb.size(2) // 2] + cap_emb[:, :, cap_emb.size(2) // 2:]) / 2
127 |
128 | # normalization in the joint embedding space
129 | if not self.no_txtnorm:
130 | cap_emb = l2norm(cap_emb, dim=-1)
131 |
132 | return cap_emb
133 |
134 |
135 | class VisualSA(nn.Module):
136 | """
137 | Build global image representations by self-attention.
138 | Args: - local: local region embeddings, shape: (batch_size, 36, 1024)
139 | - raw_global: raw image by averaging regions, shape: (batch_size, 1024)
140 | Returns: - new_global: final image by self-attention, shape: (batch_size, 1024).
141 | """
142 |
143 | def __init__(self, embed_dim, dropout_rate, num_region):
144 | super(VisualSA, self).__init__()
145 |
146 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
147 | nn.BatchNorm1d(num_region),
148 | nn.Tanh(), nn.Dropout(dropout_rate))
149 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
150 | nn.BatchNorm1d(embed_dim),
151 | nn.Tanh(), nn.Dropout(dropout_rate))
152 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))
153 |
154 | self.init_weights()
155 | self.softmax = nn.Softmax(dim=1)
156 |
157 | def init_weights(self):
158 | for embeddings in self.children():
159 | for m in embeddings:
160 | if isinstance(m, nn.Linear):
161 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
162 | m.weight.data.uniform_(-r, r)
163 | m.bias.data.fill_(0)
164 | elif isinstance(m, nn.BatchNorm1d):
165 | m.weight.data.fill_(1)
166 | m.bias.data.zero_()
167 |
168 | def forward(self, local, raw_global):
169 | # compute embedding of local regions and raw global image
170 | if len(local) == 1 and len(raw_global) == 1:
171 | l_emb = self.embedding_local[0](local)
172 | l_emb = self.embedding_local[2](l_emb)
173 | l_emb = self.embedding_local[3](l_emb)
174 | g_emb = self.embedding_global[0](raw_global)
175 | g_emb = self.embedding_global[2](g_emb)
176 | g_emb = self.embedding_global[3](g_emb)
177 | else:
178 | l_emb = self.embedding_local(local)
179 | g_emb = self.embedding_global(raw_global)
180 |
181 | # compute the normalized weights, shape: (batch_size, 36)
182 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
183 | common = l_emb.mul(g_emb)
184 | weights = self.embedding_common(common).squeeze(2)
185 | weights = self.softmax(weights)
186 |
187 | # compute final image, shape: (batch_size, 1024)
188 | new_global = (weights.unsqueeze(2) * local).sum(dim=1)
189 | new_global = l2norm(new_global, dim=-1)
190 |
191 | return new_global
192 |
193 |
194 | class TextSA(nn.Module):
195 | """
196 | Build global text representations by self-attention.
197 | Args: - local: local word embeddings, shape: (batch_size, L, 1024)
198 | - raw_global: raw text by averaging words, shape: (batch_size, 1024)
199 | Returns: - new_global: final text by self-attention, shape: (batch_size, 1024).
200 | """
201 |
202 | def __init__(self, embed_dim, dropout_rate):
203 | super(TextSA, self).__init__()
204 |
205 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
206 | nn.Tanh(), nn.Dropout(dropout_rate))
207 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
208 | nn.Tanh(), nn.Dropout(dropout_rate))
209 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))
210 |
211 | self.init_weights()
212 | self.softmax = nn.Softmax(dim=1)
213 |
214 | def init_weights(self):
215 | for embeddings in self.children():
216 | for m in embeddings:
217 | if isinstance(m, nn.Linear):
218 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
219 | m.weight.data.uniform_(-r, r)
220 | m.bias.data.fill_(0)
221 | elif isinstance(m, nn.BatchNorm1d):
222 | m.weight.data.fill_(1)
223 | m.bias.data.zero_()
224 |
225 | def forward(self, local, raw_global):
226 | # compute embedding of local words and raw global text
227 | l_emb = self.embedding_local(local)
228 | g_emb = self.embedding_global(raw_global)
229 |
230 | # compute the normalized weights, shape: (batch_size, L)
231 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
232 | common = l_emb.mul(g_emb)
233 | weights = self.embedding_common(common).squeeze(2)
234 | weights = self.softmax(weights)
235 |
236 | # compute final text, shape: (batch_size, 1024)
237 | new_global = (weights.unsqueeze(2) * local).sum(dim=1)
238 | new_global = l2norm(new_global, dim=-1)
239 |
240 | return new_global
241 |
242 |
243 | class GraphReasoning(nn.Module):
244 | """
245 | Perform the similarity graph reasoning with a full-connected graph
246 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
247 | Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256)
248 | """
249 |
250 | def __init__(self, sim_dim):
251 | super(GraphReasoning, self).__init__()
252 |
253 | self.graph_query_w = nn.Linear(sim_dim, sim_dim)
254 | self.graph_key_w = nn.Linear(sim_dim, sim_dim)
255 | self.sim_graph_w = nn.Linear(sim_dim, sim_dim)
256 | # self.relu = nn.LeakyReLU(negative_slope=0.1)
257 | self.relu = nn.ReLU()
258 |
259 | self.init_weights()
260 |
261 | def forward(self, sim_emb):
262 | sim_query = self.graph_query_w(sim_emb)
263 | sim_key = self.graph_key_w(sim_emb)
264 | sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1)
265 | sim_sgr = torch.bmm(sim_edge, sim_emb)
266 | sim_sgr = self.relu(self.sim_graph_w(sim_sgr))
267 | return sim_sgr
268 |
269 | def init_weights(self):
270 | for m in self.children():
271 | if isinstance(m, nn.Linear):
272 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
273 | m.weight.data.uniform_(-r, r)
274 | m.bias.data.fill_(0)
275 | elif isinstance(m, nn.BatchNorm1d):
276 | m.weight.data.fill_(1)
277 | m.bias.data.zero_()
278 |
279 |
280 | class AttentionFiltration(nn.Module):
281 | """
282 | Perform the similarity Attention Filtration with a gate-based attention
283 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
284 | Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256)
285 | """
286 |
287 | def __init__(self, sim_dim):
288 | super(AttentionFiltration, self).__init__()
289 |
290 | self.attn_sim_w = nn.Linear(sim_dim, 1)
291 | self.bn = nn.BatchNorm1d(1)
292 |
293 | self.init_weights()
294 |
295 | def forward(self, sim_emb):
296 | sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1)
297 | sim_saf = torch.matmul(sim_attn, sim_emb)
298 | sim_saf = l2norm(sim_saf.squeeze(1), dim=-1)
299 | return sim_saf
300 |
301 | def init_weights(self):
302 | for m in self.children():
303 | if isinstance(m, nn.Linear):
304 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
305 | m.weight.data.uniform_(-r, r)
306 | m.bias.data.fill_(0)
307 | elif isinstance(m, nn.BatchNorm1d):
308 | m.weight.data.fill_(1)
309 | m.bias.data.zero_()
310 |
311 |
312 | class EncoderSimilarity(nn.Module):
313 | """
314 | Compute the image-text similarity by SGR, SAF, AVE
315 | Args: - img_emb: local region embeddings, shape: (batch_size, 36, 1024)
316 | - cap_emb: local word embeddings, shape: (batch_size, L, 1024)
317 | Returns:
318 | - sim_all: final image-text similarities, shape: (batch_size, batch_size).
319 | """
320 |
321 | def __init__(self, opt, embed_size, sim_dim, module_name='AVE', sgr_step=3):
322 | super(EncoderSimilarity, self).__init__()
323 | self.opt = opt
324 | self.module_name = module_name
325 |
326 | self.v_global_w = VisualSA(embed_size, 0.4, 36)
327 | self.t_global_w = TextSA(embed_size, 0.4)
328 |
329 | self.sim_tranloc_w = nn.Linear(embed_size, sim_dim)
330 | self.sim_tranglo_w = nn.Linear(embed_size, sim_dim)
331 |
332 | self.sim_eval_w = nn.Linear(sim_dim, 1)
333 | self.sigmoid = nn.Sigmoid()
334 |
335 | if module_name == 'SGR':
336 | self.SGR_module = nn.ModuleList([GraphReasoning(sim_dim) for i in range(sgr_step)])
337 | elif module_name == 'SAF':
338 | self.SAF_module = AttentionFiltration(sim_dim)
339 | else:
340 | raise ValueError('Invalid input of opt.module_name in opts.py')
341 |
342 | self.init_weights()
343 |
344 | # Modifications by extending DECL
345 | def forward(self, img_emb, cap_emb, cap_lens):
346 | sim_all = []
347 | n_image = img_emb.size(0)
348 | n_caption = cap_emb.size(0)
349 |
350 | # get enhanced global images by self-attention
351 | img_ave = torch.mean(img_emb, 1)
352 | img_glo = self.v_global_w(img_emb, img_ave)
353 |
354 | for i in range(n_caption):
355 | # get the i-th sentence
356 | n_word = cap_lens[i]
357 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
358 | cap_i_expand = cap_i.repeat(n_image, 1, 1)
359 |
360 | # get enhanced global i-th text by self-attention
361 | cap_ave_i = torch.mean(cap_i, 1)
362 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i)
363 |
364 | # local-global alignment construction
365 | Context_img = SCAN_attention(cap_i_expand, img_emb, smooth=9.0)
366 | sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2)
367 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
368 |
369 | sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2)
370 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
371 |
372 | # concat the global and local alignments
373 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
374 |
375 | # compute the final similarity vector
376 | if self.module_name == 'SGR':
377 | for module in self.SGR_module:
378 | sim_emb = module(sim_emb)
379 | sim_vec = sim_emb[:, 0, :]
380 | else:
381 | sim_vec = self.SAF_module(sim_emb)
382 |
383 | # compute the final similarity score
384 | # sim_i = self.sigmoid(self.sim_eval_w(sim_vec))
385 | # sim_all.append(sim_i)
386 | sim_all.append(self.sim_eval_w(sim_vec))
387 |
388 | # (n_image, n_caption)
389 | sim_all = torch.cat(sim_all, 1)
390 |
391 | return sim_all
392 |
393 | def forward_topology(self, img_emb, cap_emb, cap_lens):
394 | sim_img = []
395 | n_image = img_emb.size(0)
396 |
397 | img_ave = torch.mean(img_emb, 1)
398 | img_glo = self.v_global_w(img_emb, img_ave)
399 |
400 | for i in range(n_image):
401 | img_i = img_emb[i].unsqueeze(0)
402 | img_i_expand = img_i.repeat(n_image, 1, 1)
403 |
404 | img_glo_i = img_glo[i].unsqueeze(0)
405 |
406 | sim_loc = torch.pow(torch.sub(img_emb, img_i_expand), 2) + 1e-6
407 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
408 |
409 | sim_glo = torch.pow(torch.sub(img_glo, img_glo_i), 2) + 1e-6
410 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
411 |
412 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
413 |
414 | if self.module_name == 'SGR':
415 | for module in self.SGR_module:
416 | sim_emb = module(sim_emb)
417 | sim_vec = sim_emb[:, 0, :]
418 | else:
419 | sim_vec = self.SAF_module(sim_emb)
420 |
421 | sim_img.append(self.sim_eval_w(sim_vec))
422 |
423 | sim_img = torch.cat(sim_img, 1)
424 |
425 | sim_cap = []
426 | n_caption = cap_emb.size(0)
427 |
428 | cap_glo = []
429 | for i in range (n_caption):
430 | n_word = cap_lens[i]
431 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
432 | cap_ave_i = torch.mean(cap_i, 1)
433 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i)
434 | cap_glo.append(cap_glo_i)
435 | cap_glo = torch.cat(cap_glo, 0)
436 |
437 | for i in range(n_caption):
438 | n_word = cap_lens[i]
439 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
440 | cap_i_expand = cap_i.repeat(n_caption, 1, 1)
441 | cap_glo_i = cap_glo[i].unsqueeze(0)
442 |
443 | Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens)
444 | sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2)
445 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
446 |
447 | sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6
448 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
449 |
450 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
451 |
452 | if self.module_name == 'SGR':
453 | for module in self.SGR_module:
454 | sim_emb = module(sim_emb)
455 | sim_vec = sim_emb[:, 0, :]
456 | else:
457 | sim_vec = self.SAF_module(sim_emb)
458 |
459 | sim_cap.append(self.sim_eval_w(sim_vec))
460 |
461 | sim_cap = torch.cat(sim_cap, 1)
462 |
463 | return sim_img, sim_cap
464 |
465 |
466 | def init_weights(self):
467 | for m in self.children():
468 | if isinstance(m, nn.Linear):
469 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
470 | m.weight.data.uniform_(-r, r)
471 | m.bias.data.fill_(0)
472 | elif isinstance(m, nn.BatchNorm1d):
473 | m.weight.data.fill_(1)
474 | m.bias.data.zero_()
475 |
476 |
477 | def SCAN_attention(query, context, smooth, eps=1e-8, lens=None):
478 | """
479 | query: (n_context, queryL, d)
480 | context: (n_context, sourceL, d)
481 | """
482 | # --> (batch, d, queryL)
483 | queryT = torch.transpose(query, 1, 2)
484 |
485 | # (batch, sourceL, d)(batch, d, queryL)
486 | # --> (batch, sourceL, queryL)
487 | attn = torch.bmm(context, queryT)
488 |
489 | attn = nn.LeakyReLU(0.1)(attn)
490 | attn = l2norm(attn, 2)
491 |
492 | if lens is not None:
493 | for i in range(len(lens)):
494 | attn[i, lens[i]:] = -1e9
495 |
496 | # --> (batch, queryL, sourceL)
497 | attn = torch.transpose(attn, 1, 2).contiguous()
498 | # --> (batch, queryL, sourceL
499 | attn = F.softmax(attn * smooth, dim=2)
500 |
501 | # --> (batch, sourceL, queryL)
502 | attnT = torch.transpose(attn, 1, 2).contiguous()
503 |
504 | # --> (batch, d, sourceL)
505 | contextT = torch.transpose(context, 1, 2)
506 | # (batch x d x sourceL)(batch x sourceL x queryL)
507 | # --> (batch, d, queryL)
508 | weightedContext = torch.bmm(contextT, attnT)
509 | # --> (batch, queryL, d)
510 | weightedContext = torch.transpose(weightedContext, 1, 2)
511 | weightedContext = l2norm(weightedContext, dim=-1)
512 |
513 | return weightedContext
514 |
515 | class SGRAF(object):
516 | """
517 | Similarity Reasoning and Filtration (SGRAF) Network
518 | """
519 |
520 | def __init__(self, opt):
521 | # Build Models
522 | self.opt = opt
523 | self.grad_clip = opt.grad_clip
524 | self.img_enc = EncoderImage(opt, opt.img_dim, opt.embed_size,
525 | no_imgnorm=opt.no_imgnorm)
526 | self.txt_enc = EncoderText(opt, opt.vocab_size, opt.word_dim,
527 | opt.embed_size, opt.num_layers,
528 | use_bi_gru=opt.bi_gru,
529 | no_txtnorm=opt.no_txtnorm)
530 | self.sim_enc = EncoderSimilarity(opt, opt.embed_size, opt.sim_dim,
531 | opt.module_name, opt.sgr_step)
532 |
533 | if torch.cuda.is_available():
534 | self.img_enc.cuda()
535 | self.txt_enc.cuda()
536 | self.sim_enc.cuda()
537 | cudnn.benchmark = True
538 |
539 | # Loss and Optimizer
540 | params = list(self.txt_enc.parameters())
541 | params += list(self.img_enc.parameters())
542 | params += list(self.sim_enc.parameters())
543 |
544 | self.params = params
545 |
546 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
547 | self.Eiters = 0
548 |
549 | if 'f30k' in self.opt.data_name:
550 | self.targets = torch.ones(145000).cuda()
551 | self.preds = torch.ones(145000).cuda()
552 | self.preds_targets = torch.ones(145000).cuda()
553 | self.topos = torch.ones(145000).cuda()
554 | self.topo_targets = torch.ones(145000).cuda()
555 | elif 'coco' in self.opt.data_name:
556 | self.targets = torch.ones(566435).cuda()
557 | self.preds = torch.ones(566435).cuda()
558 | self.preds_targets = torch.ones(566435).cuda()
559 | self.topos = torch.ones(566435).cuda()
560 | self.topo_targets = torch.ones(566435).cuda()
561 | elif 'cc152k' in self.opt.data_name:
562 | self.targets = torch.ones(150000).cuda()
563 | self.preds = torch.ones(150000).cuda()
564 | self.preds_targets = torch.ones(150000).cuda()
565 | self.topos = torch.ones(150000).cuda()
566 | self.topo_targets = torch.ones(150000).cuda()
567 |
568 | def state_dict(self):
569 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.targets, self.topos, self.topo_targets]
570 | return state_dict
571 |
572 | def load_state_dict(self, state_dict):
573 | self.img_enc.load_state_dict(state_dict[0])
574 | self.txt_enc.load_state_dict(state_dict[1])
575 | self.sim_enc.load_state_dict(state_dict[2])
576 | if len(state_dict) > 3:
577 | self.targets = state_dict[3]
578 | self.topos = state_dict[4]
579 | self.topo_targets = state_dict[5]
580 |
581 | def train_start(self):
582 | """switch to train mode"""
583 | self.img_enc.train()
584 | self.txt_enc.train()
585 | self.sim_enc.train()
586 |
587 | def val_start(self):
588 | """switch to evaluate mode"""
589 | self.img_enc.eval()
590 | self.txt_enc.eval()
591 | self.sim_enc.eval()
592 |
593 | def forward_emb(self, images, captions, lengths):
594 | """Compute the image and caption embeddings"""
595 | if torch.cuda.is_available():
596 | images = images.cuda()
597 | captions = captions.cuda()
598 |
599 | # Forward feature encoding
600 | img_embs = self.img_enc(images)
601 | cap_embs = self.txt_enc(captions, lengths)
602 | return img_embs, cap_embs, lengths
603 |
604 | # Modifications by extending DECL
605 | def forward_sim(self, img_embs, cap_embs, cap_lens, mode='sim'):
606 | # Forward similarity encoding
607 | raw_sims = self.sim_enc(img_embs, cap_embs, cap_lens)
608 | if mode == 'sim':
609 | return torch.sigmoid(raw_sims)
610 |
611 | def train_emb(self, images, captions, lengths, ids=None, *args):
612 | """One training step given images and captions.
613 | """
614 | self.Eiters += 1
615 | self.logger.update('Eit', self.Eiters)
616 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
617 |
618 | # compute the embeddings
619 | img_embs, cap_embs, cap_lens = self.forward_emb(images, captions, lengths)
620 | sims = self.forward_sim(img_embs, cap_embs, cap_lens)
621 |
622 | # measure accuracy and record loss
623 | self.optimizer.zero_grad()
624 | loss = self.forward_loss(sims)
625 |
626 | # compute gradient and do SGD step
627 | loss.backward()
628 | if self.grad_clip > 0:
629 | clip_grad_norm_(self.params, self.grad_clip)
630 | self.optimizer.step()
631 |
--------------------------------------------------------------------------------
/GSC+/model/SGRAF.py:
--------------------------------------------------------------------------------
1 | """SGRAF model"""
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import torch.nn.functional as F
7 |
8 | import torch.backends.cudnn as cudnn
9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
10 | from torch.nn.utils.clip_grad import clip_grad_norm_
11 |
12 | import numpy as np
13 | from collections import OrderedDict
14 |
15 |
16 | def l1norm(X, dim, eps=1e-8):
17 | """L1-normalize columns of X"""
18 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
19 | X = torch.div(X, norm)
20 | return X
21 |
22 |
23 | def l2norm(X, dim=-1, eps=1e-8):
24 | """L2-normalize columns of X"""
25 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
26 | X = torch.div(X, norm)
27 | return X
28 |
29 |
30 | def cosine_sim(x1, x2, dim=-1, eps=1e-8):
31 | """Returns cosine similarity between x1 and x2, computed along dim."""
32 | w12 = torch.sum(x1 * x2, dim)
33 | w1 = torch.norm(x1, 2, dim)
34 | w2 = torch.norm(x2, 2, dim)
35 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
36 |
37 |
38 | class EncoderImage(nn.Module):
39 | """
40 | Build local region representations by common-used FC-layer.
41 | Args: - images: raw local detected regions, shape: (batch_size, 36, 2048).
42 | Returns: - img_emb: finial local region embeddings, shape: (batch_size, 36, 1024).
43 | """
44 |
45 | def __init__(self, opt, img_dim, embed_size, no_imgnorm=False):
46 | super(EncoderImage, self).__init__()
47 | self.opt = opt
48 | self.embed_size = embed_size
49 | self.no_imgnorm = no_imgnorm
50 | self.fc = nn.Linear(img_dim, embed_size)
51 |
52 | self.init_weights(self.fc)
53 |
54 | def init_weights(self, fc):
55 | """Xavier initialization for the fully connected layer"""
56 | r = np.sqrt(6.) / np.sqrt(fc.in_features +
57 | fc.out_features)
58 | fc.weight.data.uniform_(-r, r)
59 | fc.bias.data.fill_(0)
60 |
61 | def forward(self, images):
62 | """Extract image feature vectors."""
63 | # assuming that the precomputed features are already l2-normalized
64 | img_emb = self.fc(images)
65 |
66 | # normalize in the joint embedding space
67 | if not self.no_imgnorm:
68 | img_emb = l2norm(img_emb, dim=-1)
69 |
70 | return img_emb
71 |
72 | def load_state_dict(self, state_dict):
73 | """Overwrite the default one to accept state_dict from Full model"""
74 | own_state = self.state_dict()
75 | new_state = OrderedDict()
76 | for name, param in state_dict.items():
77 | if name in own_state:
78 | new_state[name] = param
79 |
80 | super(EncoderImage, self).load_state_dict(new_state)
81 |
82 |
83 | class EncoderText(nn.Module):
84 | """
85 | Build local word representations by common-used Bi-GRU or GRU.
86 | Args: - images: raw local word ids, shape: (batch_size, L).
87 | Returns: - img_emb: final local word embeddings, shape: (batch_size, L, 1024).
88 | """
89 |
90 | def __init__(self, opt, vocab_size, word_dim, embed_size, num_layers,
91 | use_bi_gru=False, no_txtnorm=False):
92 | super(EncoderText, self).__init__()
93 | self.opt = opt
94 | self.embed_size = embed_size
95 | self.no_txtnorm = no_txtnorm
96 |
97 | # word embedding
98 | self.embed = nn.Embedding(vocab_size, word_dim)
99 | self.dropout = nn.Dropout(0.4)
100 |
101 | # caption embedding
102 | self.use_bi_gru = use_bi_gru
103 | self.cap_rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru)
104 |
105 | self.init_weights()
106 |
107 | def init_weights(self):
108 | self.embed.weight.data.uniform_(-0.1, 0.1)
109 |
110 | def forward(self, captions, lengths):
111 | """Handles variable size captions"""
112 | # embed word ids to vectors
113 | cap_emb = self.embed(captions)
114 | cap_emb = self.dropout(cap_emb)
115 |
116 | # pack the caption
117 | packed = pack_padded_sequence(cap_emb, lengths, batch_first=True)
118 |
119 | # forward propagate RNN
120 | out, _ = self.cap_rnn(packed)
121 |
122 | # reshape output to (batch_size, hidden_size)
123 | cap_emb, _ = pad_packed_sequence(out, batch_first=True)
124 |
125 | if self.use_bi_gru:
126 | cap_emb = (cap_emb[:, :, :cap_emb.size(2) // 2] + cap_emb[:, :, cap_emb.size(2) // 2:]) / 2
127 |
128 | # normalization in the joint embedding space
129 | if not self.no_txtnorm:
130 | cap_emb = l2norm(cap_emb, dim=-1)
131 |
132 | return cap_emb
133 |
134 |
135 | class VisualSA(nn.Module):
136 | """
137 | Build global image representations by self-attention.
138 | Args: - local: local region embeddings, shape: (batch_size, 36, 1024)
139 | - raw_global: raw image by averaging regions, shape: (batch_size, 1024)
140 | Returns: - new_global: final image by self-attention, shape: (batch_size, 1024).
141 | """
142 |
143 | def __init__(self, embed_dim, dropout_rate, num_region):
144 | super(VisualSA, self).__init__()
145 |
146 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
147 | nn.BatchNorm1d(num_region),
148 | nn.Tanh(), nn.Dropout(dropout_rate))
149 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
150 | nn.BatchNorm1d(embed_dim),
151 | nn.Tanh(), nn.Dropout(dropout_rate))
152 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))
153 |
154 | self.init_weights()
155 | self.softmax = nn.Softmax(dim=1)
156 |
157 | def init_weights(self):
158 | for embeddings in self.children():
159 | for m in embeddings:
160 | if isinstance(m, nn.Linear):
161 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
162 | m.weight.data.uniform_(-r, r)
163 | m.bias.data.fill_(0)
164 | elif isinstance(m, nn.BatchNorm1d):
165 | m.weight.data.fill_(1)
166 | m.bias.data.zero_()
167 |
168 | def forward(self, local, raw_global):
169 | # compute embedding of local regions and raw global image
170 | if len(local) == 1 and len(raw_global) == 1:
171 | l_emb = self.embedding_local[0](local)
172 | l_emb = self.embedding_local[2](l_emb)
173 | l_emb = self.embedding_local[3](l_emb)
174 | g_emb = self.embedding_global[0](raw_global)
175 | g_emb = self.embedding_global[2](g_emb)
176 | g_emb = self.embedding_global[3](g_emb)
177 | else:
178 | l_emb = self.embedding_local(local)
179 | g_emb = self.embedding_global(raw_global)
180 |
181 | # compute the normalized weights, shape: (batch_size, 36)
182 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
183 | common = l_emb.mul(g_emb)
184 | weights = self.embedding_common(common).squeeze(2)
185 | weights = self.softmax(weights)
186 |
187 | # compute final image, shape: (batch_size, 1024)
188 | new_global = (weights.unsqueeze(2) * local).sum(dim=1)
189 | new_global = l2norm(new_global, dim=-1)
190 |
191 | return new_global
192 |
193 |
194 | class TextSA(nn.Module):
195 | """
196 | Build global text representations by self-attention.
197 | Args: - local: local word embeddings, shape: (batch_size, L, 1024)
198 | - raw_global: raw text by averaging words, shape: (batch_size, 1024)
199 | Returns: - new_global: final text by self-attention, shape: (batch_size, 1024).
200 | """
201 |
202 | def __init__(self, embed_dim, dropout_rate):
203 | super(TextSA, self).__init__()
204 |
205 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim),
206 | nn.Tanh(), nn.Dropout(dropout_rate))
207 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim),
208 | nn.Tanh(), nn.Dropout(dropout_rate))
209 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1))
210 |
211 | self.init_weights()
212 | self.softmax = nn.Softmax(dim=1)
213 |
214 | def init_weights(self):
215 | for embeddings in self.children():
216 | for m in embeddings:
217 | if isinstance(m, nn.Linear):
218 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
219 | m.weight.data.uniform_(-r, r)
220 | m.bias.data.fill_(0)
221 | elif isinstance(m, nn.BatchNorm1d):
222 | m.weight.data.fill_(1)
223 | m.bias.data.zero_()
224 |
225 | def forward(self, local, raw_global):
226 | # compute embedding of local words and raw global text
227 | l_emb = self.embedding_local(local)
228 | g_emb = self.embedding_global(raw_global)
229 |
230 | # compute the normalized weights, shape: (batch_size, L)
231 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1)
232 | common = l_emb.mul(g_emb)
233 | weights = self.embedding_common(common).squeeze(2)
234 | weights = self.softmax(weights)
235 |
236 | # compute final text, shape: (batch_size, 1024)
237 | new_global = (weights.unsqueeze(2) * local).sum(dim=1)
238 | new_global = l2norm(new_global, dim=-1)
239 |
240 | return new_global
241 |
242 |
243 | class GraphReasoning(nn.Module):
244 | """
245 | Perform the similarity graph reasoning with a full-connected graph
246 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
247 | Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256)
248 | """
249 |
250 | def __init__(self, sim_dim):
251 | super(GraphReasoning, self).__init__()
252 |
253 | self.graph_query_w = nn.Linear(sim_dim, sim_dim)
254 | self.graph_key_w = nn.Linear(sim_dim, sim_dim)
255 | self.sim_graph_w = nn.Linear(sim_dim, sim_dim)
256 | # self.relu = nn.LeakyReLU(negative_slope=0.1)
257 | self.relu = nn.ReLU()
258 |
259 | self.init_weights()
260 |
261 | def forward(self, sim_emb):
262 | sim_query = self.graph_query_w(sim_emb)
263 | sim_key = self.graph_key_w(sim_emb)
264 | sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1)
265 | sim_sgr = torch.bmm(sim_edge, sim_emb)
266 | sim_sgr = self.relu(self.sim_graph_w(sim_sgr))
267 | return sim_sgr
268 |
269 | def init_weights(self):
270 | for m in self.children():
271 | if isinstance(m, nn.Linear):
272 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
273 | m.weight.data.uniform_(-r, r)
274 | m.bias.data.fill_(0)
275 | elif isinstance(m, nn.BatchNorm1d):
276 | m.weight.data.fill_(1)
277 | m.bias.data.zero_()
278 |
279 |
280 | class AttentionFiltration(nn.Module):
281 | """
282 | Perform the similarity Attention Filtration with a gate-based attention
283 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256)
284 | Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256)
285 | """
286 |
287 | def __init__(self, sim_dim):
288 | super(AttentionFiltration, self).__init__()
289 |
290 | self.attn_sim_w = nn.Linear(sim_dim, 1)
291 | self.bn = nn.BatchNorm1d(1)
292 |
293 | self.init_weights()
294 |
295 | def forward(self, sim_emb):
296 | sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1)
297 | sim_saf = torch.matmul(sim_attn, sim_emb)
298 | sim_saf = l2norm(sim_saf.squeeze(1), dim=-1)
299 | return sim_saf
300 |
301 | def init_weights(self):
302 | for m in self.children():
303 | if isinstance(m, nn.Linear):
304 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
305 | m.weight.data.uniform_(-r, r)
306 | m.bias.data.fill_(0)
307 | elif isinstance(m, nn.BatchNorm1d):
308 | m.weight.data.fill_(1)
309 | m.bias.data.zero_()
310 |
311 |
312 | class EncoderSimilarity(nn.Module):
313 | """
314 | Compute the image-text similarity by SGR, SAF, AVE
315 | Args: - img_emb: local region embeddings, shape: (batch_size, 36, 1024)
316 | - cap_emb: local word embeddings, shape: (batch_size, L, 1024)
317 | Returns:
318 | - sim_all: final image-text similarities, shape: (batch_size, batch_size).
319 | """
320 |
321 | def __init__(self, opt, embed_size, sim_dim, module_name='AVE', sgr_step=3):
322 | super(EncoderSimilarity, self).__init__()
323 | self.opt = opt
324 | self.module_name = module_name
325 |
326 | self.v_global_w = VisualSA(embed_size, 0.4, 36)
327 | self.t_global_w = TextSA(embed_size, 0.4)
328 |
329 | self.sim_tranloc_w = nn.Linear(embed_size, sim_dim)
330 | self.sim_tranglo_w = nn.Linear(embed_size, sim_dim)
331 |
332 | self.sim_eval_w = nn.Linear(sim_dim, 1)
333 | self.sigmoid = nn.Sigmoid()
334 |
335 | if module_name == 'SGR':
336 | self.SGR_module = nn.ModuleList([GraphReasoning(sim_dim) for i in range(sgr_step)])
337 | elif module_name == 'SAF':
338 | self.SAF_module = AttentionFiltration(sim_dim)
339 | else:
340 | raise ValueError('Invalid input of opt.module_name in opts.py')
341 |
342 | self.init_weights()
343 |
344 | # Modifications by extending DECL
345 | def forward(self, img_emb, cap_emb, cap_lens):
346 | sim_all = []
347 | sim_img = []
348 | sim_cap = []
349 | n_image = img_emb.size(0)
350 | n_caption = cap_emb.size(0)
351 |
352 | # get enhanced global images by self-attention
353 | img_ave = torch.mean(img_emb, 1)
354 | img_glo = self.v_global_w(img_emb, img_ave)
355 |
356 | for i in range(n_caption):
357 | # get the i-th sentence
358 | n_word = cap_lens[i]
359 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
360 | cap_i_expand = cap_i.repeat(n_image, 1, 1)
361 |
362 | # get enhanced global i-th text by self-attention
363 | cap_ave_i = torch.mean(cap_i, 1)
364 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i)
365 |
366 | # local-global alignment construction
367 | Context_img = SCAN_attention(cap_i_expand, img_emb, smooth=9.0)
368 | sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2)
369 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
370 |
371 | sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2)
372 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
373 |
374 | # concat the global and local alignments
375 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
376 |
377 | # compute the final similarity vector
378 | if self.module_name == 'SGR':
379 | for module in self.SGR_module:
380 | sim_emb = module(sim_emb)
381 | sim_vec = sim_emb[:, 0, :]
382 | else:
383 | sim_vec = self.SAF_module(sim_emb)
384 |
385 | # compute the final similarity score
386 | # sim_i = self.sigmoid(self.sim_eval_w(sim_vec))
387 | # sim_all.append(sim_i)
388 | sim_all.append(self.sim_eval_w(sim_vec))
389 |
390 | sim_all = torch.cat(sim_all, 1)
391 | return sim_all
392 |
393 | # cap_glo = []
394 | # for i in range (n_caption):
395 | # n_word = cap_lens[i]
396 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
397 | # cap_ave_i = torch.mean(cap_i, 1)
398 | # cap_glo_i = self.t_global_w(cap_i, cap_ave_i)
399 | # cap_glo.append(cap_glo_i)
400 | # cap_glo = torch.cat(cap_glo, 0)
401 |
402 | # for i in range(n_image):
403 | # img_i = img_emb[i].unsqueeze(0)
404 | # img_i_expand = img_i.repeat(n_image, 1, 1)
405 |
406 | # img_glo_i = img_glo[i].unsqueeze(0)
407 |
408 | # sim_loc = torch.pow(torch.sub(img_emb, img_i_expand), 2) + 1e-6
409 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
410 |
411 | # sim_glo = torch.pow(torch.sub(img_glo, img_glo_i), 2) + 1e-6
412 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
413 |
414 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
415 |
416 | # Context_cap = SCAN_attention(img_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens)
417 | # sim_loc_cap = torch.pow(torch.sub(Context_cap, img_i_expand), 2)
418 | # sim_loc_cap = l2norm(self.sim_tranloc_w(sim_loc_cap), dim=-1)
419 |
420 | # sim_glo_cap = torch.pow(torch.sub(cap_glo, img_glo_i), 2)
421 | # sim_glo_cap = l2norm(self.sim_tranglo_w(sim_glo_cap), dim=-1)
422 |
423 | # sim_emb_cap = torch.cat([sim_glo_cap.unsqueeze(1), sim_loc_cap], 1)
424 |
425 | # sim_emb = torch.cat([sim_emb, sim_emb_cap], 0)
426 |
427 | # if self.module_name == 'SGR':
428 | # for module in self.SGR_module:
429 | # sim_emb = module(sim_emb)
430 | # sim_vec = sim_emb[:, 0, :]
431 | # else:
432 | # sim_vec = self.SAF_module(sim_emb)
433 |
434 | # # sim_img.append(self.sim_eval_w(sim_vec))
435 | # sim_img_cap = self.sim_eval_w(sim_vec)
436 | # sim_img.append(sim_img_cap[: n_image])
437 | # sim_all.append(sim_img_cap[n_image: ])
438 |
439 | # # (n_image, n_caption)
440 | # sim_all = torch.cat(sim_all, 1)
441 | # sim_img = torch.cat(sim_img, 1)
442 |
443 | # for i in range(n_caption):
444 | # n_word = cap_lens[i]
445 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
446 | # cap_i_expand = cap_i.repeat(n_caption, 1, 1)
447 | # cap_glo_i = cap_glo[i].unsqueeze(0)
448 |
449 | # Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens)
450 | # sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2)
451 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
452 |
453 | # sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6
454 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
455 |
456 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
457 |
458 | # if self.module_name == 'SGR':
459 | # for module in self.SGR_module:
460 | # sim_emb = module(sim_emb)
461 | # sim_vec = sim_emb[:, 0, :]
462 | # else:
463 | # sim_vec = self.SAF_module(sim_emb)
464 |
465 | # sim_cap.append(self.sim_eval_w(sim_vec))
466 |
467 | # sim_cap = torch.cat(sim_cap, 1)
468 |
469 | # return sim_all, sim_img, sim_cap
470 |
471 | def forward_topology(self, img_emb, cap_emb, cap_lens):
472 | sim_img = []
473 | n_image = img_emb.size(0)
474 |
475 | img_ave = torch.mean(img_emb, 1)
476 | img_glo = self.v_global_w(img_emb, img_ave)
477 |
478 | img_emb_n = img_emb.unsqueeze(0).repeat(n_image, 1, 1, 1)
479 | img_emb_ni = img_emb.unsqueeze(1).repeat(1, n_image, 1, 1)
480 | sim_loc = torch.pow(torch.sub(img_emb_n, img_emb_ni), 2) + 1e-6
481 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
482 |
483 | img_glo_n = img_glo.unsqueeze(0).repeat(n_image, 1, 1)
484 | img_glo_ni = img_glo.unsqueeze(1).repeat(1, n_image, 1)
485 | sim_glo = torch.pow(torch.sub(img_glo_n, img_glo_ni), 2) + 1e-6
486 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
487 |
488 | sim_emb = torch.cat([sim_glo.unsqueeze(2), sim_loc], 2)
489 | sim_emb = sim_emb.view(-1, sim_emb.size(2), sim_emb.size(3))
490 | if self.module_name == 'SGR':
491 | for module in self.SGR_module:
492 | sim_emb = module(sim_emb)
493 | sim_vec = sim_emb[:, 0, :]
494 | else:
495 | sim_vec = self.SAF_module(sim_emb)
496 | sim_img = self.sim_eval_w(sim_vec)
497 | sim_img = sim_img.view(n_image, n_image)
498 |
499 | sim_cap = []
500 | n_caption = cap_emb.size(0)
501 |
502 | cap_glo = []
503 | for i in range (n_caption):
504 | n_word = cap_lens[i]
505 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
506 | cap_ave_i = torch.mean(cap_i, 1)
507 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i)
508 | cap_glo.append(cap_glo_i)
509 | cap_glo = torch.cat(cap_glo, 0)
510 |
511 | cap_emb_n = cap_emb.unsqueeze(0).repeat(n_image, 1, 1, 1)
512 | cap_emb_ni = cap_emb.unsqueeze(1).repeat(1, n_image, 1, 1)
513 | sim_loc = torch.pow(torch.sub(cap_emb_n, cap_emb_ni), 2) + 1e-6
514 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
515 |
516 | cap_glo_n = cap_glo.unsqueeze(0).repeat(n_image, 1, 1)
517 | cap_glo_ni = cap_glo.unsqueeze(1).repeat(1, n_image, 1)
518 | sim_glo = torch.pow(torch.sub(cap_glo_n, cap_glo_ni), 2) + 1e-6
519 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
520 |
521 | sim_emb = torch.cat([sim_glo.unsqueeze(2), sim_loc], 2)
522 | sim_emb = sim_emb.view(-1, sim_emb.size(2), sim_emb.size(3))
523 | if self.module_name == 'SGR':
524 | for module in self.SGR_module:
525 | sim_emb = module(sim_emb)
526 | sim_vec = sim_emb[:, 0, :]
527 | else:
528 | sim_vec = self.SAF_module(sim_emb)
529 | sim_cap = self.sim_eval_w(sim_vec)
530 | sim_cap = sim_cap.view(n_image, n_image)
531 |
532 | # for i in range(n_caption):
533 | # n_word = cap_lens[i]
534 | # cap_i = cap_emb[i, :n_word, :].unsqueeze(0)
535 | # cap_i_expand = cap_i.repeat(n_caption, 1, 1)
536 | # cap_glo_i = cap_glo[i].unsqueeze(0)
537 |
538 | # Context_cap = SCAN_attention(cap_i_expand, cap_emb + 1e-6, smooth=9.0, lens=cap_lens)
539 | # sim_loc = torch.pow(torch.sub(Context_cap, cap_i_expand), 2)
540 | # sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1)
541 |
542 | # sim_glo = torch.pow(torch.sub(cap_glo, cap_glo_i), 2) + 1e-6
543 | # sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1)
544 |
545 | # sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1)
546 |
547 | # if self.module_name == 'SGR':
548 | # for module in self.SGR_module:
549 | # sim_emb = module(sim_emb)
550 | # sim_vec = sim_emb[:, 0, :]
551 | # else:
552 | # sim_vec = self.SAF_module(sim_emb)
553 |
554 | # sim_cap.append(self.sim_eval_w(sim_vec))
555 |
556 | # sim_cap = torch.cat(sim_cap, 1)
557 |
558 | return sim_img, sim_cap
559 |
560 |
561 | def init_weights(self):
562 | for m in self.children():
563 | if isinstance(m, nn.Linear):
564 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features)
565 | m.weight.data.uniform_(-r, r)
566 | m.bias.data.fill_(0)
567 | elif isinstance(m, nn.BatchNorm1d):
568 | m.weight.data.fill_(1)
569 | m.bias.data.zero_()
570 |
571 |
572 | def SCAN_attention(query, context, smooth, eps=1e-8, lens=None):
573 | """
574 | query: (n_context, queryL, d)
575 | context: (n_context, sourceL, d)
576 | """
577 | # --> (batch, d, queryL)
578 | queryT = torch.transpose(query, 1, 2)
579 |
580 | # (batch, sourceL, d)(batch, d, queryL)
581 | # --> (batch, sourceL, queryL)
582 | attn = torch.bmm(context, queryT)
583 |
584 | attn = nn.LeakyReLU(0.1)(attn)
585 | attn = l2norm(attn, 2)
586 |
587 | if lens is not None:
588 | for i in range(len(lens)):
589 | attn[i, lens[i]:] = -1e9
590 |
591 | # --> (batch, queryL, sourceL)
592 | attn = torch.transpose(attn, 1, 2).contiguous()
593 | # --> (batch, queryL, sourceL
594 | attn = F.softmax(attn * smooth, dim=2)
595 |
596 | # --> (batch, sourceL, queryL)
597 | attnT = torch.transpose(attn, 1, 2).contiguous()
598 |
599 | # --> (batch, d, sourceL)
600 | contextT = torch.transpose(context, 1, 2)
601 | # (batch x d x sourceL)(batch x sourceL x queryL)
602 | # --> (batch, d, queryL)
603 | weightedContext = torch.bmm(contextT, attnT)
604 | # --> (batch, queryL, d)
605 | weightedContext = torch.transpose(weightedContext, 1, 2)
606 | weightedContext = l2norm(weightedContext, dim=-1)
607 |
608 | return weightedContext
609 |
610 | class SGRAF(object):
611 | """
612 | Similarity Reasoning and Filtration (SGRAF) Network
613 | """
614 |
615 | def __init__(self, opt):
616 | # Build Models
617 | self.opt = opt
618 | self.grad_clip = opt.grad_clip
619 | self.img_enc = EncoderImage(opt, opt.img_dim, opt.embed_size,
620 | no_imgnorm=opt.no_imgnorm)
621 | self.txt_enc = EncoderText(opt, opt.vocab_size, opt.word_dim,
622 | opt.embed_size, opt.num_layers,
623 | use_bi_gru=opt.bi_gru,
624 | no_txtnorm=opt.no_txtnorm)
625 | self.sim_enc = EncoderSimilarity(opt, opt.embed_size, opt.sim_dim,
626 | opt.module_name, opt.sgr_step)
627 |
628 | if torch.cuda.is_available():
629 | self.img_enc.cuda()
630 | self.txt_enc.cuda()
631 | self.sim_enc.cuda()
632 | cudnn.benchmark = True
633 |
634 | # Loss and Optimizer
635 | params = list(self.txt_enc.parameters())
636 | params += list(self.img_enc.parameters())
637 | params += list(self.sim_enc.parameters())
638 |
639 | self.params = params
640 |
641 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
642 | self.Eiters = 0
643 |
644 | if 'f30k' in self.opt.data_name:
645 | self.targets = torch.ones(145000).cuda()
646 | self.preds = torch.ones(145000).cuda()
647 | self.preds_targets = torch.ones(145000).cuda()
648 | self.topos = torch.ones(145000).cuda()
649 | self.topo_targets = torch.ones(145000).cuda()
650 | elif 'coco' in self.opt.data_name:
651 | self.targets = torch.ones(566435).cuda()
652 | self.preds = torch.ones(566435).cuda()
653 | self.preds_targets = torch.ones(566435).cuda()
654 | self.topos = torch.ones(566435).cuda()
655 | self.topo_targets = torch.ones(566435).cuda()
656 | elif 'cc152k' in self.opt.data_name:
657 | self.targets = torch.ones(150000).cuda()
658 | self.preds = torch.ones(150000).cuda()
659 | self.preds_targets = torch.ones(150000).cuda()
660 | self.topos = torch.ones(150000).cuda()
661 | self.topo_targets = torch.ones(150000).cuda()
662 |
663 | def state_dict(self):
664 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict(), self.targets, self.topos, self.topo_targets]
665 | return state_dict
666 |
667 | def load_state_dict(self, state_dict):
668 | self.img_enc.load_state_dict(state_dict[0])
669 | self.txt_enc.load_state_dict(state_dict[1])
670 | self.sim_enc.load_state_dict(state_dict[2])
671 | if len(state_dict) > 3:
672 | self.targets = state_dict[3]
673 | self.topos = state_dict[4]
674 | self.topo_targets = state_dict[5]
675 |
676 | def train_start(self):
677 | """switch to train mode"""
678 | self.img_enc.train()
679 | self.txt_enc.train()
680 | self.sim_enc.train()
681 |
682 | def val_start(self):
683 | """switch to evaluate mode"""
684 | self.img_enc.eval()
685 | self.txt_enc.eval()
686 | self.sim_enc.eval()
687 |
688 | def forward_emb(self, images, captions, lengths):
689 | """Compute the image and caption embeddings"""
690 | if torch.cuda.is_available():
691 | images = images.cuda()
692 | captions = captions.cuda()
693 |
694 | # Forward feature encoding
695 | img_embs = self.img_enc(images)
696 | cap_embs = self.txt_enc(captions, lengths)
697 | return img_embs, cap_embs, lengths
698 |
699 | # Modifications by extending DECL
700 | def forward_sim(self, img_embs, cap_embs, cap_lens, mode='sim'):
701 | # Forward similarity encoding
702 | # raw_sims, raw_sims_img, raw_sims_cap = self.sim_enc(img_embs, cap_embs, cap_lens)
703 | raw_sims = self.sim_enc(img_embs, cap_embs, cap_lens)
704 | if mode == 'sim':
705 | # return torch.sigmoid(raw_sims), raw_sims_img, raw_sims_cap
706 | return torch.sigmoid(raw_sims)
707 |
708 | def train_emb(self, images, captions, lengths, ids=None, *args):
709 | """One training step given images and captions.
710 | """
711 | self.Eiters += 1
712 | self.logger.update('Eit', self.Eiters)
713 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
714 |
715 | # compute the embeddings
716 | img_embs, cap_embs, cap_lens = self.forward_emb(images, captions, lengths)
717 | sims = self.forward_sim(img_embs, cap_embs, cap_lens)
718 |
719 | # measure accuracy and record loss
720 | self.optimizer.zero_grad()
721 | loss = self.forward_loss(sims)
722 |
723 | # compute gradient and do SGD step
724 | loss.backward()
725 | if self.grad_clip > 0:
726 | clip_grad_norm_(self.params, self.grad_clip)
727 | self.optimizer.step()
728 |
--------------------------------------------------------------------------------