├── docker_cfg
├── inputrc
└── Dockerfile
├── ATTOP
├── data
│ ├── __init__.py
│ └── dataset.py
├── models
│ ├── __init__.py
│ └── models.py
└── LICENSE
├── taskmodularnets
├── __init__.py
├── data
│ ├── __init__.py
│ └── dataset.py
├── utils
│ ├── __init__.py
│ └── utils.py
└── LICENSE
├── Causal.png
├── AO_CLEVr_examples.png
├── .gitignore
├── LICENSE_COSMO
├── LICENSE_HSIC
├── LICENSE
├── HSIC.py
├── experiment_zappos_TMNsplit.py
├── README.md
├── data.py
├── params.py
├── offline_early_stop.py
├── main.py
├── useful_utils.py
├── model.py
├── eval.py
├── train.py
└── COSMO_utils.py
/docker_cfg/inputrc:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ATTOP/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ATTOP/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/taskmodularnets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Causal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/causal_comp/HEAD/Causal.png
--------------------------------------------------------------------------------
/AO_CLEVr_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/causal_comp/HEAD/AO_CLEVr_examples.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | outputs/
3 | logs/
4 | logs*/
5 | tmp*/
6 | beat.touch
7 |
8 | notebooks/figures
9 |
10 | *__pycache__
11 |
12 | .ipynb_checkpoints/
13 | .vscode/
14 | .idea/
15 | .ipynb_checkpoints
16 |
--------------------------------------------------------------------------------
/taskmodularnets/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # taskmodularnets/LICENSE file.
6 | #
--------------------------------------------------------------------------------
/taskmodularnets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # taskmodularnets/LICENSE file.
6 | #
7 |
--------------------------------------------------------------------------------
/ATTOP/data/dataset.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the following repo
5 | # (released under the MIT License).
6 | #
7 | # Source:
8 | # https://github.com/Tushar-N/attributes-as-operators/blob/master/data/dataset.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in ATTOP/LICENSE
12 | # ---------------------------------------------------------------
13 |
14 | import numpy as np
15 |
16 | def sample_negative(self, attr, obj):
17 | new_attr, new_obj = self.train_pairs[np.random.choice(len(self.train_pairs))]
18 | if new_attr == attr and new_obj == obj:
19 | return self.sample_negative(attr, obj)
20 | return (self.attr2idx[new_attr], self.obj2idx[new_obj])
21 |
--------------------------------------------------------------------------------
/docker_cfg/Dockerfile:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | FROM nvcr.io/nvidia/pytorch:20.01-py3
9 | RUN yes | pip install tensorboardX==1.2
10 | RUN yes | pip install wget
11 | RUN yes | pip install scikit-learn
12 | RUN yes | pip install ipymd
13 | RUN yes | pip install Pillow
14 | RUN yes | apt-get update
15 | RUN yes | apt-get install tmux
16 | RUN yes | apt-get install htop
17 | RUN yes | apt-get install bc
18 | RUN yes | conda install -c conda-forge accimage
19 | RUN yes | pip install notifiers
20 | RUN yes | pip install pyyaml
21 | RUN yes | pip install ruamel.yaml
22 | RUN yes | apt install rsync
23 | RUN yes | apt install git
24 | RUN yes | pip install dataclasses
25 | RUN yes | pip install simple_parsing
26 | RUN yes | pip install munch
27 | COPY inputrc /root/.inputrc
28 |
29 |
30 |
--------------------------------------------------------------------------------
/ATTOP/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE_COSMO:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2020 Yuval Atzmon
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE_HSIC:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2020 Tom Beer and Bar Eini-Porat
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/ATTOP/models/models.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the following repo
5 | # (released under the MIT License).
6 | #
7 | # Source:
8 | # https://github.com/Tushar-N/attributes-as-operators/blob/master/models/models.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in ATTOP/LICENSE
12 | # ---------------------------------------------------------------
13 |
14 | import torch.nn as nn
15 |
16 |
17 | class MLP(nn.Module):
18 | def __init__(self, inp_dim, out_dim, num_layers=1, relu=True, bias=True):
19 | super(MLP, self).__init__()
20 | mod = []
21 | for L in range(num_layers-1):
22 | mod.append(nn.Linear(inp_dim, inp_dim, bias=bias))
23 | mod.append(nn.ReLU(True))
24 |
25 | mod.append(nn.Linear(inp_dim, out_dim, bias=bias))
26 | if relu:
27 | mod.append(nn.ReLU(True))
28 |
29 | self.mod = nn.Sequential(*mod)
30 |
31 | def forward(self, x):
32 | output = self.mod(x)
33 | return output
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | Copyright 2021 NVIDIA
4 |
5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 |
11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 |
--------------------------------------------------------------------------------
/HSIC.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from a file in the following repo
5 | # (released under the MIT License).
6 | #
7 | # Source:
8 | # https://github.com/tom-beer/deep-scientific-discovery/blob/master/hsic.py
9 | #
10 | # The license for the original version of this file can be
11 | # found in this directory (LICENSE_HSIC). The modifications
12 | # to this file are subject to the License
13 | # located at the root directory.
14 | # ---------------------------------------------------------------
15 |
16 |
17 | import torch
18 | import numpy as np
19 |
20 | from useful_utils import ns_profiling_label
21 |
22 |
23 | def pairwise_distances(x):
24 | x_distances = torch.sum(x**2,-1).reshape((-1,1))
25 | return -2*torch.mm(x,x.t()) + x_distances + x_distances.t()
26 |
27 | def kernelMatrixGaussian(x, sigma=1):
28 |
29 | pairwise_distances_ = pairwise_distances(x)
30 | gamma = -1.0 / (sigma ** 2)
31 | return torch.exp(gamma * pairwise_distances_)
32 |
33 | def kernelMatrixLinear(x):
34 | return torch.matmul(x,x.t())
35 |
36 | # check
37 | def HSIC(X, Y, kernelX="Gaussian", kernelY="Gaussian", sigmaX=1, sigmaY=1,
38 | log_median_pairwise_distance=False):
39 | m,_ = X.shape
40 | assert(m>1)
41 |
42 | median_pairwise_distanceX, median_pairwise_distanceY = np.nan, np.nan
43 | if log_median_pairwise_distance:
44 | # This calc takes a long time. It is used for debugging and disabled by default.
45 | with ns_profiling_label('dist'):
46 | median_pairwise_distanceX = median_pairwise_distance(X)
47 | median_pairwise_distanceY = median_pairwise_distance(Y)
48 |
49 | with ns_profiling_label('Hkernel'):
50 | K = kernelMatrixGaussian(X,sigmaX) if kernelX == "Gaussian" else kernelMatrixLinear(X)
51 | L = kernelMatrixGaussian(Y,sigmaY) if kernelY == "Gaussian" else kernelMatrixLinear(Y)
52 |
53 | with ns_profiling_label('Hfinal'):
54 | H = torch.eye(m, device='cuda') - 1.0/m * torch.ones((m,m), device='cuda')
55 | H = H.float().cuda()
56 |
57 | Kc = torch.mm(H,torch.mm(K,H))
58 |
59 | HSIC = torch.trace(torch.mm(L,Kc))/((m-1)**2)
60 | return HSIC, median_pairwise_distanceX, median_pairwise_distanceY
61 |
62 |
63 | def median_pairwise_distance(X):
64 | t = pairwise_distances(X).detach()
65 | triu_indices = t.triu(diagonal=1).nonzero().T
66 |
67 | if triu_indices[0].shape[0] == 0 or triu_indices[1].shape[0] == 0:
68 | return 0.
69 | else:
70 | return torch.median(t[triu_indices[0], triu_indices[1]]).item()
71 |
72 |
--------------------------------------------------------------------------------
/experiment_zappos_TMNsplit.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | from pathlib import Path
9 |
10 | from dataclasses import dataclass
11 | from simple_parsing import ArgumentParser, ConflictResolution
12 |
13 | from params import CommandlineArgs
14 | import shlex
15 |
16 | from main import main
17 | from useful_utils import to_simple_parsing_args
18 |
19 |
20 | @dataclass
21 | class ScriptArgs():
22 | output_dir: str = '/tmp/output_results_dir'
23 | """ Directory to save the results """
24 |
25 | data_dir: str = '/local_data/zap_tmn'
26 | """ Directory that holds the data"""
27 |
28 | seed: int = 0
29 | """ Random seed. For zappos, choose \in [0..4] """
30 |
31 | use_wandb: bool = True
32 | """ log results to w&b """
33 |
34 | wandb_user: str = 'none'
35 |
36 | @classmethod
37 | def get_args(cls):
38 | args: cls = to_simple_parsing_args(cls)
39 | return args
40 |
41 | if __name__ == '__main__':
42 | script_args = ScriptArgs.get_args()
43 |
44 | # Seed=0: Unseen=27.4, Seen=33.9, Harmonic=30.3, Closed=56.7, AUC=25.2
45 |
46 | # Set commandline arguments for the experiment.
47 | # see params.py for their documentation and meaning with respect to definitions in the paper
48 |
49 | experiment_commandline_opt_string = f"""--output_dir={script_args.output_dir} --data_dir={script_args.data_dir}
50 | --use_wandb={script_args.use_wandb} --wandb_user={script_args.wandb_user}
51 | --learn_embeddings=1 --nlayers_label=0 --nlayers_joint_ao=2 --h_dim=300 --E_num_hidden=2 --E_num_common_hidden=1
52 | --dataset_name=zap_tmn --dataset_variant=irrelevant --optimizer_name=adam --lr=0.0003 --max_epoch=1000 --alternate_ys=0
53 | --weight_decay=5e-05 --calc_AUC=1 --lambda_feat=0 --mu_img_feat=0 --lambda_aux_disjoint=100 --lambda_aux_img=0 --metadata_from_pkl=0
54 | --lambda_ao_emb=1000 --lambda_aux=0 --seed={script_args.seed} --alphaH=450.0 --HSIC_coeff=0.0045 --balanced_loss=0 --num_split=0
55 | --num_workers=8 --report_imbalanced_metrics=True"""
56 | if script_args.use_wandb:
57 | experiment_commandline_opt_string += """ --instance_name=dev_zap_tmn --project_name=causal_comp_prep_zap_tmn"""
58 |
59 | parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
60 | parser.add_arguments(CommandlineArgs, dest='cfg')
61 | main_args: CommandlineArgs = parser.parse_args(shlex.split(experiment_commandline_opt_string)).cfg
62 | main(main_args)
63 |
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | # A causal view of compositional zero-shot recognition
4 | This repository hosts the dataset and source code for the paper "A causal view of compositional zero-shot recognition". Yuval Atzmon, Felix Kreuk, Uri Shalit, Gal Chechik, NeurIPS 2020 (Spotlight)
5 |
6 |
7 | ## Code
8 |
9 | ### Setup
10 | #### Build docker image:
11 | ```
12 | cd docker_cfg
13 | docker build --network=host -t causal_comp -f Dockerfile .
14 | cd ..
15 | ```
16 |
17 |
18 | #### Prepare UT Zappos50KUT with TMN split:
19 | To reproduce Zappos50KUT results according to TMN evaluation split, you should prepare the dataset according to
20 | taskmodularnets:
21 | https://github.com/facebookresearch/taskmodularnets
22 |
23 | The Zappos50KUT dataset is for academic, non-commercial use only, released by:
24 |
25 | * A. Yu and K. Grauman. "Fine-Grained Visual Comparisons with Local Learning". In CVPR, 2014.
26 | * A. Yu and K. Grauman. "Semantic Jitter: Dense Supervision for Visual Comparisons via Synthetic Images". In ICCV, 2017.
27 |
28 |
29 | #### Reproduce Zappos experiment.
30 | Notes:
31 | 1. Set `DATA_DIR` to the directory containing the data.
32 | 2. Set `SEED` to [0..4]
33 | 3. Set `SOURCE_CODE_DIR` to the current project workdir
34 | 4. Set `OUTPUT_DIR` to the directory to save the result
35 |
36 | ```
37 | SEED=0 # set seed \in [0..4]
38 | SOURCE_CODE_DIR=$HOME/git/causal_comp_prep/
39 | DATA_DIR=/local_data/zap_tmn # SET HERE THE DATA DIR
40 | DATASET=zap_tmn
41 | OUTPUT_DIR=/tmp/output/causal_comp_${DATASET}__seed${SEED}
42 |
43 | # prepare output directory
44 | mkdir -p ${OUTPUT_DIR}
45 | rm -r ${OUTPUT_DIR}/*
46 | # run experiment
47 | docker run --net=host -v ${SOURCE_CODE_DIR}:/workspace/causal_comp -v ${DATA_DIR}:/data/zap_tmn -v ${OUTPUT_DIR}:/output --user $(id -u):$(id -g) --shm-size=1g --ulimit memlock=-1 --ulimit stack=6710886 --rm -it causal_comp /bin/bash -c "cd /workspace/causal_comp/; python experiment_zappos_TMNsplit.py --seed=${SEED} --output_dir=/output --data_dir=/data/zap_tmn --use_wandb=0"
48 | ```
49 |
50 | #### Reproduce AO-CLEVr experiments.
51 | Hyperparams for reproducing the results will be published soon
52 |
53 |
54 | ## AO-CLEVr Dataset
55 |
56 | AO-CLEVr is a new synthetic-images dataset containing images of "easy" Attribute-Object categories, based on the CLEVr framework (Johnson et al. CVPR 2017). AO-CLEVr has attribute-object pairs created from 8 attributes: \{ red, purple, yellow, blue, green, cyan, gray, brown \} and 3 object shapes \{sphere, cube, cylinder\}, yielding 24 attribute-object pairs. Each pair consists of 7500 images. Each image has a single object that consists of the attribute-object pair. The object is randomly assigned one of two sizes (small/large), one of two materials (rubber/metallic), a random position, and random lightning according to CLEVr defaults.
57 |
58 | 
59 |
60 | The dataset can be accessed from [the following url](https://drive.google.com/drive/folders/1BBwW9VqzROgJXmvnfXcOxbLob8FB_jLf).
61 |
62 |
63 | ## Cite the paper
64 | If you use the contents of this project, please cite our paper.
65 |
66 | @inproceedings{neurips2020_causal_comp_atzmon,
67 | author = {Atzmon, Yuval and Kreuk, Felix and Shalit, Uri and Chechik, Gal},
68 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
69 | title = {A causal view of compositional zero-shot recognition},
70 | year = {2020}
71 | }
72 |
73 | For business inquiries, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com)
74 | For press and other inquiries, please contact Hector Marinez at [hmarinez@nvidia.com](hmarinez@nvidia.com)
75 |
--------------------------------------------------------------------------------
/taskmodularnets/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # taskmodularnets/LICENSE file.
6 | #
7 | import os
8 | import re
9 | import torch
10 | import torchvision.transforms as transforms
11 | import torchvision.datasets as torchdata
12 | import numpy as np
13 | from torch.autograd import Variable
14 | import itertools
15 | import copy
16 |
17 | # Save the training script and all the arguments
18 | import shutil
19 |
20 |
21 | def save_args(args):
22 | shutil.copy('train_modular.py', args.cv_dir + '/' + args.name + '/')
23 | shutil.copy('archs/models.py', args.cv_dir + '/' + args.name + '/')
24 | with open(args.cv_dir + '/' + args.name + '/args.txt', 'w') as f:
25 | f.write(str(args))
26 |
27 |
28 | class UnNormalizer:
29 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
30 | self.mean = mean
31 | self.std = std
32 |
33 | def __call__(self, tensor):
34 | for b in range(tensor.size(0)):
35 | for t, m, s in zip(tensor[b], self.mean, self.std):
36 | t.mul_(s).add_(m)
37 | return tensor
38 |
39 |
40 | def chunks(l, n):
41 | """Yield successive n-sized chunks from l."""
42 | for i in range(0, len(l), n):
43 | yield l[i:i + n]
44 |
45 |
46 | def flatten(l):
47 | return list(itertools.chain.from_iterable(l))
48 |
49 |
50 | class AverageMeter(object):
51 | """Computes and stores the average and current value"""
52 |
53 | def __init__(self):
54 | self.reset()
55 |
56 | def reset(self):
57 | self.val = 0
58 | self.avg = 0
59 | self.sum = 0
60 | self.count = 0
61 |
62 | def update(self, val, n=1):
63 | self.val = val
64 | self.sum += val * n
65 | self.count += n
66 | self.avg = self.sum / self.count
67 |
68 |
69 | def calc_pr_ovr_noref(counts, out):
70 | """
71 | [P, R, score, ap] = calc_pr_ovr(counts, out, K)
72 | Input :
73 | counts : number of occurrences of this word in the ith image
74 | out : score for this image
75 | Output :
76 | P, R : precision and recall
77 | score : score which corresponds to the particular precision and recall
78 | ap : average precision
79 | """
80 | #binarize counts
81 | out = out.astype(np.float64)
82 | counts = np.array(counts > 0, dtype=np.float32)
83 | tog = np.hstack((counts[:, np.newaxis].astype(np.float64),
84 | out[:, np.newaxis].astype(np.float64)))
85 | ind = np.argsort(out)
86 | ind = ind[::-1]
87 | score = np.array([tog[i, 1] for i in ind])
88 | sortcounts = np.array([tog[i, 0] for i in ind])
89 |
90 | tp = sortcounts
91 | fp = sortcounts.copy()
92 | for i in range(sortcounts.shape[0]):
93 | if sortcounts[i] >= 1:
94 | fp[i] = 0.
95 | elif sortcounts[i] < 1:
96 | fp[i] = 1.
97 |
98 | tp = np.cumsum(tp)
99 | fp = np.cumsum(fp)
100 | # P = np.cumsum(tp)/(np.cumsum(tp) + np.cumsum(fp));
101 | P = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
102 |
103 | numinst = np.sum(counts)
104 |
105 | R = tp / (numinst + 1e-10)
106 |
107 | ap = voc_ap(R, P)
108 | return P, R, score, ap
109 |
110 |
111 | def voc_ap(rec, prec):
112 | # correct AP calculation
113 | # first append sentinel values at the end
114 | mrec = np.concatenate(([0.], rec, [1.]))
115 | mpre = np.concatenate(([0.], prec, [0.]))
116 |
117 | # compute the precision envelope
118 | for i in range(mpre.size - 1, 0, -1):
119 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
120 |
121 | # to calculate area under PR curve, look for points
122 | # where X axis (recall) changes value
123 | i = np.where(mrec[1:] != mrec[:-1])[0]
124 |
125 | # and sum (\Delta recall) * prec
126 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
127 | return ap
128 |
129 |
130 | def roll(x, n):
131 | return torch.cat((x[-n:], x[:-n]))
132 |
133 |
134 | def load_word_embeddings(emb_file, vocab):
135 |
136 | vocab = [v.lower() for v in vocab]
137 |
138 | embeds = {}
139 | for line in open(emb_file, 'r'):
140 | line = line.strip().split(' ')
141 | wvec = torch.FloatTensor(list(map(float, line[1:])))
142 | embeds[line[0]] = wvec
143 |
144 | # for zappos (should account for everything)
145 | custom_map = {
146 | 'Faux.Fur': 'fur',
147 | 'Faux.Leather': 'leather',
148 | 'Full.grain.leather': 'leather',
149 | 'Hair.Calf': 'hair',
150 | 'Patent.Leather': 'leather',
151 | 'Nubuck': 'leather',
152 | 'Boots.Ankle': 'boots',
153 | 'Boots.Knee.High': 'knee-high',
154 | 'Boots.Mid-Calf': 'midcalf',
155 | 'Shoes.Boat.Shoes': 'shoes',
156 | 'Shoes.Clogs.and.Mules': 'clogs',
157 | 'Shoes.Flats': 'flats',
158 | 'Shoes.Heels': 'heels',
159 | 'Shoes.Loafers': 'loafers',
160 | 'Shoes.Oxfords': 'oxfords',
161 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
162 | 'traffic_light': 'light',
163 | 'trash_can': 'trashcan'
164 | }
165 | for k in custom_map:
166 | embeds[k.lower()] = embeds[custom_map[k]]
167 |
168 | embeds = [embeds[k] for k in vocab]
169 | embeds = torch.stack(embeds)
170 | print('loaded embeddings', embeds.size())
171 |
172 | return embeds
173 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | from typing import NamedTuple
9 |
10 | import numpy as np
11 |
12 | from pathlib import Path
13 | from ATTOP.data.dataset import sample_negative as ATTOP_sample_negative
14 |
15 | import torch
16 | from torch.utils import data
17 |
18 | from useful_utils import categorical_histogram, get_and_update_num_calls
19 | from COSMO_utils import temporary_random_numpy_seed
20 |
21 |
22 | class DataItem(NamedTuple):
23 | """ A NamedTuple for returning a Dataset item """
24 | feat: torch.Tensor
25 | pos_attr_id: int
26 | pos_obj_id: int
27 | neg_attr_id: int
28 | neg_obj_id: int
29 | image_fname: str
30 |
31 |
32 | class CompDataFromDict():
33 | # noinspection PyMissingConstructor
34 | def __init__(self, dict_data: dict, data_subset: str, data_dir: str):
35 |
36 | # define instance variables to be retrieved from struct_data_dict
37 | self.split: str = 'TBD'
38 | self.phase: str = 'TBD'
39 | self.feat_dim: int = -1
40 | self.objs: list = []
41 | self.attrs: list = []
42 | self.attr2idx: dict = {}
43 | self.obj2idx: dict = {}
44 | self.pair2idx: dict = {}
45 | self.seen_pairs: list = []
46 | self.all_open_pairs: list = []
47 | self.closed_unseen_pairs: list = []
48 | self.unseen_closed_val_pairs: list = []
49 | self.unseen_closed_test_pairs: list = []
50 | self.train_data: tuple = tuple()
51 | self.val_data: tuple = tuple()
52 | self.test_data: tuple = tuple()
53 |
54 | self.data_dir: str = data_dir
55 |
56 | # retrieve instance variables from struct_data_dict
57 | vars(self).update(dict_data)
58 | self.data = dict_data[data_subset]
59 |
60 | self.activations = {}
61 | features_dict = torch.load(Path(data_dir) / 'features.t7')
62 | for i, img_filename in enumerate(features_dict['files']):
63 | self.activations[img_filename] = features_dict['features'][i]
64 |
65 | self.input_shape = (self.feat_dim,)
66 | self.num_objs = len(self.objs)
67 | self.num_attrs = len(self.attrs)
68 | self.num_seen_pairs = len(self.seen_pairs)
69 | self.shape_obj_attr = (self.num_objs, self.num_attrs)
70 |
71 | self.flattened_seen_pairs_mask = self.get_flattened_pairs_mask(self.seen_pairs)
72 | self.flattened_closed_unseen_pairs_mask = self.get_flattened_pairs_mask(self.closed_unseen_pairs)
73 | self.flattened_all_open_pairs_mask = self.get_flattened_pairs_mask(self.all_open_pairs)
74 | self.seen_pairs_joint_class_ids = np.where(self.flattened_seen_pairs_mask)
75 |
76 | self.y1_freqs, self.y2_freqs, self.pairs_freqs = self._calc_freqs()
77 | self._just_load_labels = False
78 |
79 | self.train_pairs = self.seen_pairs
80 |
81 | def sample_negative(self, attr, obj):
82 | return ATTOP_sample_negative(self, attr, obj)
83 |
84 | def get_flattened_pairs_mask(self, pairs):
85 | pairs_ids = np.array([(self.obj2idx[obj], self.attr2idx[attr]) for attr, obj in pairs])
86 | flattened_pairs = np.zeros(self.shape_obj_attr, dtype=bool) # init an array of False
87 | flattened_pairs[tuple(zip(*pairs_ids))] = True
88 | flattened_pairs = flattened_pairs.flatten()
89 | return flattened_pairs
90 |
91 | def just_load_labels(self, just_load_labels=True):
92 | self._just_load_labels = just_load_labels
93 |
94 | def get_all_labels(self):
95 | attrs = []
96 | objs = []
97 | joints = []
98 | self.just_load_labels(True)
99 | for attrs_batch, objs_batch in self:
100 | if isinstance(attrs_batch, torch.Tensor):
101 | attrs_batch = attrs_batch.cpu().numpy()
102 | if isinstance(objs_batch, torch.Tensor):
103 | objs_batch = objs_batch.cpu().numpy()
104 | joint = self.to_joint_label(objs_batch, attrs_batch)
105 |
106 | attrs.append(attrs_batch)
107 | objs.append(objs_batch)
108 | joints.append(joint)
109 |
110 | self.just_load_labels(False)
111 | attrs = np.array(attrs)
112 | objs = np.array(objs)
113 | return attrs, objs, joints
114 |
115 | def _calc_freqs(self):
116 | y2_train, y1_train, ys_joint_train = self.get_all_labels()
117 | y1_freqs = categorical_histogram(y1_train, range(self.num_objs), plot=False, frac=True)
118 | y1_freqs[y1_freqs == 0] = np.nan
119 | y2_freqs = categorical_histogram(y2_train, range(self.num_attrs), plot=False, frac=True)
120 | y2_freqs[y2_freqs == 0] = np.nan
121 |
122 | pairs_freqs = categorical_histogram(ys_joint_train,
123 | range(self.num_objs * self.num_attrs),
124 | plot=False, frac=True)
125 | pairs_freqs[pairs_freqs == 0] = np.nan
126 | return y1_freqs, y2_freqs, pairs_freqs
127 |
128 | def get(self, name):
129 | return vars(self).get(name)
130 |
131 | def __getitem__(self, idx):
132 | image_fname, attr, obj = self.data[idx]
133 | pos_attr_id, pos_obj_id = self.attr2idx[attr], self.obj2idx[obj]
134 | if self._just_load_labels:
135 | return pos_attr_id, pos_obj_id
136 |
137 | num_calls_cnt = get_and_update_num_calls(self.__getitem__)
138 |
139 | negative_attr_id, negative_obj_id = -1, -1 # default values
140 | if self.phase == 'train':
141 | # we set a temp np seed to override a weird issue with
142 | # sample_negative() at __getitem__, where the sampled pairs
143 | # could not be deterministically reproduced:
144 | # Now at each call to _getitem_ we set the seed to a 834276 (chosen randomly) + the number of calls to _getitem_
145 | with temporary_random_numpy_seed(834276 + num_calls_cnt):
146 | # draw a negative pair
147 | negative_attr_id, negative_obj_id = self.sample_negative(attr, obj)
148 |
149 | item = DataItem(
150 | feat=self.activations[image_fname],
151 | pos_attr_id=pos_attr_id,
152 | pos_obj_id=pos_obj_id,
153 | neg_attr_id=negative_attr_id,
154 | neg_obj_id=negative_obj_id,
155 | image_fname=image_fname,
156 | )
157 | return item
158 |
159 | def __len__(self):
160 | return len(self.data)
161 |
162 | def to_joint_label(self, y1_batch, y2_batch):
163 | return (y1_batch * self.num_attrs + y2_batch)
164 |
165 |
166 | def get_data_loaders(train_dataset, valid_dataset, test_dataset, batch_size,
167 | num_workers=10, test_batchsize=None, shuffle_eval_set=True):
168 | if test_batchsize is None:
169 | test_batchsize = batch_size
170 |
171 | pin_memory = True
172 | if num_workers == 0:
173 | pin_memory = False
174 | print('num_workers = ', num_workers)
175 | print('pin_memory = ', pin_memory)
176 | train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
177 | pin_memory=pin_memory)
178 | valid_loader = None
179 | if valid_dataset is not None and len(valid_dataset) > 0:
180 | valid_loader = data.DataLoader(valid_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set,
181 | num_workers=num_workers, pin_memory=pin_memory)
182 | test_loader = data.DataLoader(test_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set,
183 | num_workers=num_workers, pin_memory=pin_memory)
184 | return test_loader, train_loader, valid_loader
185 |
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | from collections import namedtuple
11 | from pathlib import Path
12 | import numpy as np
13 |
14 | from dataclasses import dataclass
15 | from simple_parsing import ArgumentParser, ConflictResolution, field
16 |
17 | from useful_utils import to_simple_parsing_args
18 |
19 |
20 | @dataclass
21 | class DataCfg():
22 | # dataset name
23 | dataset_name: str = 'ao_clevr'
24 |
25 | # default: {dataset_basedir}/{dataset_name}
26 | data_dir: str = 'default'
27 |
28 | dataset_basedir: str = 'data'
29 |
30 | # VT_random|UV_random
31 | dataset_variant: str = 'VT_random'
32 |
33 | num_split: int = 5000
34 |
35 | metadata_from_pkl: bool = True
36 |
37 | def __post_init__(self):
38 | if self.data_dir == 'default':
39 | self.data_dir = os.path.join(self.dataset_basedir, self.dataset_name)
40 |
41 | def __getitem__(self, key):
42 | """ Allow accessing instance attributes as dictionary keys """
43 | return getattr(self, key)
44 |
45 |
46 | @dataclass
47 | class MetricsCfg:
48 | calc_AUC: bool = True
49 | """ A flag to calc AUC metric (slower) """
50 |
51 | def __post_init__(self):
52 | assert self.calc_AUC
53 |
54 |
55 | @dataclass(frozen=True)
56 | class EarlyStopMetric:
57 | # metric name
58 | metric: str
59 |
60 | # min: minimize or max: maximize
61 | polarity: str
62 |
63 |
64 |
65 | @dataclass
66 | class ModelCfg():
67 | """" The hyper params to configure the architecture.
68 | Note: For backward compatibility, some of the variable names here are different than those mentioned in the
69 | paper. We explain in comments, how each variable is referenced at the paper.
70 | """
71 |
72 | h_dim: int = 150
73 | """ Number of hidden units in each MLP layer. """
74 |
75 | nlayers_label: int = 2
76 | """ Number of MLP layer, for h_O anb h_A. """
77 |
78 | nlayers_joint_ao: int = 2
79 | """ Number of MLP layers, for g. """
80 |
81 |
82 | E_num_hidden: int = 2
83 | """ Number of layers, for g^-1_A and g^-1_O """
84 |
85 | E_num_common_hidden: int = 0
86 | """ Number of layers, when projecting pretrained image features to the feature space \X
87 | (as explained for Zappos in section C.1)
88 | """
89 |
90 | mlp_activation: str = 'leaky_relu'
91 |
92 | learn_embeddings: bool = True
93 | """ Set to False to train as VisProd, with CE loss rather than an embedding loss. """
94 |
95 | def __post_init__(self):
96 | self.VisProd: bool = not self.learn_embeddings
97 |
98 | assert self.E_num_common_hidden <= self.E_num_hidden
99 |
100 |
101 | @dataclass
102 | class TrainCfg:
103 | """" The hyper params to configure the training. See Supplementary C. "implementation details".
104 | Note: For backward compatibility, some of the variable names here are different than those mentioned in the
105 | paper. We explain in comments, how each variable is referenced at the paper.
106 | """
107 |
108 | metrics: MetricsCfg
109 |
110 |
111 | batch_size: int = 2048
112 | """ batch size """
113 |
114 | lr: float = 0.003
115 | """ initial learning rate """
116 |
117 | max_epoch: int = 5
118 | """ max number of epochs """
119 |
120 | alternate_ys: int = 21
121 | """ Whether and how to use alternate training. 0: no alternation|12: object then attr|21: attr then obj"""
122 |
123 |
124 | lr_step2: float = 3e-05 #
125 | """ Step 2 initial learning rate. Only relevant if alternate_ys != 0 """
126 |
127 | max_epoch_step2: int = 1000
128 | """ Step 2 max number of epochs. Only relevant if alternate_ys != 0 """
129 |
130 | weight_decay: float = 0.1
131 | """ weight_decay """
132 |
133 | HSIC_coeff: float = 10
134 | """ \lambda_rep in paper """
135 |
136 | alphaH: float = 0
137 | """ \lambda_oh in paper """
138 |
139 | alphaH_step2 = -1
140 | """ Step 2 \lambda_oh. Only relevant if alternate_ys != 0. If set to -1, then take --alphaH value """
141 |
142 | lambda_CE: float = 1
143 | """ a coefficient for L_data """
144 |
145 | lambda_feat: float = 1
146 | """ \lambda_ao in paper """
147 |
148 | lambda_ao_emb: float = 0
149 | """ \lambda_ao when projecting pretrained image features to the feature space \X
150 | (as explained for Zappos in section C.1)
151 | Note: --lambda_feat and --lambda_ao_emb cant be both non-zero (we raise exception for this case).
152 | """
153 |
154 | lambda_aux_disjoint: float = 100
155 | """ \lambda_icore in paper """
156 |
157 | lambda_aux_img: float = 10
158 | """ \lambda_ig in paper, when --lambda_feat>0"""
159 |
160 | lambda_aux: float = 0
161 | """ \lambda_ig in paper, when --lambda_ao_emb>0"""
162 |
163 | mu_img_feat: float = 0.1
164 | """ \lambda_ao at inference time """
165 |
166 | balanced_loss: bool = True
167 | """ Weighed the loss of ||φa−ha||^2 and ||φo−ho||^2 according to the respective attribute and object frequencies in
168 | the training set (Described in supplementary C.2). """
169 |
170 | triplet_loss_margin: float = 0.5
171 | """ The margin for the triplet loss. Same value as used by attributes-as-operators """
172 |
173 | optimizer_name: str = 'Adam'
174 |
175 | seed: int = 0
176 | """ random seed """
177 |
178 | test_batchsize: int = -1
179 | """batch-size for inference; default uses the training batch size"""
180 |
181 | verbose: bool = True
182 | num_workers: int = 8
183 | shuffle_eval_set: bool = True
184 | n_iter: int = field(init=False)
185 | mu_disjoint: float = field(init=False)
186 | mu_ao_emb: float = field(init=False)
187 | primary_early_stop_metric: EarlyStopMetric = field(init=False)
188 | freeze_class1: bool = field(init=False)
189 | freeze_class2: bool = field(init=False)
190 | Y12_balance_coeff: float = field(init=False)
191 |
192 |
193 | def __post_init__(self):
194 | # sanity checks (assertions)
195 | assert (self.alternate_ys in [0, 12, 21])
196 | assert not ((self.lambda_ao_emb > 0) and (self.lambda_feat > 0))
197 | if self.lambda_feat == 0:
198 | assert(self.mu_img_feat == 0)
199 |
200 | # assignments
201 | if self.test_batchsize <= 0:
202 | self.test_batchsize = self.batch_size
203 | if self.alphaH_step2 < 0:
204 | self.alphaH_step2 = self.alphaH
205 |
206 | self.mu_disjoint = self.lambda_CE
207 | self.mu_ao_emb = self.lambda_ao_emb
208 | self.primary_early_stop_metric = EarlyStopMetric('epoch', 'max')
209 | self.Y12_balance_coeff = 0.5
210 | self.freeze_class1 = False
211 | self.freeze_class2 = False
212 | self.n_iter = -1 # Should be updated after data is loaded
213 |
214 |
215 |
216 | def set_n_iter(self, num_train_samples, max_epoch = None):
217 | if max_epoch is None:
218 | max_epoch = self.max_epoch
219 | self.n_iter = int((max_epoch) * np.ceil(num_train_samples / self.batch_size))
220 |
221 | def __getitem__(self, key):
222 | """ Allow accessing instance attributes as dictionary keys """
223 | return getattr(self, key)
224 |
225 | @dataclass
226 | class ExperimentCfg():
227 | output_dir: str = field(alias="-o")
228 | ignore_existing_output_contents: bool = True
229 | gpu: int = 0
230 | use_wandb: bool = True
231 | wandb_user: str = 'none'
232 | project_name: str = 'causal_comp_prep'
233 | experiment_name: str = 'default'
234 | instance_name: str = 'default'
235 | git_hash: str = ''
236 | sync_uid: str = ''
237 | report_imbalanced_metrics: bool = False
238 |
239 | # float precision when logging to CSV file
240 | csv_precision: int = 8
241 |
242 | delete_dumped_preds: bool = True
243 |
244 |
245 | def __post_init__(self):
246 |
247 | # Set default experiment name
248 | self._set_default_experiment_name()
249 |
250 |
251 | def _set_default_experiment_name(self):
252 | at_ngc: bool = ('NGC_JOB_ID' in os.environ.keys())
253 | at_docker = np.in1d(['/opt/conda/bin'], np.array(sys.path))[0]
254 | at_local_docker = at_docker and not at_ngc
255 | name_suffix = '_local'
256 | if at_local_docker:
257 | name_suffix += '_docker'
258 | elif at_ngc:
259 | name_suffix = '_ngc'
260 | if self.experiment_name == 'default':
261 | self.experiment_name = 'dev' + name_suffix
262 | if self.instance_name == 'default':
263 | self.instance_name = 'dev' + name_suffix
264 |
265 | def __getitem__(self, key):
266 | """ Allow accessing instance attributes as dictionary keys """
267 | return getattr(self, key)
268 |
269 |
270 | @dataclass
271 | class CommandlineArgs():
272 | model: ModelCfg
273 | data: DataCfg
274 | train: TrainCfg
275 | exp: ExperimentCfg
276 | device: str = 'cuda'
277 |
278 | @classmethod
279 | def get_args(cls):
280 | args: cls = to_simple_parsing_args(cls)
281 | return args
282 |
--------------------------------------------------------------------------------
/offline_early_stop.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | # from loguru import logger
9 | import argparse
10 | import glob
11 | import os
12 | import tempfile
13 | import warnings
14 | from collections import OrderedDict
15 | from copy import deepcopy
16 | from os.path import join
17 | import json
18 | from shutil import copyfile
19 | from datetime import datetime
20 | import pandas as pd
21 | import sys
22 | import numpy as np
23 | from scipy import signal
24 |
25 | from useful_utils import comma_seperated_str_to_list, wandb_myinit, slice_dict_to_dict, \
26 | fill_missing_by_defaults
27 | from COSMO_utils import run_bash
28 |
29 |
30 | parser = argparse.ArgumentParser(description=__doc__)
31 | parser.add_argument('--dir', required=True, type=str)
32 | parser.add_argument('--early_stop_metrics', required=False, default=None, type=str)
33 | parser.add_argument('--ignore_skipping', default=True, action="store_true")
34 | parser.add_argument('--infer_on_incompleted', default=False, action="store_true")
35 | parser.add_argument('--smooth_val_curve', default=False, type=bool)
36 | parser.add_argument('--smooth_window_width', default=6, type=int)
37 | parser.add_argument('--use_wandb', default=False, type=bool)
38 | parser.add_argument('--wandb_project_name', default=None, type=str)
39 | parser.add_argument('--wandb_subset_metrics', default=False, type=bool)
40 | parser.add_argument('--eval_on_last_epoch', default=False, type=bool)
41 |
42 |
43 | def main(args):
44 | if isinstance(args, dict):
45 | args = fill_missing_by_defaults(args, parser)
46 |
47 | files = glob.glob(args.dir + "/*")
48 | # results_files = [file for file in files if "results" in basename(file)]
49 | print("###############")
50 | print("Starting offline_early_stop")
51 | print(f"running on '{args.dir}'")
52 | # if args.create_results_json:
53 | # if args.metric is None:
54 | # print("if creating empty results.json, must give specific metric")
55 | # with open(join(args.dir, "results.json"), 'w') as f:
56 | # json.dump({"metrics": {}, "train_cfg": {}, "meta_cfg": {}}, f)
57 | if args.early_stop_metrics is None:
58 | assert ValueError('--early_stop_metrics is required')
59 | if not args.infer_on_incompleted:
60 | assert (os.path.exists(join(args.dir, 'completed_training.touch')) or os.path.exists(join(args.dir, 'results.json')))
61 | if join(args.dir, "summary.csv") not in files:
62 | raise (RuntimeError("no summary.csv file!\n"))
63 |
64 | if not args.ignore_skipping and os.path.exists(join(args.dir, "lock")):
65 | print("this folder was already processed, skipping!\n")
66 | sys.exit(0)
67 | else:
68 | with open(join(args.dir, "lock"), "w") as f:
69 | f.write("0")
70 | summary_csv = pd.read_csv(join(args.dir, 'summary.csv'), sep='|')
71 |
72 | def smooth_validation_curve(validation_curve):
73 | if args.smooth_val_curve:
74 | win = np.hanning(args.smooth_window_width)
75 | validation_curve = signal.convolve(validation_curve, win, mode='same',
76 | method='direct') / sum(win)
77 | validation_curve = pd.Series(validation_curve)
78 |
79 | return validation_curve
80 |
81 | es_metric_list = comma_seperated_str_to_list(args.early_stop_metrics)
82 | # get run arguments
83 |
84 | args_dict = json.load(open(join(args.dir, "args.json"), "r"))
85 | early_stop_results_dict = OrderedDict()
86 | for i, primary_early_stop_metric in enumerate(es_metric_list):
87 | metric_index = i+1
88 | results = deepcopy(args_dict)
89 | print('')
90 |
91 | new_results_json_file = join(args.dir, f"results{metric_index}.json")
92 | if os.path.exists(new_results_json_file):
93 | backup_file_name = new_results_json_file.replace(".json",
94 | f"_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json")
95 | copyfile(new_results_json_file, backup_file_name)
96 | print(f"backed up '{new_results_json_file}' => '{backup_file_name}'")
97 |
98 | print(f"creating new file: {new_results_json_file}")
99 |
100 | try:
101 | validation_curve = summary_csv[primary_early_stop_metric].copy()
102 | validation_curve = smooth_validation_curve(validation_curve)
103 | best_epoch = validation_curve.idxmax()
104 | if np.isnan(best_epoch):
105 | continue
106 | if args.eval_on_last_epoch:
107 | best_epoch = len(validation_curve) -1
108 | best_epoch_summary = summary_csv.iloc[[best_epoch]]
109 | best_epoch_test_score = best_epoch_summary[primary_early_stop_metric.replace("valid", "test")]
110 | best_epoch_summary = best_epoch_summary.to_dict(orient='index')[best_epoch]
111 | print(f"best epoch is: {best_epoch}")
112 | print(f"test score: {best_epoch_test_score}")
113 |
114 | results['metrics'] = best_epoch_summary
115 | results['train']['primary_early_stop_metric'] = primary_early_stop_metric
116 | json.dump(results, open(new_results_json_file, "w"))
117 | early_stop_results_dict[primary_early_stop_metric] = results
118 | except KeyError as e:
119 | warnings.warn(repr(e))
120 |
121 | if args.use_wandb:
122 | import wandb
123 | offline_log_to_wandb(args.wandb_project_name, args_dict, early_stop_results_dict, summary_csv,
124 | workdir=args.dir,
125 | wandb_log_subset_of_metrics=args.wandb_subset_metrics)
126 | print("done offline_early_stop!\n")
127 |
128 | return early_stop_results_dict
129 |
130 |
131 | def offline_log_to_wandb(project_name, args_dict, early_stop_results_dict, summary_df, workdir=None,
132 | wandb_log_subset_of_metrics=False):
133 |
134 | if project_name is None:
135 | project_name = args_dict['exp']['project_name'] + '_offline'
136 | if wandb_log_subset_of_metrics:
137 | project_name += '_subset'
138 | print(f'Writing to W&B project {project_name}')
139 |
140 | curve_metric_names = None
141 | if wandb_log_subset_of_metrics:
142 | curve_metric_names = get_wandb_curve_metrics()
143 |
144 | print(f'Start dump results to W&B project: {project_name}')
145 | wandb_myinit(project_name=project_name, experiment_name=args_dict['exp']['experiment_name'],
146 | instance_name=args_dict['exp']['instance_name'], config=args_dict, workdir=workdir)
147 |
148 |
149 | global_step_name = 'epoch'
150 | summary_df = summary_df.set_index(global_step_name)
151 | print(f'Dump run curves')
152 | first_iter = True
153 | for global_step, step_metrics in summary_df.iterrows():
154 | if first_iter:
155 | first_iter = False
156 | if curve_metric_names is not None:
157 | for metric in curve_metric_names:
158 | if metric not in step_metrics:
159 | warnings.warn(f"Can't log '{metric}'. It doesn't exists.")
160 |
161 | if wandb_log_subset_of_metrics:
162 | metrics_to_log = slice_dict_to_dict(step_metrics.to_dict(), curve_metric_names, ignore_missing_keys=True)
163 | else:
164 | # log all metrics
165 | metrics_to_log = step_metrics.to_dict()
166 |
167 | metrics_to_log[global_step_name] = global_step
168 | wandb.log(metrics_to_log)
169 |
170 | early_stop_results_to_wandb_summary(early_stop_results_dict)
171 | dump_preds_at_early_stop(early_stop_results_dict, workdir, use_wandb=True)
172 |
173 | # terminate nicely offline w&b run
174 | wandb.join()
175 |
176 | def dump_preds_at_early_stop(early_stop_results_dict, workdir, use_wandb):
177 | print(f'Save to the dumped predictions at early stop epochs')
178 | # dirpath = tempfile.mkdtemp()
179 | for es_metric, results_dict in early_stop_results_dict.items():
180 | for phase_name in ('valid', 'test'):
181 | target_fname_preds = join(workdir, f'preds__{es_metric}_{phase_name}.npz')
182 |
183 | epoch = results_dict['metrics']['epoch']
184 | fname = join(workdir, 'dump_preds', f'epoch_{epoch}', f'dump_preds_{phase_name}.npz')
185 | if os.path.exists(fname):
186 | run_bash(f'cp {fname} {target_fname_preds}')
187 | if use_wandb:
188 | import wandb
189 | wandb.save(target_fname_preds)
190 | print(f'Saved {target_fname_preds}')
191 |
192 |
193 | def early_stop_results_to_wandb_summary(early_stop_results_dict):
194 | print(f'Dump early stop results')
195 | wandb_summary = OrderedDict()
196 | for es_metric, results_dict in early_stop_results_dict.items():
197 | wandb_summary[f'res__{es_metric}'] = results_dict['metrics']
198 | import wandb
199 | wandb.run.summary.update(wandb_summary)
200 |
201 |
202 |
203 |
204 | def get_wandb_curve_metrics():
205 | eval_metric_names = comma_seperated_str_to_list(
206 | 'y_joint_loss_mean, y1_loss_mean, y2_loss_mean'
207 | ', closed_balanced_acc'
208 | ', open_balanced_unseen_acc, open_balanced_seen_acc, open_H'
209 | ', y1_balanced_acc_unseen, y2_balanced_acc_unseen'
210 | ', y1_balanced_acc_seen, y2_balanced_acc_seen'
211 | ', closed_acc'
212 | ', unseen_open_acc, seen_open_acc, open_H_IMB'
213 | ', y1_acc_unseen, y2_acc_unseen'
214 | )
215 |
216 | train_metric_names = comma_seperated_str_to_list('y1_loss, y2_loss, y_loss'#, d_loss'
217 | ', hsic_loss, total_loss'#, d_fool_loss'
218 | ', y1_acc, y2_acc'#, ds1_acc, ds2_acc, current_alpha'
219 | ', HSIC_cond1, HSIC_cond2'
220 | ', loss, leplus_loss, tloss_feat, tloss_ao_emb'
221 | ', tloss_a, tloss_o, loss_aux'
222 | ', loss_aux_disjoint_attr, loss_aux_disjoint_obj')
223 |
224 |
225 | logged_metrics = []
226 | for metric in eval_metric_names:
227 | logged_metrics.append(metric + '_valid')
228 | logged_metrics.append(metric + '_test')
229 |
230 | for metric in train_metric_names:
231 | logged_metrics.append(metric + '_mean')
232 |
233 | return logged_metrics
234 |
235 |
236 | if __name__ == '__main__':
237 | args = parser.parse_args()
238 | main(args)
239 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | import json
9 | import os, sys
10 | import pickle
11 | from os.path import join
12 | from copy import deepcopy
13 | from pathlib import Path
14 | from sys import argv
15 | import random
16 |
17 | import dataclasses
18 | from munch import Munch
19 |
20 | import useful_utils
21 | import offline_early_stop
22 | import COSMO_utils
23 | from useful_utils import SummaryWriter_withCSV, wandb_myinit
24 |
25 | import torch.nn
26 | import torch
27 | import numpy as np
28 |
29 | import signal
30 |
31 | from data import CompDataFromDict
32 | from params import CommandlineArgs, TrainCfg, ExperimentCfg
33 | from train import train, alternate_training
34 |
35 |
36 | def set_random_seeds(base_seed):
37 | random.seed(base_seed)
38 | np.random.seed(base_seed+7205) # 7205 is a base seed that was randomly chosen
39 | torch.random.manual_seed(base_seed+1000)
40 | torch.cuda.manual_seed(base_seed+1001)
41 | torch.backends.cudnn.deterministic = True
42 |
43 |
44 | def main(args: CommandlineArgs):
45 | train_cfg: TrainCfg = args.train
46 | exp_cfg: ExperimentCfg = args.exp
47 | set_random_seeds(base_seed=train_cfg.seed)
48 |
49 | # init logging
50 | writer = init_logging(args)
51 |
52 | # load data
53 | test_dataset, train_dataset, valid_dataset = load_data(args)
54 | train_cfg.set_n_iter(len(train_dataset.data))
55 |
56 | # training
57 | with useful_utils.profileblock('complete training'):
58 | if train_cfg.alternate_ys == 0:
59 | ### train both heads jointly ##
60 | train(args, train_dataset, valid_dataset, test_dataset, writer)
61 | else:
62 | ### alternate between heads ###
63 | alternate_training(args, train_dataset, valid_dataset, test_dataset, writer)
64 |
65 | # ----- Finalizing ------
66 | # dump log to csv
67 | writer.dump_to_csv(verbose=1, float_format=f'%.{args.exp.csv_precision}f')
68 | writer.close()
69 |
70 | # Indicate run has complete
71 | COSMO_utils.run_bash(f'touch {join(exp_cfg.output_dir, "completed_training.touch")}')
72 |
73 | # Process offline early stopping according to results at output_dir
74 | early_stop_results_dict = process_offline_early_stopping(exp_cfg)
75 |
76 | # Print results
77 | print_results(early_stop_results_dict, exp_cfg)
78 |
79 | # Delete temporary artifacts from output dir
80 | clear_output_dir(exp_cfg)
81 |
82 | print('Done.\n')
83 |
84 |
85 | def print_results(early_stop_results_dict, exp_cfg):
86 | from munch import munchify
87 | early_stop_results_dict = munchify(early_stop_results_dict)
88 | print('\n\n####################################')
89 | if exp_cfg.report_imbalanced_metrics:
90 | # E.g. Zappos
91 | U = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.unseen_open_acc_test
92 | S = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.seen_open_acc_test
93 | H = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.open_H_IMB_test
94 | closed = 100 * early_stop_results_dict.AUC_open_valid.metrics.closed_acc_test
95 | AUC = 100 * early_stop_results_dict.AUC_open_valid.metrics.AUC_open_test
96 | print('Reporting IMbalanced metrics')
97 | print(f'Unseen={U:.1f}, Seen={S:.1f}, Harmonic={H:.1f}, Closed={closed:.1f}, AUC={AUC:.1f}')
98 |
99 | else:
100 | # e.g. AO-CLEVr
101 | U = 100 * early_stop_results_dict.open_H_valid.metrics.open_balanced_unseen_acc_test
102 | S = 100 * early_stop_results_dict.open_H_valid.metrics.open_balanced_seen_acc_test
103 | H = 100 * early_stop_results_dict.open_H_valid.metrics.open_H_test
104 | closed = 100 * early_stop_results_dict.closed_balanced_acc_valid.metrics.closed_balanced_acc_test
105 | print('Reporting Balanced metrics')
106 | print(f'Unseen={U:.1f}, Seen={S:.1f}, Harmonic={H:.1f}, Closed={closed:.1f}')
107 | print('####################################\n\n')
108 |
109 |
110 | def init_logging(args):
111 | exp_cfg: ExperimentCfg = args.exp
112 | output_dir = Path(exp_cfg.output_dir)
113 | if not exp_cfg.ignore_existing_output_contents and len(list(output_dir.iterdir())) > 0:
114 | raise ValueError(f'Output directory {output_dir} is not empty')
115 |
116 | args_dict = dataclasses.asdict(args)
117 | if exp_cfg.use_wandb:
118 | import wandb
119 | wandb_myinit(project_name=exp_cfg.project_name, experiment_name=exp_cfg.experiment_name,
120 | instance_name=exp_cfg.instance_name, config=args_dict, workdir=exp_cfg.output_dir,
121 | username=exp_cfg.wandb_user)
122 | # printing starts here - after initializing w&b
123 | print('commandline was:')
124 | print(' '.join(argv))
125 | print(vars(args))
126 | writer = SummaryWriter_withCSV(log_dir=exp_cfg.output_dir, suppress_tensorboard=True, wandb=exp_cfg.use_wandb)
127 | writer.set_print_options(pandas_max_columns=500, pandas_max_width=200)
128 | to_json(args_dict, exp_cfg.output_dir, filename='args.json')
129 | return writer
130 |
131 |
132 | def clear_output_dir(exp_cfg):
133 | # Always delete dumped (per-epoch) logits when done, because it takes a lot of space
134 | delete_dumped_logits(exp_cfg.output_dir)
135 |
136 | # Delete dumped (per epoch) decisions if required
137 | if exp_cfg.delete_dumped_preds:
138 | print('Delete logging of per-epoch dumped predictions')
139 | cmd = f'rm -rf {join(exp_cfg.output_dir, "dump_preds")}'
140 | print(cmd)
141 | COSMO_utils.run_bash(cmd)
142 |
143 |
144 | def process_offline_early_stopping(exp_cfg: ExperimentCfg):
145 | cfg_offline_early_stop = Munch()
146 | cfg_offline_early_stop.dir = exp_cfg.output_dir
147 | cfg_offline_early_stop.early_stop_metrics = 'open_H_valid,closed_balanced_acc_valid,open_H_IMB_valid,AUC_open_valid'
148 | early_stop_results_dict = offline_early_stop.main(cfg_offline_early_stop)
149 | if exp_cfg.use_wandb:
150 | # dump each early_stop result to currents project
151 | offline_early_stop.early_stop_results_to_wandb_summary(early_stop_results_dict)
152 | # and save the dumped predictions at its epoch
153 | offline_early_stop.dump_preds_at_early_stop(early_stop_results_dict, exp_cfg.output_dir, use_wandb=exp_cfg.use_wandb)
154 | return early_stop_results_dict
155 |
156 |
157 | def load_data(args: CommandlineArgs):
158 | if args.data.metadata_from_pkl:
159 | train_dataset, valid_dataset, test_dataset = load_pickled_metadata(args)
160 | print('load data from PKL')
161 | else:
162 | train_dataset, valid_dataset, test_dataset = load_TMN_data(args)
163 | print('load data using TMN project')
164 | return test_dataset, train_dataset, valid_dataset
165 |
166 |
167 | def to_json(args_dict, log_dir, filename):
168 | args_json = os.path.join(log_dir, filename)
169 | with open(args_json, 'w') as f:
170 | json.dump(args_dict, f)
171 | print(f'\nDump configuration to JSON file: {args_json}\n\n')
172 |
173 |
174 | def SIGINT_KeyboardInterrupt_handler(sig, frame):
175 | raise KeyboardInterrupt()
176 |
177 |
178 | def load_TMN_data(args: CommandlineArgs):
179 | import sys
180 | sys.path.append('taskmodularnets')
181 | import taskmodularnets.data.dataset as tmn_data
182 |
183 | dict_data = dict()
184 | for subset in ['train', 'val', 'test']:
185 | dTMN = tmn_data.CompositionDatasetActivations(root=args.data.data_dir,
186 | phase=subset,
187 | split='compositional-split-natural')
188 |
189 |
190 | # Add class attributes according to the current project API
191 | dTMN.all_open_pairs, dTMN.seen_pairs = \
192 | dTMN.pairs, dTMN.train_pairs
193 |
194 | # Get TMN unseen pairs, because val/test_pairs include both seen and unseen pairs
195 | dTMN.unseen_closed_val_pairs = list(set(dTMN.val_pairs).difference(dTMN.seen_pairs))
196 | dTMN.unseen_closed_test_pairs = list(set(dTMN.test_pairs).difference(dTMN.seen_pairs))
197 |
198 | dTMN.closed_unseen_pairs = dict(
199 | train=[],
200 | val=dTMN.unseen_closed_val_pairs,
201 | test=dTMN.unseen_closed_test_pairs)[subset]
202 |
203 | dict_data[f'{subset}'] = deepcopy(vars(dTMN))
204 |
205 | train_dataset = CompDataFromDict(dict_data['train'], data_subset='train_data', data_dir=args.data.data_dir)
206 | valid_dataset = CompDataFromDict(dict_data['val'], data_subset='val_data', data_dir=args.data.data_dir)
207 | test_dataset = CompDataFromDict(dict_data['test'], data_subset='test_data', data_dir=args.data.data_dir)
208 |
209 | print('Seen (train) pairs: ', train_dataset.seen_pairs)
210 | print('Unseen (val) pairs: ', train_dataset.unseen_closed_val_pairs)
211 | print('Unseen (test) pairs: ', train_dataset.unseen_closed_test_pairs)
212 |
213 | return train_dataset, valid_dataset, test_dataset
214 |
215 |
216 | def load_pickled_metadata(args: CommandlineArgs):
217 | data_cfg = args.data
218 | dataset_name = deepcopy(data_cfg['dataset_name'])
219 | dataset_variant = deepcopy(data_cfg['dataset_variant'])
220 | meta_path = Path(f"{data_cfg['data_dir']}/metadata_pickles")
221 | random_state_path = Path(f"{data_cfg['data_dir']}/np_random_state_pickles")
222 | meta_path = meta_path.expanduser()
223 |
224 | dict_data = dict()
225 | seen_seed = args.train.seed
226 | for subset in ['train', 'valid', 'test']:
227 | metadata_full_filename = meta_path / f"metadata_{dataset_name}__{dataset_variant}__comp_seed_{data_cfg['num_split']}__seen_seed_{seen_seed}__{subset}.pkl"
228 | dict_data[f'{subset}'] = deepcopy(pickle.load(open(metadata_full_filename, 'rb')))
229 |
230 | np_rnd_state_fname = random_state_path / f"np_random_state_{dataset_name}__{dataset_variant}__comp_seed_{data_cfg['num_split']}__seen_seed_{seen_seed}.pkl"
231 | np_seed_state = pickle.load(open(np_rnd_state_fname, 'rb'))
232 | np.random.set_state(np_seed_state)
233 |
234 | train_dataset = CompDataFromDict(dict_data['train'], data_subset='train_data', data_dir=data_cfg['data_dir'])
235 | valid_dataset = CompDataFromDict(dict_data['valid'], data_subset='val_data', data_dir=data_cfg['data_dir'])
236 | test_dataset = CompDataFromDict(dict_data['test'], data_subset='test_data', data_dir=data_cfg['data_dir'])
237 |
238 | print('Seen (train) pairs: ', train_dataset.seen_pairs)
239 | print('Unseen (val) pairs: ', train_dataset.unseen_closed_val_pairs)
240 | print('Unseen (test) pairs: ', train_dataset.unseen_closed_test_pairs)
241 |
242 | return train_dataset, valid_dataset, test_dataset
243 |
244 |
245 | def delete_dumped_logits(logdir):
246 | # Delete dumped logits
247 | f"find {join(logdir, 'dump_preds')} -name 'logits*' -delete"
248 |
249 |
250 | if __name__ == '__main__':
251 | args = CommandlineArgs.get_args()
252 | main(args)
253 |
254 |
--------------------------------------------------------------------------------
/taskmodularnets/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # taskmodularnets/LICENSE file.
6 | #
7 | import numpy as np
8 | import torch.utils.data as tdata
9 | import torch
10 | import torchvision.transforms as transforms
11 | from PIL import Image
12 | import glob
13 | import os
14 | import tqdm
15 | import torchvision.models as tmodels
16 | import torch.nn as nn
17 | from torch.autograd import Variable
18 | import torch
19 | import bz2
20 | from utils import utils
21 | import h5py
22 | import pdb
23 | import itertools
24 | import os
25 | import collections
26 | import scipy.io
27 | from sklearn.model_selection import train_test_split
28 |
29 |
30 | class ImageLoader:
31 | def __init__(self, root):
32 | self.img_dir = root
33 |
34 | def __call__(self, img):
35 | file = '%s/%s' % (self.img_dir, img)
36 | img = Image.open(file).convert('RGB')
37 | return img
38 |
39 |
40 | def imagenet_transform(phase):
41 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
42 |
43 | if phase == 'train':
44 | transform = transforms.Compose([
45 | transforms.RandomResizedCrop(224),
46 | transforms.RandomHorizontalFlip(),
47 | transforms.ToTensor(),
48 | transforms.Normalize(mean, std)
49 | ])
50 | elif phase == 'test' or phase == 'val':
51 | transform = transforms.Compose([
52 | transforms.Resize(256),
53 | transforms.CenterCrop(224),
54 | transforms.ToTensor(),
55 | transforms.Normalize(mean, std)
56 | ])
57 |
58 | return transform
59 |
60 |
61 | #------------------------------------------------------------------------------------------------------------------------------------#
62 |
63 |
64 | class CompositionDataset(tdata.Dataset):
65 | def __init__(
66 | self,
67 | root,
68 | phase,
69 | split='compositional-split',
70 | subset=False,
71 | num_negs=1,
72 | pair_dropout=0.0,
73 | ):
74 | self.root = root
75 | self.phase = phase
76 | self.split = split
77 | self.num_negs = num_negs
78 | self.pair_dropout = pair_dropout
79 |
80 | self.feat_dim = None
81 | self.transform = imagenet_transform(phase)
82 | self.loader = ImageLoader(self.root + '/images/')
83 |
84 | self.attrs, self.objs, self.pairs, \
85 | self.train_pairs, self.val_pairs, \
86 | self.test_pairs = self.parse_split()
87 |
88 | self.train_data, self.val_data, self.test_data = self.get_split_info()
89 | if self.phase == 'train':
90 | self.data = self.train_data
91 | elif self.phase == 'val':
92 | self.data = self.val_data
93 | else:
94 | self.data = self.test_data
95 | if subset:
96 | ind = np.arange(len(self.data))
97 | ind = ind[::len(ind) // 1000]
98 | self.data = [self.data[i] for i in ind]
99 |
100 | self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)}
101 | self.attr2idx = {attr: idx for idx, attr in enumerate(self.attrs)}
102 | self.pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)}
103 |
104 | print('# train pairs: %d | # val pairs: %d | # test pairs: %d' % (len(
105 | self.train_pairs), len(self.val_pairs), len(self.test_pairs)))
106 | print('# train images: %d | # val images: %d | # test images: %d' %
107 | (len(self.train_data), len(self.val_data), len(self.test_data)))
108 |
109 | # fix later -- affordance thing
110 | # return {object: all attrs that occur with obj}
111 | self.obj_affordance = {}
112 | self.train_obj_affordance = {}
113 | for _obj in self.objs:
114 | candidates = [
115 | attr
116 | for (_, attr,
117 | obj) in self.train_data + self.val_data + self.test_data
118 | if obj == _obj
119 | ]
120 | self.obj_affordance[_obj] = sorted(list(set(candidates)))
121 |
122 | candidates = [
123 | attr for (_, attr, obj) in self.train_data if obj == _obj
124 | ]
125 | self.train_obj_affordance[_obj] = sorted(list(set(candidates)))
126 |
127 | self.sample_indices = list(range(len(self.data)))
128 | self.sample_pairs = self.train_pairs
129 |
130 | def reset_dropout(self):
131 | self.sample_indices = list(range(len(self.data)))
132 | self.sample_pairs = self.train_pairs
133 |
134 | shuffled_ind = np.random.permutation(len(self.train_pairs))
135 | n_pairs = int((1 - self.pair_dropout) * len(self.train_pairs))
136 | self.sample_pairs = [
137 | self.train_pairs[pi] for pi in shuffled_ind[:n_pairs]
138 | ]
139 | print('Using {} pairs out of {} pairs right now'.format(
140 | n_pairs, len(self.train_pairs)))
141 | self.sample_indices = [
142 | i for i in range(len(self.data))
143 | if (self.data[i][1], self.data[i][2]) in self.sample_pairs
144 | ]
145 | print('Using {} images out of {} images right now'.format(
146 | len(self.sample_indices), len(self.data)))
147 |
148 | def get_split_info(self):
149 | data = torch.load(self.root + '/metadata_{}.t7'.format(self.split))
150 | train_data, val_data, test_data = [], [], []
151 | for instance in data:
152 | image, attr, obj, settype = instance['image'], instance[
153 | 'attr'], instance['obj'], instance['set']
154 |
155 | if attr == 'NA' or (attr,
156 | obj) not in self.pairs or settype == 'NA':
157 | # ignore instances with unlabeled attributes
158 | # ignore instances that are not in current split
159 | continue
160 |
161 | data_i = [image, attr, obj]
162 | if settype == 'train':
163 | train_data.append(data_i)
164 | elif settype == 'val':
165 | val_data.append(data_i)
166 | else:
167 | test_data.append(data_i)
168 |
169 | return train_data, val_data, test_data
170 |
171 | def parse_split(self):
172 | def parse_pairs(pair_list):
173 | with open(pair_list, 'r') as f:
174 | pairs = f.read().strip().split('\n')
175 | pairs = [t.split() for t in pairs]
176 | pairs = list(map(tuple, pairs))
177 | attrs, objs = zip(*pairs)
178 | return attrs, objs, pairs
179 |
180 | tr_attrs, tr_objs, tr_pairs = parse_pairs(
181 | '%s/%s/train_pairs.txt' % (self.root, self.split))
182 | vl_attrs, vl_objs, vl_pairs = parse_pairs(
183 | '%s/%s/val_pairs.txt' % (self.root, self.split))
184 | ts_attrs, ts_objs, ts_pairs = parse_pairs(
185 | '%s/%s/test_pairs.txt' % (self.root, self.split))
186 |
187 | all_attrs, all_objs = sorted(
188 | list(set(tr_attrs + vl_attrs + ts_attrs))), sorted(
189 | list(set(tr_objs + vl_objs + ts_objs)))
190 | all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs)))
191 |
192 | return all_attrs, all_objs, all_pairs, tr_pairs, vl_pairs, ts_pairs
193 |
194 | def sample_negative(self, attr, obj):
195 | new_attr, new_obj = self.sample_pairs[np.random.choice(
196 | len(self.sample_pairs))]
197 | if new_attr == attr and new_obj == obj:
198 | return self.sample_negative(attr, obj)
199 | return (self.attr2idx[new_attr], self.obj2idx[new_obj])
200 |
201 | def sample_affordance(self, attr, obj):
202 | new_attr = np.random.choice(self.obj_affordance[obj])
203 | if new_attr == attr:
204 | return self.sample_affordance(attr, obj)
205 | return self.attr2idx[new_attr]
206 |
207 | def sample_train_affordance(self, attr, obj):
208 | new_attr = np.random.choice(self.train_obj_affordance[obj])
209 | if new_attr == attr:
210 | return self.sample_train_affordance(attr, obj)
211 | return self.attr2idx[new_attr]
212 |
213 | def __getitem__(self, index):
214 | index = self.sample_indices[index]
215 | image, attr, obj = self.data[index]
216 | img = self.loader(image)
217 | img = self.transform(img)
218 |
219 | data = [
220 | img, self.attr2idx[attr], self.obj2idx[obj], self.pair2idx[(attr,
221 | obj)]
222 | ]
223 |
224 | if self.phase == 'train':
225 | all_neg_attrs = []
226 | all_neg_objs = []
227 | for _ in range(self.num_negs):
228 | neg_attr, neg_obj = self.sample_negative(
229 | attr, obj) # negative example for triplet loss
230 | all_neg_objs.append(neg_obj)
231 | all_neg_attrs.append(neg_attr)
232 | neg_attr = torch.LongTensor(all_neg_attrs)
233 | neg_obj = torch.LongTensor(all_neg_objs)
234 | inv_attr = self.sample_train_affordance(
235 | attr, obj) # attribute for inverse regularizer
236 | comm_attr = self.sample_affordance(
237 | inv_attr, obj) # attribute for commutative regularizer
238 | data += [neg_attr, neg_obj, inv_attr, comm_attr]
239 | return data
240 |
241 | def __len__(self):
242 | return len(self.sample_indices)
243 |
244 |
245 | #------------------------------------------------------------------------------------------------------------------------------------#
246 |
247 |
248 | class CompositionDatasetActivations(CompositionDataset):
249 | def __init__(
250 | self,
251 | root,
252 | phase,
253 | split,
254 | subset=False,
255 | num_negs=1,
256 | pair_dropout=0.0,
257 | ):
258 | super(CompositionDatasetActivations, self).__init__(
259 | root,
260 | phase,
261 | split,
262 | subset=subset,
263 | num_negs=num_negs,
264 | pair_dropout=pair_dropout,
265 | )
266 |
267 | # precompute the activations -- weird. Fix pls
268 | feat_file = '%s/features.t7' % root
269 | if not os.path.exists(feat_file):
270 | with torch.no_grad():
271 | self.generate_features(feat_file)
272 |
273 | activation_data = torch.load(feat_file)
274 | self.activations = dict(
275 | zip(activation_data['files'], activation_data['features']))
276 | self.feat_dim = activation_data['features'].size(1)
277 |
278 | print('%d activations loaded' % (len(self.activations)))
279 |
280 | def generate_features(self, out_file):
281 |
282 | data = self.train_data + self.val_data + self.test_data
283 | transform = imagenet_transform('test')
284 | feat_extractor = tmodels.resnet18(pretrained=True)
285 | feat_extractor.fc = nn.Sequential()
286 | feat_extractor.eval().cuda()
287 |
288 | image_feats = []
289 | image_files = []
290 | for chunk in tqdm.tqdm(
291 | utils.chunks(data, 512), total=len(data) // 512):
292 | files, attrs, objs = zip(*chunk)
293 | imgs = list(map(self.loader, files))
294 | imgs = list(map(transform, imgs))
295 | feats = feat_extractor(torch.stack(imgs, 0).cuda())
296 | image_feats.append(feats.data.cpu())
297 | image_files += files
298 | image_feats = torch.cat(image_feats, 0)
299 | print('features for %d images generated' % (len(image_files)))
300 |
301 | torch.save({'features': image_feats, 'files': image_files}, out_file)
302 |
303 | def __getitem__(self, index):
304 | data = super(CompositionDatasetActivations, self).__getitem__(index)
305 | index = self.sample_indices[index]
306 | image, attr, obj = self.data[index]
307 | data[0] = self.activations[image]
308 | return data
309 |
--------------------------------------------------------------------------------
/useful_utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 | import re
8 | from collections import OrderedDict, defaultdict
9 | import time
10 | from copy import deepcopy
11 | import os
12 |
13 | import torch
14 |
15 | from tensorboardX import SummaryWriter
16 | import tensorboardX
17 | import pandas as pd
18 | import numpy as np
19 |
20 | from COSMO_utils import run_bash
21 |
22 |
23 | class SummaryWriter_withCSV(SummaryWriter):
24 | """ A wrapper for tensorboard SummaryWriter that is based on pandas,
25 | and writes to CSV and optionally to W&B"""
26 | def __init__(self, log_dir, *args, **kwargs):
27 |
28 | global_step_name = 'epoch'
29 | if 'global_step_name' in kwargs:
30 | global_step_name = kwargs['global_step_name']
31 | kwargs.pop('global_step_name')
32 |
33 | self.global_step_name = global_step_name
34 |
35 | self.log_wandb = False
36 | if 'wandb' in kwargs:
37 | self.log_wandb = kwargs['wandb']
38 | kwargs.pop('wandb')
39 |
40 | if self.log_wandb:
41 | import wandb
42 |
43 | self.suppress_tensorboard = kwargs.get('suppress_tensorboard', False)
44 |
45 | if not self.suppress_tensorboard:
46 | super(SummaryWriter_withCSV, self).__init__(log_dir, *args, **kwargs)
47 |
48 | self.df = pd.DataFrame()
49 | self.df.index.name = self.global_step_name
50 | self._log_dir = log_dir
51 | self.last_global_step = -1
52 |
53 | def add_scalar(self, tag, scalar_value, global_step=None):
54 | if global_step is not None:
55 | self.df.loc[global_step, tag] = scalar_value
56 |
57 | # if finalized last step, dump its metrics to wandb
58 | if self.last_global_step < global_step:
59 | if self.log_wandb and self.last_global_step >= 0:
60 | with ns_profiling_label('wandb_log'):
61 | import wandb
62 | wandb.log(self.df.loc[self.last_global_step, :].to_dict(), sync=False)
63 | self.last_global_step = global_step
64 |
65 | if not self.suppress_tensorboard:
66 | super().add_scalar(tag, scalar_value, global_step=global_step)
67 |
68 | def add_summary(self, summary, global_step=None):
69 | summary_proto = tensorboardX.summary.Summary()
70 |
71 | if isinstance(summary, bytes):
72 | summary_list = [value for value in summary_proto.FromString(summary).value]
73 | else:
74 | summary_list = [value for value in summary.value]
75 | for val in summary_list:
76 | self.add_scalar(val.tag, val.simple_value, global_step=global_step)
77 |
78 |
79 | def set_print_options(self, pandas_max_columns=500, pandas_max_width=1000):
80 | pd.set_option('display.max_columns', pandas_max_columns)
81 | pd.set_option('display.width', pandas_max_width)
82 |
83 | def last_results_as_string(self, regex_filter_out=None):
84 | if regex_filter_out is not None:
85 | s = self.df.loc[:, ~self.df.columns.str.match(regex_filter_out)].iloc[-1, :]
86 | else:
87 | s = self.df.iloc[-1, :]
88 | string = ' '.join([f'{key}={s[key]:.2g}' for key in s.keys()])
89 | return string
90 |
91 | def dump_to_csv(self, fname='summary.csv', sep='|', verbose=0, **kwargs_df_to_csv):
92 | fullfname = os.path.join(self._log_dir, fname)
93 | self.df.to_csv(fullfname, sep=sep, **kwargs_df_to_csv)
94 | if verbose>0:
95 | print(f'Dump history to CSV file: {fullfname}')
96 | if verbose == 2:
97 | print('')
98 | with open(fullfname) as f:
99 | print(f.read())
100 | def close(self):
101 |
102 | # dump last step results to wandb
103 | if self.log_wandb:
104 | with ns_profiling_label('wandb_log'):
105 | import wandb
106 | wandb.log(self.df.loc[self.last_global_step, :].to_dict(), sync=False)
107 |
108 | if not self.suppress_tensorboard:
109 | super().close()
110 |
111 |
112 | def slice_dict_to_dict(d, keys, returned_keys_prefix='', returned_keys_postfix='', ignore_missing_keys=False):
113 | """ Returns a tuple from dictionary values, ordered and slice by given keys
114 | keys can be a list, or a CSV string
115 | """
116 | if isinstance(keys, str):
117 | keys = keys[:-1] if keys[-1] == ',' else keys
118 | keys = re.split(', |[, ]', keys)
119 |
120 | if returned_keys_prefix != '' or returned_keys_postfix != '':
121 | return OrderedDict((returned_keys_prefix + k + returned_keys_postfix, d[k]) for k in keys)
122 |
123 | if ignore_missing_keys:
124 | return OrderedDict((k, d[k]) for k in keys if k in d)
125 | else:
126 | return OrderedDict((k, d[k]) for k in keys)
127 |
128 |
129 | def clone_model(model):
130 | # clone a pytorch model
131 | return deepcopy(model)
132 |
133 |
134 | def is_uncommited_git_repo(list_ignored_regex_pattern_filenames=None, ignore_untracked_files=True):
135 | """ Check if there are uncommited files in workdir.
136 | Can ignore specific file names or regex patterns
137 | """
138 |
139 | uncommitted_files_list = get_uncommitted_files(list_ignored_regex_pattern_filenames, ignore_untracked_files)
140 |
141 | if uncommitted_files_list:
142 | return True
143 | else:
144 | return False
145 |
146 |
147 | def get_uncommitted_files(list_ignored_regex_pattern_filenames, ignore_untracked_files):
148 | ignore_untracked_files_str = ''
149 | if ignore_untracked_files:
150 | ignore_untracked_files_str = ' -uno'
151 | if list_ignored_regex_pattern_filenames is None:
152 | list_ignored_regex_pattern_filenames = []
153 | git_status = run_bash('git status --porcelain' + ignore_untracked_files_str)
154 | uncommitted_files = []
155 | for line in git_status.split('\n'):
156 | ignore_current_file = False
157 | if line:
158 | fname = re.split(' ?\w +', line)[1]
159 | for ig_file_regex_pattern in list_ignored_regex_pattern_filenames:
160 | if re.match(ig_file_regex_pattern, fname):
161 | ignore_current_file = True
162 | break
163 | if ignore_current_file:
164 | continue
165 | else:
166 | uncommitted_files.append(fname)
167 | return uncommitted_files
168 |
169 |
170 | def list_to_2d_tuple(l):
171 | return tuple(tuple(tup) for tup in l)
172 |
173 |
174 | def categorical_histogram(data, labels_list, plot=True, frac=True, plt_show=False):
175 | import matplotlib.pyplot as plt
176 | s_counts = pd.Series(data).value_counts()
177 | s_frac = s_counts/s_counts.sum()
178 | hist_dict = s_counts.to_dict()
179 | if frac:
180 | hist_dict = s_frac.to_dict()
181 | hist = []
182 | for ix, _ in enumerate(labels_list):
183 | hist.append(hist_dict.get(ix, 0))
184 |
185 | if plot:
186 | pd.Series(hist, index=labels_list).plot(kind='bar')
187 | if frac:
188 | plt.ylim((0,1))
189 | if plt_show:
190 | plt.show()
191 | else:
192 | return np.array(hist, dtype='float32')
193 |
194 |
195 | def to_torch(array, device):
196 | array = np.asanyarray(array) # cast to array if not an array. othrwise do nothing
197 | return torch.from_numpy(array).to(device)
198 |
199 |
200 | def comma_seperated_str_to_list(comma_seperated_str, regex_sep=r', |[, ]'):
201 | return re.split(regex_sep, comma_seperated_str)
202 |
203 |
204 |
205 | class profileblock(object):
206 | """
207 | Usage example:
208 | with profileblock(label='abc'):
209 | time.sleep(0.1)
210 |
211 | """
212 | def __init__(self, label=None, disable=False):
213 | self.disable = disable
214 | self.label = ''
215 | if label is not None:
216 | self.label = label + ': '
217 |
218 | def __enter__(self):
219 | if self.disable:
220 | return
221 | self.tic = time.time()
222 | return self
223 |
224 | def __exit__(self, type, value, traceback):
225 | if self.disable:
226 | return
227 | elapsed = np.round(time.time() - self.tic, 2)
228 | print(f'{self.label} Elapsed {elapsed} sec')
229 |
230 | ###########
231 |
232 |
233 | ##################
234 |
235 | from contextlib import contextmanager
236 | from torch.cuda import nvtx
237 |
238 |
239 | # @contextmanager
240 | # def ns_profiling_label(label, disable=False):
241 | # """
242 | # Wraps a code block with a label for profiling using "nsight-systems"
243 | # Usage example:
244 | # with ns_profiling_label('epoch %d'%epoch):
245 | # << CODE FOR TRAINING AN EPOCH >>
246 | #
247 | # """
248 | # if not disable:
249 | # nvtx.range_push(label)
250 | # try:
251 | # yield None
252 | # finally:
253 | # if not disable:
254 | # nvtx.range_pop()
255 |
256 |
257 | @contextmanager
258 | def ns_profiling_label(label):
259 | """
260 | A do nothing version of ns_profiling_label()
261 |
262 | """
263 | try:
264 | yield None
265 | finally:
266 | pass
267 |
268 |
269 | def torch_nans(*args, **kwargs):
270 | return np.nan*torch.zeros(*args, **kwargs)
271 |
272 |
273 |
274 | class batch_torch_logger():
275 | def __init__(self, cs_str_args=None, num_batches=None, nanmean_args_cs_str=None, device=None):
276 | """
277 | cs_str_args: arguments list as a comma separated string
278 | """
279 | assert(num_batches is not None)
280 | self.device = device
281 | self.loggers = {}
282 | if cs_str_args is not None:
283 | args_list = comma_seperated_str_to_list(cs_str_args)
284 | self.loggers = dict((arg_name, torch_nans(num_batches, device=device)) for arg_name in args_list)
285 |
286 | self.nanmean_args_list = []
287 | if nanmean_args_cs_str is not None:
288 | self.nanmean_args_list = comma_seperated_str_to_list(nanmean_args_cs_str)
289 |
290 |
291 | self.cnt = -1
292 |
293 | def new_batch(self):
294 | self.cnt += 1
295 |
296 | def log(self, locals_dict):
297 | for arg in self.loggers.keys():
298 | self.log_arg(arg, locals_dict.get(arg, torch.tensor(-9999.).to(self.device)))
299 |
300 | def log_arg(self, arg, value):
301 | with ns_profiling_label(f'log arg'):
302 | try:
303 | if type(value) == float or isinstance(value, np.number):
304 | value = torch.FloatTensor([value])
305 | self.loggers[arg][self.cnt] = value.detach()
306 | except AttributeError:
307 | print(f'Error: arg name = {arg}')
308 | raise
309 |
310 | def mean(self, arg):
311 | if arg in self.nanmean_args_list:
312 | return torch_nanmean(self.loggers[arg][:(self.cnt + 1)]).detach().item()
313 | else:
314 | return self.loggers[arg][:(self.cnt+1)].mean().detach().item()
315 |
316 | def get_means(self):
317 | return OrderedDict((arg + '_mean', self.mean(arg)) for arg in self.loggers.keys())
318 |
319 | def torch_nanmean(x):
320 | return x[~torch.isnan(x)].mean()
321 |
322 |
323 | def wandb_myinit(project_name, experiment_name, instance_name, config, workdir=None, wandb_output_dir='/tmp/wandb',
324 | reinit=True, username='user'):
325 |
326 | import wandb
327 | tags = [experiment_name]
328 | if experiment_name.startswith('qa_'):
329 | tags.append('qa')
330 | config['workdir'] = workdir
331 | wandb.init(project=project_name, name=instance_name, config=config, tags=tags, dir=wandb_output_dir, reinit=reinit,
332 | entity=username)
333 |
334 |
335 | def get_all_argparse_keys(parser):
336 | return [action.dest for action in parser._actions]
337 |
338 |
339 | def fill_missing_by_defaults(args_dict, argparse_parser):
340 | all_arg_keys = get_all_argparse_keys(argparse_parser)
341 | for key in all_arg_keys:
342 | if key not in args_dict:
343 | args_dict[key] = argparse_parser.get_default(key)
344 | return args_dict
345 |
346 |
347 | def get_and_update_num_calls(func_ptr):
348 | try:
349 | get_and_update_num_calls.num_calls_cnt[func_ptr] += 1
350 | except AttributeError as e:
351 | if 'num_calls_cnt' in repr(e):
352 | get_and_update_num_calls.num_calls_cnt = defaultdict(int)
353 | else:
354 | raise
355 |
356 | return get_and_update_num_calls.num_calls_cnt[func_ptr]
357 |
358 |
359 | def duplicate(x, times):
360 | return tuple(deepcopy(x) for _ in range(times))
361 |
362 |
363 | def to_simple_parsing_args(some_dataclass_type):
364 | """
365 | Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline
366 | Example:
367 |
368 | @classmethod
369 | def get_args(cls):
370 | args: cls = to_simple_parsing_args(cls)
371 | return args
372 |
373 | """
374 | from simple_parsing import ArgumentParser, ConflictResolution
375 | parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
376 | parser.add_arguments(some_dataclass_type, dest='cfg')
377 | args: some_dataclass_type = parser.parse_args().cfg
378 | return args
379 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | from copy import deepcopy
9 | from typing import Union
10 |
11 | import torch
12 | from torch.nn import functional as F
13 | from torch import nn
14 | from torch.nn.functional import triplet_margin_loss
15 |
16 | from data import CompDataFromDict
17 | from useful_utils import ns_profiling_label
18 | from params import CommandlineArgs
19 | from ATTOP.models.models import MLP as ATTOP_MLP
20 |
21 | def get_model(args: CommandlineArgs, dataset: CompDataFromDict):
22 | cfg = args.model
23 | if cfg.E_num_common_hidden==0:
24 | ECommon = nn.Sequential()
25 | ECommon.output_shape = lambda: (None, dataset.input_shape[0])
26 | else:
27 | ECommon = ATTOP_MLP(dataset.input_shape[0], cfg.h_dim, num_layers=cfg.E_num_common_hidden - 1, relu=True, bias=True).to(args.device)
28 | ECommon.output_shape = lambda: (None, cfg.h_dim)
29 |
30 |
31 | if not cfg.VisProd:
32 | h_A = Label_MLP_embed(dataset.num_attrs, cfg.h_dim,
33 | num_layers=cfg.nlayers_label).to(args.device)
34 | h_O = Label_MLP_embed(dataset.num_objs, cfg.h_dim,
35 | num_layers=cfg.nlayers_label).to(args.device)
36 | g1_emb_to_hidden_feat = ATTOP_MLP(2 * cfg.h_dim, cfg.h_dim,
37 | num_layers=cfg.nlayers_joint_ao).to(args.device)
38 | g2_feat_to_image_feat = ATTOP_MLP(cfg.h_dim, dataset.input_shape[0], num_layers=cfg.nlayers_joint_ao
39 | ).to(args.device)
40 |
41 | g_inv_O = MLP_Encoder(ECommon.output_shape()[1], h_dim=cfg.h_dim,
42 | E_num_common_hidden=cfg.E_num_hidden - cfg.E_num_common_hidden,
43 | mlp_activation='leaky_relu', BN=True).to(args.device) # category
44 | g_inv_A = MLP_Encoder(ECommon.output_shape()[1], h_dim=cfg.h_dim,
45 | E_num_common_hidden=cfg.E_num_hidden - cfg.E_num_common_hidden,
46 | mlp_activation='leaky_relu', BN=True).to(args.device) # category
47 |
48 | emb_cf_O = EmbeddingClassifier(h_O, image_feat_dim=g_inv_O.h_dim, device=args.device).to(args.device)
49 | # redundant historic call - just to make sure random-number-gen is kept aligned with original codebase
50 | _ = EmbeddingClassifier(h_A, image_feat_dim=g_inv_A.h_dim, device=args.device).to(args.device)
51 |
52 | emb_cf_A = EmbeddingClassifier(h_A, image_feat_dim=g_inv_A.h_dim, device=args.device).to(args.device)
53 | # redundant historic call - just to make sure random-number-gen is kept aligned with original codebase
54 | _ = EmbeddingClassifier(h_O, image_feat_dim=g_inv_O.h_dim, device=args.device).to(args.device)
55 |
56 | model = CompModel(ECommon, g_inv_O, g_inv_A, emb_cf_O, emb_cf_A, h_O, h_A, g1_emb_to_hidden_feat, g2_feat_to_image_feat,
57 | args, dataset).to(args.device)
58 |
59 | return model
60 |
61 |
62 | def nll_sum_loss(attr_logits, obj_logits, attr_gt, obj_gt, nll_loss_funcs):
63 | nll_loss_attr = nll_loss_funcs.y2(attr_logits, attr_gt)
64 | nll_loss_obj = nll_loss_funcs.y1(obj_logits, obj_gt)
65 | return nll_loss_attr + nll_loss_obj
66 |
67 |
68 | def get_activation_layer(activation):
69 | return dict(leaky_relu=nn.LeakyReLU(),
70 | relu=nn.ReLU(),
71 | )[activation]
72 |
73 | class MLP_block(nn.Module):
74 | def __init__(self, input_dim, output_dim, mlp_activation='leaky_relu', BN=True):
75 | super(MLP_block, self).__init__()
76 | layers_list = []
77 | layers_list += [nn.Linear(input_dim, output_dim)]
78 | if BN:
79 | layers_list += [nn.BatchNorm1d(num_features=output_dim)]
80 | layers_list += [get_activation_layer(mlp_activation)]
81 | self.NN = nn.Sequential(*layers_list)
82 |
83 | def forward(self, x):
84 | return self.NN(x)
85 |
86 |
87 | class Label_MLP_embed(nn.Module):
88 | def __init__(self, num_embeddings, embedding_dim, num_layers=0):
89 | super(Label_MLP_embed, self).__init__()
90 | self.lin_emb = nn.Embedding(num_embeddings, embedding_dim)
91 | layers_list = []
92 | layers_list += [self.lin_emb]
93 | if num_layers > 0:
94 | layers_list += [ATTOP_MLP(embedding_dim, embedding_dim, num_layers=num_layers)]
95 | self.NN = nn.Sequential(*layers_list)
96 | self.embedding_dim = embedding_dim
97 |
98 | def forward(self, tokens):
99 | output = self.NN(tokens)
100 | return output
101 |
102 |
103 |
104 | class MLP_Encoder(nn.Module):
105 | def __init__(self, input_dim, h_dim=150, E_num_common_hidden=1, mlp_activation='leaky_relu', BN=True,
106 | linear_last_layer=False, **kwargs):
107 | super().__init__()
108 |
109 | self.h_dim = h_dim
110 | layers_list = []
111 | prev_dim = input_dim
112 | for n in range(E_num_common_hidden):
113 | if n==0:
114 | layers_list += [MLP_block(prev_dim, h_dim, mlp_activation=mlp_activation, BN=BN)]
115 | else:
116 | layers_list += [MLP_block(prev_dim, h_dim, mlp_activation=mlp_activation, BN=BN)]
117 | prev_dim = h_dim
118 |
119 |
120 | if linear_last_layer:
121 | layers_list += [nn.Linear(prev_dim, prev_dim)]
122 |
123 | self._output_shape = prev_dim
124 | self.layers_list = layers_list
125 | self.NN = nn.Sequential(*layers_list)
126 |
127 |
128 | def forward(self, x):
129 | return self.NN(x)
130 |
131 | def output_shape(self):
132 | return (None, self._output_shape)
133 |
134 |
135 | class EmbeddingClassifier(nn.Module):
136 | def __init__(self, embeddings, image_feat_dim, device):
137 | super().__init__()
138 | self.num_emb_class, _ = list(embeddings.parameters())[0].shape
139 | self.image_feat_dim = image_feat_dim
140 | self.device = device
141 |
142 | self.embeddings = embeddings
143 | self.emb_alignment = nn.Linear(image_feat_dim, self.embeddings.embedding_dim)
144 | def forward(self, feat):
145 | feat = self.emb_alignment(feat)
146 | emb_per_class = self.embeddings(torch.arange(0, self.num_emb_class, dtype=int).to(self.device)).T
147 | out = -((feat[:, :, None] - emb_per_class) ** 2).sum(1)
148 | return out
149 |
150 |
151 | def LinearSoftmaxLogits(input_dim, output_dim):
152 | return nn.Sequential(nn.Linear(input_dim, output_dim),
153 | nn.LogSoftmax())
154 |
155 |
156 | class CompModel(nn.Module):
157 | def __init__(self, ECommon, g_inv_O, g_inv_A, emb_cf_O, emb_cf_A, h_O, h_A, g1_emb_to_hidden_feat, g2_feat_to_image_feat,
158 | args: CommandlineArgs, dataset):
159 | """
160 | Input:
161 | E: encoder
162 | M: classifier
163 | D: discriminator
164 | alpha: weighting parameter of label classifier and domain classifier
165 | num_classes: the number of classes
166 | """
167 | super(CompModel, self).__init__()
168 | model_cfg = args.model
169 | self.args = args
170 | self.is_mutual = True
171 | self.ECommon = ECommon
172 | self.g_inv_O = g_inv_O
173 | self.g_inv_A = g_inv_A
174 | self.emb_cf_O = emb_cf_O
175 | self.emb_cf_A = emb_cf_A
176 | self.h_O = h_O
177 | self.h_A = h_A
178 | self.g1_emb_to_hidden_feat = g1_emb_to_hidden_feat
179 | self.g2_feat_to_image_feat = g2_feat_to_image_feat
180 | self.loader = None
181 | self.num_objs = dataset.num_objs
182 | self.num_attrs = dataset.num_attrs
183 | self.attrs_idxs, self.objs_idxs = self._get_ao_outerprod_idxs()
184 | self.last_feature_common = None
185 |
186 | if not model_cfg.VisProd:
187 | self.obj_inv_core_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_objs).to(args.device)
188 | self.attr_inv_core_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_attrs).to(args.device)
189 |
190 | self.obj_inv_g_hidden_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_objs).to(args.device)
191 | self.attr_inv_g_hidden_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_attrs).to(args.device)
192 |
193 | self.obj_inv_g_imgfeat_logits = LinearSoftmaxLogits(dataset.input_shape[0], self.num_objs).to(args.device)
194 | self.attr_inv_g_imgfeat_logits = LinearSoftmaxLogits(dataset.input_shape[0], self.num_attrs).to(args.device)
195 |
196 | self.device = args.device
197 | # check (rename)
198 | self.mu_disjoint = args.train.mu_disjoint
199 | self.mu_ao_emb = args.train.mu_ao_emb
200 | self.mu_img_feat = args.train.mu_img_feat
201 |
202 | def encode(self, input_data, freeze_class1=False, freeze_class2=False):
203 | feature_common = self.ECommon(input_data)
204 | self.last_feature_common = feature_common
205 | if feature_common is input_data:
206 | self.last_feature_common = None
207 |
208 | if freeze_class1:
209 | with torch.no_grad():
210 | feature1 = self.g_inv_O(feature_common).detach()
211 | else:
212 | feature1 = self.g_inv_O(feature_common)
213 |
214 |
215 | if freeze_class2:
216 | with torch.no_grad():
217 | feature2 = self.g_inv_A(feature_common).detach()
218 | else:
219 | feature2 = self.g_inv_A(feature_common)
220 |
221 | return feature1, feature2, feature_common
222 |
223 | def forward(self, input_data,
224 | freeze_class1=False, freeze_class2=False):
225 |
226 | ### init and definitions
227 | freeze_class = freeze_class1, freeze_class2
228 | classifiers = self.emb_cf_O, self.emb_cf_A
229 | class_outputs = [None, None]
230 |
231 | def set_grad_disabled(condition):
232 | return torch.set_grad_enabled(not condition)
233 | ### end init and definitions
234 |
235 | feature1, feature2, feature_common = self.encode(input_data, freeze_class1, freeze_class2)
236 |
237 | for m, feature in enumerate([feature1, feature2]):
238 | with set_grad_disabled(freeze_class[m]):
239 | class_outputs[m] = classifiers[m](feature)
240 | if freeze_class[m]:
241 | class_outputs[m] = class_outputs[m].detach()
242 |
243 | joint_output = (class_outputs[0][..., None] + class_outputs[1][..., None, :])
244 | joint_output = torch.flatten(joint_output[:, :], start_dim=1) # flatten
245 |
246 | # inference
247 | if not self.training and not self.args.model.VisProd:
248 | joint_output = self.mu_disjoint * joint_output
249 |
250 | if self.mu_img_feat > 0 or self.mu_ao_emb > 0:
251 | flattened_ao_emb_joint_scores, flattened_img_emb_scores = \
252 | self.get_joint_embed_classification_scores(self.attrs_idxs, self.objs_idxs,
253 | self.last_feature_common, input_data)
254 |
255 | joint_output += self.mu_ao_emb * flattened_ao_emb_joint_scores
256 | joint_output += self.mu_img_feat * flattened_img_emb_scores
257 |
258 | scores_emb = joint_output.view((-1, self.num_objs, self.num_attrs))
259 | class_outputs[0] = scores_emb.max(axis=2)[0].detach()
260 | class_outputs[1] = scores_emb.max(axis=1)[0].detach() # obj, attr
261 |
262 | return tuple(class_outputs + [feature1, feature2, joint_output])
263 |
264 | def eval_pair_embed_losses(self, args: CommandlineArgs, img_feat, img_hidden_emb, attr_labels, obj_labels,
265 | neg_attr_labels, neg_obj_labels, nll_loss_funcs):
266 | device = args.device
267 |
268 | with ns_profiling_label('labels_to_embeddings'):
269 | h_A_pos, h_O_pos, g_hidden_pos, g_img_pos = self.labels_to_embeddings(attr_labels, obj_labels)
270 | _, _, g_hidden_neg, g_img_neg = self.labels_to_embeddings(neg_attr_labels, neg_obj_labels)
271 |
272 | tloss_g_imgfeat = torch.tensor(0.).to(device)
273 | if args.train.lambda_feat > 0:
274 | with ns_profiling_label('tloss_g_imgfeat'):
275 | tloss_g_imgfeat = triplet_margin_loss(img_feat, g_img_pos, g_img_neg,
276 | margin=args.train.triplet_loss_margin)
277 |
278 | tloss_g_hidden = torch.tensor(0.).to(device)
279 | if args.train.lambda_ao_emb > 0:
280 | with ns_profiling_label('tloss_g_hidden'):
281 | tloss_g_hidden = triplet_margin_loss(img_hidden_emb, g_hidden_pos, g_hidden_neg,
282 | margin=args.train.triplet_loss_margin)
283 |
284 |
285 | # Loss_invert terms
286 | loss_inv_core = torch.tensor(0.).to(device)
287 | if args.train.lambda_aux_disjoint > 0: # check hp name
288 | with ns_profiling_label('loss_inv_core'):
289 | loss_inv_core = nll_sum_loss(self.attr_inv_core_logits(h_A_pos),
290 | self.obj_inv_core_logits(h_O_pos),
291 | attr_labels, obj_labels, nll_loss_funcs)
292 |
293 | loss_inv_g_imgfeat = torch.tensor(0.).to(device)
294 | if args.train.lambda_aux_img > 0: # check hp name
295 | with ns_profiling_label('loss_inv_g_imgfeat'):
296 | loss_inv_g_imgfeat = nll_sum_loss(self.attr_inv_g_imgfeat_logits(g_img_pos),
297 | self.obj_inv_g_imgfeat_logits(g_img_pos),
298 | attr_labels, obj_labels, nll_loss_funcs)
299 |
300 | loss_inv_g_hidden = torch.tensor(0.).to(device)
301 | if args.train.lambda_aux > 0: # check hp name
302 | with ns_profiling_label('loss_inv_g_hidden'):
303 | loss_inv_g_hidden = nll_sum_loss(self.attr_inv_g_hidden_logits(g_hidden_pos),
304 | self.obj_inv_g_hidden_logits(g_hidden_pos),
305 | attr_labels, obj_labels, nll_loss_funcs)
306 |
307 |
308 | return tloss_g_hidden, tloss_g_imgfeat, loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat
309 |
310 | def labels_to_embeddings(self, attr_labels, obj_labels):
311 | """
312 | h_A
313 | |
314 | attr_labels -> h_A ->
315 | > g1_emb_to_hidden_feat -> g2_feat_to_image_feat ->
316 | obj_labels -> h_O -> | |
317 | | g_hidden g_img
318 | h_O
319 |
320 | """
321 | h_A = self.h_A(attr_labels)
322 | h_O = self.h_O(obj_labels)
323 |
324 | g_hidden = self.g1_emb_to_hidden_feat(torch.cat((h_A, h_O), dim=1))
325 | g_img = self.g2_feat_to_image_feat(g_hidden)
326 |
327 | return h_A, h_O, g_hidden, g_img
328 |
329 | def get_joint_embed_classification_scores(self, attrs, objs, common_emb_feat, img_feat):
330 | _, _, g_hidden, g_img = self.labels_to_embeddings(attrs, objs)
331 | vec_dist_img_emb = ((img_feat[:, :, None] - g_img.T[None, :, :]))
332 | flattened_img_emb_scores = -((vec_dist_img_emb ** 2).sum(1))
333 |
334 | if common_emb_feat is not None:
335 | vec_dist_joint_ao_emb = ((common_emb_feat[:, :, None] - g_hidden.T[None, :, :]))
336 | flattened_joint_scores = -((vec_dist_joint_ao_emb ** 2).sum(1))
337 | else:
338 | flattened_joint_scores = 0*flattened_img_emb_scores.detach()
339 |
340 | return flattened_joint_scores, flattened_img_emb_scores
341 |
342 | def _get_ao_outerprod_idxs(self):
343 | device = self.args.device
344 | outerprod_pairs = torch.cartesian_prod(torch.arange(0, self.num_objs, device=device),
345 | torch.arange(0, self.num_attrs, device=device))
346 | objs_idxs = outerprod_pairs[:, 0]
347 | attrs_idxs = outerprod_pairs[:, 1]
348 | return attrs_idxs, objs_idxs
349 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 |
10 | import torch
11 | from torch import Tensor
12 | import numpy as np
13 |
14 | from model import CompModel
15 | from data import CompDataFromDict
16 | from useful_utils import ns_profiling_label, batch_torch_logger, list_to_2d_tuple, slice_dict_to_dict, to_torch
17 | from COSMO_utils import calc_cs_ausuc, per_class_balanced_accuracy
18 |
19 |
20 | def evaluation_step(model, valid_loader, test_loader, loss_funcs, writer, epoch, n_epochs, curr_epoch_metrics,
21 | early_stop_metric_name, best_ES_metric_value, calc_AUC=False):
22 | """
23 | Forward pass validation and test data, updated the metrics and finds best epoch according to validation metric.
24 | """
25 | model.eval()
26 | # a path for dumping logits and predictions of each epoch.
27 | dump_preds_path = os.path.join(writer._log_dir, 'dump_preds', f'epoch_{epoch}')
28 |
29 | # get metrics on validation set
30 | with ns_profiling_label('eval val set'):
31 | metrics_valid = eval_model_with_dataloader(model, valid_loader, loss_funcs, phase_name='valid',
32 | calc_AUC=calc_AUC, dump_to_fs_basedir=dump_preds_path)
33 | metrics_valid['epoch'] = epoch # log epoch number as a metric
34 | # update metrics dictionary with validation metrics
35 | curr_epoch_metrics.update(metrics_valid)
36 |
37 | # get metrics on test set
38 | with ns_profiling_label('eval test set'):
39 | metrics_test = eval_model_with_dataloader(model, test_loader, loss_funcs, phase_name='test', calc_AUC=calc_AUC,
40 | dump_to_fs_basedir=dump_preds_path)
41 | # update metrics dictionary with test metrics
42 | curr_epoch_metrics.update(metrics_test)
43 |
44 |
45 | # Early Stop (ES) monitoring
46 | current_ES_metric_value = metrics_valid[early_stop_metric_name.metric]
47 | ES_metric_polarity = 2 * (early_stop_metric_name.polarity == 'max') - 1
48 | is_best = False
49 | if best_ES_metric_value*ES_metric_polarity <= current_ES_metric_value*ES_metric_polarity:
50 | is_best = True
51 | best_ES_metric_value = current_ES_metric_value
52 |
53 | model.train()
54 |
55 | return best_ES_metric_value, is_best
56 |
57 |
58 | def eval_model_with_dataloader(model, data_loader, loss_funcs, phase_name,
59 | calc_AUC=False,
60 | dump_to_fs_basedir=None):
61 | if dump_to_fs_basedir is not None:
62 | os.makedirs(dump_to_fs_basedir, exist_ok=True)
63 | print(f'\n\n\nEvaluating metrics for {phase_name} phase:')
64 |
65 | # get model logits and ground-truth for evaluation data
66 | with ns_profiling_label('forward_pass_data'):
67 | logits_pred_ys_1, logits_pred_ys_2, logits_ys_joint, y1, y2, ys_joint, logger_means, logits_filenames \
68 | = forward_pass_data(model, data_loader, loss_funcs)
69 |
70 | preds_dump_dict, results = calc_metrics_from_logits(data_loader.dataset, logits_ys_joint, ys_joint, y1, y2,
71 | logits_filenames, phase_name, logger_means, calc_AUC)
72 |
73 | ### dump predictions to filesystem
74 | if dump_to_fs_basedir is not None:
75 | # cast to numpy
76 | preds_dump_dict.update(
77 | dict((k, v.detach().cpu().numpy()) for k, v in preds_dump_dict.items() if isinstance(v, torch.Tensor)))
78 | fname = os.path.join(dump_to_fs_basedir, 'dump_preds' + f'_{phase_name}') + '.npz'
79 | np.savez_compressed(fname, **preds_dump_dict)
80 | print(f'dumped predictions to: {fname}')
81 |
82 | return results
83 |
84 |
85 | def forward_pass_data(model: CompModel, data_loader, loss_funcs):
86 | device = model.device
87 | with torch.no_grad():
88 | num_batches = len(data_loader)
89 | batch_metrics_logger = batch_torch_logger(num_batches=num_batches,
90 | cs_str_args='y1_loss, y2_loss, y_sum_loss',
91 | device=device)
92 |
93 | # predefine lists that aggregates data across batches
94 | ys_1 = [None] * num_batches
95 | ys_2 = [None] * num_batches
96 | fname_list = [None] * num_batches
97 | logits_pred_ys_1 = [None] * num_batches
98 | logits_pred_ys_2 = [None] * num_batches
99 | ys_joint = [None] * num_batches
100 | logits_ys_joint = [None] * num_batches
101 | data_iter = iter(data_loader)
102 |
103 | # iterate on all the data - forward-pass it through the model and aggregate the ground-truth labels,
104 | # the logits, and losses
105 | for i in range(num_batches):
106 | X_batch, y2_batch, y1_batch, _, _, fname_batch = next(data_iter)
107 |
108 | with ns_profiling_label('process batch'):
109 | batch_metrics_logger.new_batch()
110 |
111 | with torch.no_grad():
112 |
113 | with ns_profiling_label('copy to gpu'):
114 | X_batch = X_batch.float().to(device)
115 | y1_batch = y1_batch.long().to(device)
116 | y2_batch = y2_batch.long().to(device)
117 |
118 | with ns_profiling_label('fwd pass'):
119 | y1_pred, y2_pred, _, _, joint_pred = model(X_batch)
120 | y1_pred = y1_pred.detach()
121 | y2_pred = y2_pred.detach()
122 |
123 | # calc a joint label
124 | y_joint = data_loader.dataset.to_joint_label(y1_batch, y2_batch).detach()
125 |
126 | # evaluate the loss for the current eval batch
127 | if loss_funcs is not None:
128 | y1_loss = loss_funcs.y1(y1_pred, y1_batch).detach()
129 | y2_loss = loss_funcs.y2(y2_pred, y2_batch).detach()
130 | y_sum_loss = (y1_loss + y2_loss).detach()
131 |
132 | batch_metrics_logger.log(locals_dict=locals())
133 |
134 | # aggregate labels and logits of current batch.
135 | ys_1[i] = y1_batch
136 | ys_2[i] = y2_batch
137 | ys_joint[i] = y_joint
138 |
139 | fname_list[i] = fname_batch
140 |
141 | logits_pred_ys_1[i] = y1_pred
142 | logits_pred_ys_2[i] = y2_pred
143 | logits_ys_joint[i] = joint_pred
144 |
145 | # calc the mean of each metric across all batches
146 | logger_means = batch_metrics_logger.get_means()
147 |
148 | # concat per-batch ground-truth labels list to tensors
149 | y1 = torch.cat(ys_1).detach()
150 | y2 = torch.cat(ys_2).detach()
151 | ys_joint = torch.cat(ys_joint).detach()
152 |
153 | # cocant per-batch filenames list to a single tuple
154 | fname = sum(list_to_2d_tuple(fname_list), tuple())
155 |
156 | # cocant per-batch logits list to tensors
157 | logits_pred_ys_1 = torch.cat(logits_pred_ys_1).detach()
158 | logits_pred_ys_2 = torch.cat(logits_pred_ys_2).detach()
159 | logits_ys_joint = torch.cat(logits_ys_joint).detach()
160 | return logits_pred_ys_1, logits_pred_ys_2, logits_ys_joint, y1, y2, ys_joint, logger_means, fname
161 |
162 |
163 | def calc_metrics_from_logits(dataset, logits_ys_joint, ys_joint, y1, y2, logits_fnames, phase_name, logger_means={},
164 | calc_AUC=False, device=None):
165 |
166 | """
167 | Notes: 1. This function assumes that the logits contain all (cartesian product) combinations of attribute and object
168 | i.e. including combinations that are not included in the open-set.
169 | 2. calc_AUC is optional, because it slows down evaluation step
170 | """
171 |
172 | # Prepare variables
173 | fname = logits_fnames # filenames
174 |
175 |
176 | # Get:
177 | # (1) logits based scores for objects and for attributes
178 | # (2) pairs indexing variables
179 | logits_pred_ys_1, logits_pred_ys_2 = \
180 | prepare_logits_for_metrics_calc(logits_ys_joint, dataset)
181 |
182 |
183 | shape_obj_attr = dataset.shape_obj_attr
184 | flattened_seen_pairs_mask = dataset.flattened_seen_pairs_mask
185 | flattened_closed_unseen_pairs_mask = dataset.flattened_closed_unseen_pairs_mask
186 | flattened_all_open_pairs_mask = dataset.flattened_all_open_pairs_mask
187 |
188 | # get indices of seen samples and of unseen samples
189 | seen_pairs_joint_class_ids = dataset.seen_pairs_joint_class_ids
190 | ids_seen_samples = np.where(np.in1d(ys_joint.cpu(), seen_pairs_joint_class_ids[0]))[0]
191 | ids_unseen_samples = np.where(~np.in1d(ys_joint.cpu(), seen_pairs_joint_class_ids[0]))[0]
192 |
193 | # ======== Open-Set Accuracy Metrics ====================
194 |
195 | # filter-out logits of irrelevant pairs by setting their scores to -inf
196 | logits_ys_open = logits_ys_joint.clone()
197 | logits_ys_open[:, ~flattened_all_open_pairs_mask] = -np.inf
198 |
199 | # Calc standard (imbalanced) accuracy metrics for open set
200 | unseen_open_acc = acc_from_logits(logits_ys_open, ys_joint, ids_unseen_samples)
201 | seen_open_acc = acc_from_logits(logits_ys_open, ys_joint, ids_seen_samples)
202 | open_H_IMB = 2 * (unseen_open_acc * seen_open_acc) / (unseen_open_acc + seen_open_acc + 1e-7)
203 |
204 | # Calc balanced accuracy metrics for open set
205 | num_class_joint_unseen = len(np.unique(ys_joint[ids_unseen_samples].cpu()))
206 | num_class_joint_seen = len(ys_joint[ids_seen_samples].unique())
207 | pred_open_unseen = logits_ys_open.argmax(1)[ids_unseen_samples]
208 | open_balanced_unseen_acc = per_class_balanced_accuracy(ys_joint[ids_unseen_samples].cpu().numpy(),
209 | pred_open_unseen.cpu().numpy(), num_class_joint_unseen)
210 | pred_open_seen = logits_ys_open.argmax(1)[ids_seen_samples]
211 | open_balanced_seen_acc = per_class_balanced_accuracy(ys_joint[ids_seen_samples].cpu().numpy(),
212 | pred_open_seen.cpu().numpy(), num_class_joint_seen)
213 | # harmonic accuracy
214 | open_H = 2 * (open_balanced_unseen_acc * open_balanced_seen_acc) / (
215 | open_balanced_unseen_acc + open_balanced_seen_acc)
216 | logits_ys_open = None # release memory
217 |
218 |
219 | # ======== Closed-Set Accuracy Metrics ====================
220 |
221 | # filter-out logits of irrelevant pairs by setting their scores to -inf
222 | logits_ZS_ys_closed = logits_ys_joint.clone()
223 | logits_ZS_ys_closed[:, ~flattened_closed_unseen_pairs_mask] = -np.inf
224 |
225 | # Calc standard (imbalanced) accuracy metrics for closed set
226 | closed_acc = acc_from_logits(logits_ZS_ys_closed, ys_joint, ids_unseen_samples)
227 | # Calc balanced accuracy metrics for closed set
228 | pred_closed_unseen = logits_ZS_ys_closed.argmax(1)[ids_unseen_samples]
229 | closed_balanced_acc = per_class_balanced_accuracy(ys_joint[ids_unseen_samples].cpu().numpy(),
230 | pred_closed_unseen.cpu().numpy(), num_class_joint_unseen)
231 | closed_balanced_acc_random = calc_random_balanced_baseline_by_logits_neginf(logits_ZS_ys_closed)
232 | pred_closed_unseen, logits_ZS_ys_closed = None, None # release memory
233 |
234 | # ===== Unseen Accuracy metrics for objects (y1) or attributes (y2) =====
235 | # Note 'unseen' indicates that y1 resides in an unseen *combination* of (y1,y2), not that y1 is unseen
236 | # (and similarly for y2)
237 | pred_y_1_unseen = logits_pred_ys_1.argmax(1)[ids_unseen_samples]
238 | pred_y_2_unseen = logits_pred_ys_2.argmax(1)[ids_unseen_samples]
239 | pred_joint_unseen = logits_ys_joint.argmax(1)[ids_unseen_samples]
240 | y1_acc_unseen = (pred_y_1_unseen == y1[ids_unseen_samples]).sum().float() / len(y1[ids_unseen_samples])
241 | y2_acc_unseen = (pred_y_2_unseen == y2[ids_unseen_samples]).sum().float() / len(y2[ids_unseen_samples])
242 |
243 | # balanced accuracy metric.
244 | # NOTE !!: balanced accuracy metric ignores classes that don't participate in the set (relevant to val set)
245 | num_class_y1_unseen = len(y1[ids_unseen_samples].unique())
246 | y1_balanced_acc_unseen = per_class_balanced_accuracy(y1[ids_unseen_samples], pred_y_1_unseen, num_class_y1_unseen)
247 | y1_balanced_acc_unseen_random = calc_random_balanced_baseline_by_logits_neginf(logits_pred_ys_1)
248 | num_class_y2_unseen = len(y2[ids_unseen_samples].unique())
249 | y2_balanced_acc_unseen = per_class_balanced_accuracy(y2[ids_unseen_samples], pred_y_2_unseen, num_class_y2_unseen)
250 | y2_balanced_acc_unseen_random = calc_random_balanced_baseline_by_logits_neginf(logits_pred_ys_2)
251 |
252 | ##
253 | # Init a dictionary (preds_dump_dict) that holds variables to dump to file system
254 | preds_dump_dict = slice_dict_to_dict(locals(), 'flattened_seen_pairs_mask, seen_pairs_joint_class_ids, shape_obj_attr, '
255 | 'y1, y2, ys_joint, fname, '
256 | 'ids_seen_samples, ids_unseen_samples, '
257 | 'flattened_closed_unseen_pairs_mask, '
258 | 'num_class_joint_unseen, num_class_joint_seen, '
259 | 'flattened_all_open_pairs_mask, pred_open_unseen, pred_open_seen, '
260 | 'pred_closed_unseen, '
261 | 'pred_y_1_unseen, pred_y_2_unseen, pred_joint_unseen, num_class_y1_unseen, '
262 | 'num_class_y2_unseen')
263 |
264 | if calc_AUC:
265 | AUC_open, AUC_open_balanced = calc_ausuc_from_logits(preds_dump_dict, logits_ys_joint, dataset, device)
266 |
267 | results = build_results_dict(locals(), logger_means, phase_name, calc_AUC)
268 | return preds_dump_dict, results
269 |
270 |
271 | def build_results_dict(local_vars_dict, logger_means, phase_name, calc_AUC):
272 | _phase_name = f'_{phase_name}'
273 | results = slice_dict_to_dict(
274 | local_vars_dict,
275 | ['closed_acc',
276 | 'closed_balanced_acc',
277 | 'open_H',
278 | 'open_H_IMB',
279 | 'open_balanced_seen_acc',
280 | 'open_balanced_unseen_acc',
281 | 'seen_open_acc',
282 | 'unseen_open_acc',
283 | 'y1_acc_unseen',
284 | 'y1_balanced_acc_unseen',
285 | 'y1_balanced_acc_unseen_random',
286 | 'y2_acc_unseen',
287 | 'y2_balanced_acc_unseen',
288 | 'y2_balanced_acc_unseen_random'],
289 | returned_keys_postfix=_phase_name)
290 | logger_means_with_postfix = slice_dict_to_dict(logger_means, logger_means.keys(),
291 | returned_keys_postfix=_phase_name)
292 | results.update(logger_means_with_postfix)
293 | # cast torch & numpy scalars to ordinary python scalar (in order to be compatibility with json.dump() )
294 | results.update(dict((k, v.cpu().detach().item()) for k, v in results.items() if isinstance(v, torch.Tensor)))
295 | results.update(dict((k, v.item()) for k, v in results.items() if isinstance(v, np.number)))
296 | if calc_AUC:
297 | results['AUC_open' + _phase_name] = local_vars_dict['AUC_open']
298 | results['AUC_open_balanced' + _phase_name] = local_vars_dict['AUC_open_balanced']
299 | return results
300 |
301 |
302 |
303 | def prepare_logits_for_metrics_calc(logits_ys_joint, dataset):
304 | """ Returns:
305 | (1) logits based scores for objects and for attributes
306 | (2) pairs indexing variables
307 | """
308 | unflattened_logits = logits_ys_joint.view((-1, dataset.num_objs, dataset.num_attrs))
309 | logits_pred_ys_1 = unflattened_logits.max(dim=2)[0].detach()
310 | logits_pred_ys_2 = unflattened_logits.max(dim=1)[0].detach()
311 | return logits_pred_ys_1, logits_pred_ys_2
312 |
313 |
314 | def accuracy(predictions: Tensor, labels: Tensor, return_tensor=False):
315 | acc = (predictions == labels).float().mean()
316 | if not return_tensor:
317 | acc = acc.item()
318 | return acc
319 |
320 |
321 | def acc_from_logits(logits: Tensor, labels: Tensor, subset_ids=None, return_tensor=False):
322 | predictions = logits.argmax(1)
323 | if subset_ids is None:
324 | return accuracy(predictions, labels, return_tensor=return_tensor)
325 | else:
326 | if len(subset_ids) == 0:
327 | return np.nan
328 | return accuracy(predictions[subset_ids], labels[subset_ids], return_tensor=return_tensor)
329 |
330 |
331 | def calc_ausuc_from_logits(preds_dump_dict, logits_ys_joint, dataset: CompDataFromDict, device):
332 | # AUSUC when evaluating on open pairs.
333 | # We use the code from https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py (CVPR 2019)
334 | # For that we first represent our logits and labels to match that API
335 |
336 | ys_joint = preds_dump_dict['ys_joint']
337 | seen_pairs_joint_class_ids = preds_dump_dict['seen_pairs_joint_class_ids']
338 | flattened_closed_unseen_pairs = preds_dump_dict['flattened_closed_unseen_pairs_mask']
339 | flattened_all_open_pairs_mask = preds_dump_dict['flattened_all_open_pairs_mask']
340 | # shape_obj_attr = preds_dump_dict['shape_obj_attr']
341 |
342 | # filter-in only the logits of the open pairs
343 | ids_open_pairs = np.where(flattened_all_open_pairs_mask)[0]
344 | logits_open = logits_ys_joint[:, ids_open_pairs]
345 | logits_ys_joint = None # release memory
346 |
347 | # get a mapping from indices of cartesian-product for y1,y2 to open-pairs
348 | inv_ids_open_pairs = (0 * flattened_all_open_pairs_mask) # init an array of zeros
349 | inv_ids_open_pairs[ids_open_pairs] = np.array(list(range(len(ids_open_pairs))))
350 |
351 | if not isinstance(logits_open, torch.Tensor):
352 | logits_open = to_torch(logits_open, device)
353 | logits_open = logits_open.to(device)
354 |
355 | if isinstance(ys_joint, torch.Tensor):
356 | ys_joint = ys_joint.cpu()
357 |
358 | ys_open = to_torch(inv_ids_open_pairs[ys_joint], device)
359 | seen_pairs = inv_ids_open_pairs[seen_pairs_joint_class_ids[0]]
360 | unseen_pairs = inv_ids_open_pairs[np.where(flattened_closed_unseen_pairs)[0]]
361 |
362 | AUSUC_open = calc_cs_ausuc(logits_open, ys_open, seen_pairs, unseen_pairs, use_balanced_accuracy=False)
363 | AUSUC_open_balanced = calc_cs_ausuc(logits_open, ys_open, seen_pairs, unseen_pairs, use_balanced_accuracy=True)
364 |
365 | return AUSUC_open, AUSUC_open_balanced
366 |
367 |
368 | def calc_random_balanced_baseline_by_logits_neginf(logits):
369 | num_active = (logits[0, :] != -np.inf).sum().item()
370 | if num_active == 0:
371 | return np.nan
372 | else:
373 | return 1./num_active
374 |
375 |
376 |
377 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the License
5 | # located at the root directory.
6 | # ---------------------------------------------------------------
7 |
8 | from collections import OrderedDict
9 | from typing import NamedTuple
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn, optim
14 | from torch.nn.functional import one_hot
15 |
16 | from HSIC import HSIC
17 | from model import get_model
18 |
19 | from useful_utils import to_torch, batch_torch_logger, ns_profiling_label, profileblock, clone_model
20 | from data import get_data_loaders
21 | from pprint import pprint
22 | from eval import evaluation_step, acc_from_logits, eval_model_with_dataloader
23 |
24 | from params import EarlyStopMetric, CommandlineArgs
25 | from model import CompModel
26 |
27 |
28 | def train(args: CommandlineArgs, train_dataset, valid_dataset, test_dataset, writer, model: CompModel = None):
29 | # Init
30 | train_cfg = args.train
31 |
32 | best_metrics = {}
33 | epoch = -1
34 | start_epoch = 0
35 | device = args.device
36 | if len(writer.df) > 0:
37 | start_epoch = writer.df.index.max()
38 |
39 | # Get pytorch data loaders
40 | test_loader, train_loader, valid_loader = get_data_loaders(train_dataset, valid_dataset, test_dataset,
41 | train_cfg.batch_size, train_cfg.num_workers,
42 | test_batchsize=train_cfg.test_batchsize,
43 | shuffle_eval_set=train_cfg.shuffle_eval_set)
44 |
45 | if model is None:
46 | model: CompModel = get_model(args, train_dataset)
47 | best_model = clone_model(model)
48 |
49 | ## NOTE:
50 | # y1 refer to object labels
51 | # y2 refer to attribute labels
52 | num_classes1 = train_dataset.num_objs
53 | num_classes2 = train_dataset.num_attrs
54 |
55 | class NLLLossFuncs(NamedTuple):
56 | y1: nn.NLLLoss
57 | y2: nn.NLLLoss
58 |
59 | nll_loss_funcs = NLLLossFuncs(y1=nn.NLLLoss(), y2=nn.NLLLoss())
60 | if train_cfg.balanced_loss:
61 | nll_loss_funcs=NLLLossFuncs(y1=nn.NLLLoss(weight=to_torch(1 / train_dataset.y1_freqs, device)),
62 | y2=nn.NLLLoss(weight=to_torch(1 / train_dataset.y2_freqs, device)))
63 |
64 | itr_per_epoch = len(train_loader)
65 | n_epochs = train_cfg.n_iter // itr_per_epoch
66 |
67 | best_primary_metric = np.inf * (2 * (train_cfg.primary_early_stop_metric.polarity == 'min') - 1)
68 |
69 | optimizer = get_optimizer(train_cfg.optimizer_name, train_cfg.lr, train_cfg.weight_decay, model, args)
70 |
71 | epoch_range = range(start_epoch + 1, start_epoch + n_epochs + 1)
72 | data_iterator = iter(train_loader)
73 |
74 | for epoch in epoch_range:
75 | with profileblock(label='Epoch train step'):
76 | # Select which tensors to log. Taking an average on all batches per epoch.
77 | logger = batch_torch_logger(num_batches=len(train_loader),
78 | cs_str_args='y1_loss, y2_loss, y_loss, '
79 | 'L_rep, '
80 | 'y1_acc, y2_acc, '
81 | 'HSIC_cond1, HSIC_cond2, '
82 | 'pairwise_dist_cond1_repr1, '
83 | 'pairwise_dist_cond1_repr2, '
84 | 'pairwise_dist_cond2_repr1, '
85 | 'pairwise_dist_cond2_repr2, '
86 | 'HSIC_label_cond1, HSIC_label_cond2',
87 | nanmean_args_cs_str = 'pairwise_dist_cond1_repr1, '
88 | 'pairwise_dist_cond1_repr2, '
89 | 'pairwise_dist_cond2_repr1, '
90 | 'pairwise_dist_cond2_repr2, '
91 | 'tloss_a, tloss_o, tloss_g_imgfeat, '
92 | 'loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat',
93 | device=device
94 | )
95 |
96 |
97 | for batch_cnt in range(len(train_loader)):
98 | logger.new_batch()
99 |
100 | optimizer.zero_grad()
101 |
102 | with ns_profiling_label('fetch batch'):
103 | try:
104 | batch = next(data_iterator)
105 | except StopIteration:
106 | data_iterator = iter(train_loader)
107 | batch = next(data_iterator)
108 |
109 | with ns_profiling_label('send to gpu'):
110 | X, y2, y1 = batch[0], batch[1], batch[2]
111 | neg_attrs, neg_objs = batch[3].to(device), batch[4].to(device)
112 | X = X.float().to(device) # images
113 | y1 = y1.long().to(device) # object labels
114 | y2 = y2.long().to(device) # attribute labels
115 |
116 | with ns_profiling_label('forward pass'):
117 | # y1_scores, y2_scores are logits of negative-squared-distances at the embedding space
118 | # repr1, repr2 are phi_hat1, phi_hat2 at the paper
119 | y1_scores, y2_scores, repr1, repr2, _ = \
120 | model(X, freeze_class1=train_cfg.freeze_class1,
121 | freeze_class2=train_cfg.freeze_class2)
122 |
123 | y1_loss = nll_loss_funcs.y1(y1_scores, y1)
124 | y2_loss = nll_loss_funcs.y2(y2_scores, y2)
125 | y_loss = y1_loss * train_cfg.Y12_balance_coeff + y2_loss * (1 - train_cfg.Y12_balance_coeff)
126 | L_data = train_cfg.lambda_CE * y_loss
127 |
128 | L_invert = 0.
129 | if not args.model.VisProd:
130 | # pair embedding losses
131 | tloss_g_hidden, tloss_g_imgfeat, loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat = \
132 | model.eval_pair_embed_losses(args, X, model.last_feature_common, y2, y1, neg_attrs,
133 | neg_objs, nll_loss_funcs)
134 |
135 | # aggregate triplet loss into L_data
136 | L_data += train_cfg.lambda_ao_emb * tloss_g_hidden
137 | L_data += train_cfg.lambda_feat * tloss_g_imgfeat
138 |
139 | # aggregate components of L_invert
140 | L_invert += train_cfg.lambda_aux_disjoint * loss_inv_core
141 | L_invert += train_cfg.lambda_aux * loss_inv_g_hidden
142 | L_invert += train_cfg.lambda_aux_img * loss_inv_g_imgfeat
143 |
144 |
145 | ys = (y1, y2)
146 | L_rep, HSIC_rep_loss_terms, HSIC_mean_of_median_pairwise_dist_terms = \
147 | conditional_indep_losses(repr1, repr2, ys, train_cfg.HSIC_coeff, indep_coeff2=train_cfg.HSIC_coeff,
148 | num_classes1=num_classes1,
149 | num_classes2=num_classes2, log_median_pairwise_distance=False,
150 | device=device)
151 |
152 | ohy1 = one_hot(y1, num_classes1)
153 | ohy2 = one_hot(y2, num_classes2)
154 | L_oh1, HSIC_oh_loss_terms1, _ = \
155 | conditional_indep_losses(ohy2, repr1, ys, train_cfg.alphaH, indep_coeff2=0, num_classes1=num_classes1,
156 | num_classes2=num_classes2, log_median_pairwise_distance=False,
157 | device=device)
158 |
159 | L_oh2, HSIC_oh_loss_terms2, _ = \
160 | conditional_indep_losses(ohy1, repr2, ys, 0, indep_coeff2=train_cfg.alphaH, num_classes1=num_classes1,
161 | num_classes2=num_classes2, log_median_pairwise_distance=False,
162 | device=device)
163 |
164 | L_indep = L_rep + L_oh1 + L_oh2
165 |
166 | loss = L_data + L_indep + L_invert
167 |
168 |
169 | with ns_profiling_label('loss and update'):
170 | loss.backward()
171 | optimizer.step()
172 |
173 | # log the metrics
174 | with ns_profiling_label('log batch'):
175 |
176 | # extract indep loss terms from lists for logging
177 | HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \
178 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2 = \
179 | HSIC_logging_terms(HSIC_rep_loss_terms, HSIC_mean_of_median_pairwise_dist_terms)
180 |
181 | HSIC_label_cond1 = HSIC_oh_loss_terms1[0]
182 | HSIC_label_cond2 = HSIC_oh_loss_terms2[1]
183 |
184 | with ns_profiling_label('calc y1 train acc'):
185 | y1_acc = acc_from_logits(y1_scores, y1, return_tensor=True).detach()
186 | with ns_profiling_label('calc y2 train acc'):
187 | y2_acc = acc_from_logits(y2_scores, y2, return_tensor=True).detach()
188 |
189 | logger.log(locals_dict=locals())
190 |
191 | curr_epoch_metrics = OrderedDict()
192 | curr_epoch_metrics.update(logger.get_means())
193 | with profileblock(label='Evaluation step'):
194 |
195 | best_primary_metric, is_best = evaluation_step(model, valid_loader, test_loader, nll_loss_funcs, writer, epoch,
196 | n_epochs, curr_epoch_metrics,
197 | early_stop_metric_name=train_cfg.primary_early_stop_metric,
198 | best_ES_metric_value=best_primary_metric,
199 | calc_AUC=train_cfg.metrics.calc_AUC)
200 |
201 | # write current epoch metrics to metrics logger
202 | with ns_profiling_label('write eval step metrics'):
203 | for metric_key, value in curr_epoch_metrics.items():
204 | writer.add_scalar(f'{metric_key}', value, epoch)
205 | # dump collected metrics to csv
206 | writer.dump_to_csv()
207 |
208 | # print all columns
209 | last_results_as_string = writer.last_results_as_string()
210 | last_results_as_string = '\n '.join(last_results_as_string.split('\n'))
211 | if train_cfg.verbose:
212 | print('\n[%d/%d]' % (epoch, n_epochs), last_results_as_string)
213 |
214 | if is_best:
215 | best_model = clone_model(model)
216 | best_metrics = writer.df.iloc[-1, :].to_dict()
217 | best_metrics['epoch'] = int(writer.df.iloc[-1, :].name)
218 | if train_cfg.verbose:
219 | print(f'Best! (@epoch {epoch})')
220 |
221 |
222 | model = best_model
223 | model.eval()
224 |
225 | print('Best epoch was: ', best_metrics['epoch'])
226 | print(f'Primary early stop monitor was {train_cfg.primary_early_stop_metric}')
227 |
228 | val_metrics = eval_model_with_dataloader(model, valid_loader, nll_loss_funcs, phase_name='valid')
229 | best_metrics.update(val_metrics)
230 | print('Val metrics on best val epoch :')
231 | pprint([(k, v) for k, v in val_metrics.items()])
232 |
233 | test_metrics = eval_model_with_dataloader(model, test_loader, nll_loss_funcs, phase_name='test')
234 | best_metrics.update(test_metrics)
235 | print('\n\nTest metrics on best val epoch :')
236 | pprint([(k, v) for k, v in test_metrics.items()])
237 |
238 | # cast numpy items to their original type
239 | for k, v in best_metrics.items():
240 | if isinstance(v, np.number):
241 | best_metrics[k] = v.item()
242 |
243 | #### two redundant calls to align random-number-generator with original training script
244 | # check: (to delete?)
245 | _ = eval_model_with_dataloader(model, valid_loader, nll_loss_funcs, phase_name='valid')
246 | _ = eval_model_with_dataloader(model, test_loader, nll_loss_funcs, phase_name='test')
247 |
248 |
249 | return model, best_metrics
250 |
251 |
252 | def alternate_training(args, train_dataset, valid_dataset, test_dataset, writer):
253 | train_cfg = args.train
254 | ### alternate between heads ###
255 | ## train first head #
256 | # Save 'HSIC_coeff' for usage during step2. For step1, set HSIC_coeff to 0,
257 | HSIC_coeff_step2 = train_cfg.HSIC_coeff
258 | train_cfg.HSIC_coeff = 0
259 | if train_cfg.alternate_ys == 12:
260 | train_cfg.Y12_balance_coeff = 1
261 | train_cfg.primary_early_stop_metric = EarlyStopMetric('y1_balanced_acc_unseen_valid', 'max')
262 | train_cfg.freeze_class1 = False
263 | train_cfg.freeze_class2 = True
264 | elif train_cfg.alternate_ys == 21:
265 | train_cfg.Y12_balance_coeff = 0
266 | train_cfg.primary_early_stop_metric = EarlyStopMetric('y2_balanced_acc_unseen_valid', 'max')
267 | train_cfg.freeze_class1 = True
268 | train_cfg.freeze_class2 = False
269 | else:
270 | raise ValueError("train_cfg.alternate_ys = ", train_cfg.alternate_ys)
271 | print(
272 | f"first iter ay={train_cfg.alternate_ys}, primary_early_stop_metric={train_cfg.primary_early_stop_metric.metric}")
273 | model, best_metrics_dict1 = train(args, train_dataset, valid_dataset, test_dataset, writer)
274 | print('step1 best metrics epoch = ', best_metrics_dict1['epoch'])
275 |
276 | ## train 2nd head #
277 | model.train()
278 | train_cfg.set_n_iter(len(train_dataset.data), (1 + train_cfg.max_epoch_step2))
279 | train_cfg.lr = train_cfg.lr_step2
280 | if train_cfg.alternate_ys == 12:
281 | train_cfg.Y12_balance_coeff = 0
282 | train_cfg.primary_early_stop_metric = EarlyStopMetric('epoch', 'max')
283 | train_cfg.freeze_class1 = True
284 | train_cfg.freeze_class2 = False
285 | elif train_cfg.alternate_ys == 21:
286 | train_cfg.Y12_balance_coeff = 1
287 | train_cfg.primary_early_stop_metric = EarlyStopMetric('epoch', 'max')
288 | train_cfg.freeze_class1 = False
289 | train_cfg.freeze_class2 = True
290 | train_cfg.HSIC_coeff = HSIC_coeff_step2
291 | train(args, train_dataset, valid_dataset, test_dataset, writer, model=model)
292 |
293 |
294 |
295 | def get_optimizer(optimizer_name, lr, weight_decay, model, args: CommandlineArgs):
296 | """ returns an optimizer instance """
297 |
298 | # list the weights to optimize
299 | obj_related_weights = list(model.g_inv_O.parameters()) + list(model.emb_cf_O.parameters())
300 | attr_related_weights = list(model.g_inv_A.parameters()) + list(model.emb_cf_A.parameters())
301 | pair_related_weights = []
302 | if not args.model.VisProd:
303 | obj_related_weights += list(model.obj_inv_core_logits.parameters())
304 | attr_related_weights += list(model.attr_inv_core_logits.parameters())
305 |
306 | obj_related_weights += list(model.obj_inv_g_hidden_logits.parameters())
307 | attr_related_weights += list(model.attr_inv_g_hidden_logits.parameters())
308 |
309 | obj_related_weights += list(model.obj_inv_g_imgfeat_logits.parameters())
310 | attr_related_weights += list(model.attr_inv_g_imgfeat_logits.parameters())
311 |
312 | pair_related_weights = list(model.g1_emb_to_hidden_feat.parameters()) + list(model.g2_feat_to_image_feat.parameters())
313 | all_weights = obj_related_weights + attr_related_weights + pair_related_weights + list(model.ECommon.parameters())
314 |
315 | # set optimizer hyper param
316 | optimizer_kwargs = dict(lr=lr, weight_decay=weight_decay)
317 | if optimizer_name.lower() == 'nest':
318 | optimizer_kwargs.update(momentum=0.9, nesterov=True)
319 |
320 | # choose optimizer
321 | optimizer_class = dict(adam=optim.Adam, sgd=optim.SGD, nest=optim.SGD)[optimizer_name.lower()]
322 |
323 | # initialize an optimizer instance
324 | optimizer = optimizer_class(all_weights, **optimizer_kwargs)
325 |
326 | return optimizer
327 |
328 |
329 | def conditional_indep_losses(repr1, repr2, ys, indep_coeff, indep_coeff2=None, num_classes1=None, num_classes2=None,
330 | Hkernel='L', Hkernel_sigma_obj=None, Hkernel_sigma_attr=None,
331 | log_median_pairwise_distance=False, device=None):
332 | # check readability
333 |
334 | normalize_to_mean = (num_classes1, num_classes2)
335 |
336 | if indep_coeff2 is None:
337 | indep_coeff2 = indep_coeff
338 |
339 | HSIC_loss_terms = []
340 | HSIC_mean_of_median_pairwise_dist_terms = []
341 | with ns_profiling_label('HSIC/d loss calc'):
342 | # iterate on both heads
343 | for m, num_class in enumerate((num_classes1, num_classes2)):
344 | with ns_profiling_label(f'iter m={m}'):
345 | HSIC_tmp_loss = 0.
346 | HSIC_median_pw_y1 = []
347 | HSIC_median_pw_y2 = []
348 |
349 | labels_in_batch_sorted, indices = torch.sort(ys[m])
350 | unique_ixs = 1 + (labels_in_batch_sorted[1:] - labels_in_batch_sorted[:-1]).nonzero()
351 | unique_ixs = [0] + unique_ixs.flatten().cpu().numpy().tolist() + [len(ys[m])]
352 |
353 | for j in range(len(unique_ixs)-1):
354 | current_class_indices = unique_ixs[j], unique_ixs[j + 1]
355 | count = current_class_indices[1] - current_class_indices[0]
356 | if count < 2:
357 | continue
358 | curr_class_slice = slice(*current_class_indices)
359 | curr_class_indices = indices[curr_class_slice].sort()[0]
360 |
361 | with ns_profiling_label(f'iter j={j}'):
362 | HSIC_kernel = dict(G='Gaussian', L='Linear')[Hkernel]
363 | with ns_profiling_label('HSIC call'):
364 | hsic_loss_i, median_pairwise_distance_y1, median_pairwise_distance_y2 = \
365 | HSIC(repr1[curr_class_indices, :].float(), repr2[curr_class_indices, :].float(),
366 | kernelX=HSIC_kernel, kernelY=HSIC_kernel,
367 | sigmaX=Hkernel_sigma_obj, sigmaY=Hkernel_sigma_attr,
368 | log_median_pairwise_distance=log_median_pairwise_distance)
369 | HSIC_tmp_loss += hsic_loss_i
370 | HSIC_median_pw_y1.append(median_pairwise_distance_y1)
371 | HSIC_median_pw_y2.append(median_pairwise_distance_y2)
372 |
373 | HSIC_tmp_loss = HSIC_tmp_loss / normalize_to_mean[m]
374 | HSIC_loss_terms.append(HSIC_tmp_loss)
375 | HSIC_mean_of_median_pairwise_dist_terms.append([np.mean(HSIC_median_pw_y1), np.mean(HSIC_median_pw_y2)])
376 |
377 | indep_loss = torch.tensor(0.).to(device)
378 | if indep_coeff > 0:
379 | indep_loss = (indep_coeff * HSIC_loss_terms[0] + indep_coeff2 * HSIC_loss_terms[1]) / 2
380 | return indep_loss, HSIC_loss_terms, HSIC_mean_of_median_pairwise_dist_terms
381 |
382 |
383 |
384 |
385 | def HSIC_logging_terms(HSIC_loss_terms, HSIC_mean_of_median_pairwise_dist_terms):
386 | """ This is just a utility function for naming monitored values of HSIC loss """
387 | HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \
388 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2 = [np.nan] * 6
389 |
390 | if HSIC_loss_terms:
391 | HSIC_cond1 = HSIC_loss_terms[0]
392 | HSIC_cond2 = HSIC_loss_terms[1]
393 | pairwise_dist_cond1_repr1 = HSIC_mean_of_median_pairwise_dist_terms[0][0]
394 | pairwise_dist_cond1_repr2 = HSIC_mean_of_median_pairwise_dist_terms[0][1]
395 | pairwise_dist_cond2_repr1 = HSIC_mean_of_median_pairwise_dist_terms[1][0]
396 | pairwise_dist_cond2_repr2 = HSIC_mean_of_median_pairwise_dist_terms[1][1]
397 |
398 | return HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \
399 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2
400 |
--------------------------------------------------------------------------------
/taskmodularnets/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 | Section 1 -- Definitions.
71 |
72 | a. Adapted Material means material subject to Copyright and Similar
73 | Rights that is derived from or based upon the Licensed Material
74 | and in which the Licensed Material is translated, altered,
75 | arranged, transformed, or otherwise modified in a manner requiring
76 | permission under the Copyright and Similar Rights held by the
77 | Licensor. For purposes of this Public License, where the Licensed
78 | Material is a musical work, performance, or sound recording,
79 | Adapted Material is always produced where the Licensed Material is
80 | synched in timed relation with a moving image.
81 |
82 | b. Adapter's License means the license You apply to Your Copyright
83 | and Similar Rights in Your contributions to Adapted Material in
84 | accordance with the terms and conditions of this Public License.
85 |
86 | c. Copyright and Similar Rights means copyright and/or similar rights
87 | closely related to copyright including, without limitation,
88 | performance, broadcast, sound recording, and Sui Generis Database
89 | Rights, without regard to how the rights are labeled or
90 | categorized. For purposes of this Public License, the rights
91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 | Rights.
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. NonCommercial means not primarily intended for or directed towards
116 | commercial advantage or monetary compensation. For purposes of
117 | this Public License, the exchange of the Licensed Material for
118 | other material subject to Copyright and Similar Rights by digital
119 | file-sharing or similar means is NonCommercial provided there is
120 | no payment of monetary compensation in connection with the
121 | exchange.
122 |
123 | j. Share means to provide material to the public by any means or
124 | process that requires permission under the Licensed Rights, such
125 | as reproduction, public display, public performance, distribution,
126 | dissemination, communication, or importation, and to make material
127 | available to the public including in ways that members of the
128 | public may access the material from a place and at a time
129 | individually chosen by them.
130 |
131 | k. Sui Generis Database Rights means rights other than copyright
132 | resulting from Directive 96/9/EC of the European Parliament and of
133 | the Council of 11 March 1996 on the legal protection of databases,
134 | as amended and/or succeeded, as well as other essentially
135 | equivalent rights anywhere in the world.
136 |
137 | l. You means the individual or entity exercising the Licensed Rights
138 | under this Public License. Your has a corresponding meaning.
139 |
140 | Section 2 -- Scope.
141 |
142 | a. License grant.
143 |
144 | 1. Subject to the terms and conditions of this Public License,
145 | the Licensor hereby grants You a worldwide, royalty-free,
146 | non-sublicensable, non-exclusive, irrevocable license to
147 | exercise the Licensed Rights in the Licensed Material to:
148 |
149 | a. reproduce and Share the Licensed Material, in whole or
150 | in part, for NonCommercial purposes only; and
151 |
152 | b. produce, reproduce, and Share Adapted Material for
153 | NonCommercial purposes only.
154 |
155 | 2. Exceptions and Limitations. For the avoidance of doubt, where
156 | Exceptions and Limitations apply to Your use, this Public
157 | License does not apply, and You do not need to comply with
158 | its terms and conditions.
159 |
160 | 3. Term. The term of this Public License is specified in Section
161 | 6(a).
162 |
163 | 4. Media and formats; technical modifications allowed. The
164 | Licensor authorizes You to exercise the Licensed Rights in
165 | all media and formats whether now known or hereafter created,
166 | and to make technical modifications necessary to do so. The
167 | Licensor waives and/or agrees not to assert any right or
168 | authority to forbid You from making technical modifications
169 | necessary to exercise the Licensed Rights, including
170 | technical modifications necessary to circumvent Effective
171 | Technological Measures. For purposes of this Public License,
172 | simply making modifications authorized by this Section 2(a)
173 | (4) never produces Adapted Material.
174 |
175 | 5. Downstream recipients.
176 |
177 | a. Offer from the Licensor -- Licensed Material. Every
178 | recipient of the Licensed Material automatically
179 | receives an offer from the Licensor to exercise the
180 | Licensed Rights under the terms and conditions of this
181 | Public License.
182 |
183 | b. No downstream restrictions. You may not offer or impose
184 | any additional or different terms or conditions on, or
185 | apply any Effective Technological Measures to, the
186 | Licensed Material if doing so restricts exercise of the
187 | Licensed Rights by any recipient of the Licensed
188 | Material.
189 |
190 | 6. No endorsement. Nothing in this Public License constitutes or
191 | may be construed as permission to assert or imply that You
192 | are, or that Your use of the Licensed Material is, connected
193 | with, or sponsored, endorsed, or granted official status by,
194 | the Licensor or others designated to receive attribution as
195 | provided in Section 3(a)(1)(A)(i).
196 |
197 | b. Other rights.
198 |
199 | 1. Moral rights, such as the right of integrity, are not
200 | licensed under this Public License, nor are publicity,
201 | privacy, and/or other similar personality rights; however, to
202 | the extent possible, the Licensor waives and/or agrees not to
203 | assert any such rights held by the Licensor to the limited
204 | extent necessary to allow You to exercise the Licensed
205 | Rights, but not otherwise.
206 |
207 | 2. Patent and trademark rights are not licensed under this
208 | Public License.
209 |
210 | 3. To the extent possible, the Licensor waives any right to
211 | collect royalties from You for the exercise of the Licensed
212 | Rights, whether directly or through a collecting society
213 | under any voluntary or waivable statutory or compulsory
214 | licensing scheme. In all other cases the Licensor expressly
215 | reserves any right to collect such royalties, including when
216 | the Licensed Material is used other than for NonCommercial
217 | purposes.
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material (including in modified
227 | form), You must:
228 |
229 | a. retain the following if it is supplied by the Licensor
230 | with the Licensed Material:
231 |
232 | i. identification of the creator(s) of the Licensed
233 | Material and any others designated to receive
234 | attribution, in any reasonable manner requested by
235 | the Licensor (including by pseudonym if
236 | designated);
237 |
238 | ii. a copyright notice;
239 |
240 | iii. a notice that refers to this Public License;
241 |
242 | iv. a notice that refers to the disclaimer of
243 | warranties;
244 |
245 | v. a URI or hyperlink to the Licensed Material to the
246 | extent reasonably practicable;
247 |
248 | b. indicate if You modified the Licensed Material and
249 | retain an indication of any previous modifications; and
250 |
251 | c. indicate the Licensed Material is licensed under this
252 | Public License, and include the text of, or the URI or
253 | hyperlink to, this Public License.
254 |
255 | 2. You may satisfy the conditions in Section 3(a)(1) in any
256 | reasonable manner based on the medium, means, and context in
257 | which You Share the Licensed Material. For example, it may be
258 | reasonable to satisfy the conditions by providing a URI or
259 | hyperlink to a resource that includes the required
260 | information.
261 |
262 | 3. If requested by the Licensor, You must remove any of the
263 | information required by Section 3(a)(1)(A) to the extent
264 | reasonably practicable.
265 |
266 | 4. If You Share Adapted Material You produce, the Adapter's
267 | License You apply must not prevent recipients of the Adapted
268 | Material from complying with this Public License.
269 |
270 | Section 4 -- Sui Generis Database Rights.
271 |
272 | Where the Licensed Rights include Sui Generis Database Rights that
273 | apply to Your use of the Licensed Material:
274 |
275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 | to extract, reuse, reproduce, and Share all or a substantial
277 | portion of the contents of the database for NonCommercial purposes
278 | only;
279 |
280 | b. if You include all or a substantial portion of the database
281 | contents in a database in which You have Sui Generis Database
282 | Rights, then the database in which You have Sui Generis Database
283 | Rights (but not its individual contents) is Adapted Material; and
284 |
285 | c. You must comply with the conditions in Section 3(a) if You Share
286 | all or a substantial portion of the contents of the database.
287 |
288 | For the avoidance of doubt, this Section 4 supplements and does not
289 | replace Your obligations under this Public License where the Licensed
290 | Rights include other Copyright and Similar Rights.
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 | Section 6 -- Term and Termination.
321 |
322 | a. This Public License applies for the term of the Copyright and
323 | Similar Rights licensed here. However, if You fail to comply with
324 | this Public License, then Your rights under this Public License
325 | terminate automatically.
326 |
327 | b. Where Your right to use the Licensed Material has terminated under
328 | Section 6(a), it reinstates:
329 |
330 | 1. automatically as of the date the violation is cured, provided
331 | it is cured within 30 days of Your discovery of the
332 | violation; or
333 |
334 | 2. upon express reinstatement by the Licensor.
335 |
336 | For the avoidance of doubt, this Section 6(b) does not affect any
337 | right the Licensor may have to seek remedies for Your violations
338 | of this Public License.
339 |
340 | c. For the avoidance of doubt, the Licensor may also offer the
341 | Licensed Material under separate terms or conditions or stop
342 | distributing the Licensed Material at any time; however, doing so
343 | will not terminate this Public License.
344 |
345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 | License.
347 |
348 | Section 7 -- Other Terms and Conditions.
349 |
350 | a. The Licensor shall not be bound by any additional or different
351 | terms or conditions communicated by You unless expressly agreed.
352 |
353 | b. Any arrangements, understandings, or agreements regarding the
354 | Licensed Material not stated herein are separate from and
355 | independent of the terms and conditions of this Public License.
356 |
357 | Section 8 -- Interpretation.
358 |
359 | a. For the avoidance of doubt, this Public License does not, and
360 | shall not be interpreted to, reduce, limit, restrict, or impose
361 | conditions on any use of the Licensed Material that could lawfully
362 | be made without permission under this Public License.
363 |
364 | b. To the extent possible, if any provision of this Public License is
365 | deemed unenforceable, it shall be automatically reformed to the
366 | minimum extent necessary to make it enforceable. If the provision
367 | cannot be reformed, it shall be severed from this Public License
368 | without affecting the enforceability of the remaining terms and
369 | conditions.
370 |
371 | c. No term or condition of this Public License will be waived and no
372 | failure to comply consented to unless expressly agreed to by the
373 | Licensor.
374 |
375 | d. Nothing in this Public License constitutes or may be interpreted
376 | as a limitation upon, or waiver of, any privileges and immunities
377 | that apply to the Licensor or You, including from the legal
378 | processes of any jurisdiction or authority.
379 |
380 | =======================================================================
381 |
382 | Creative Commons is not a party to its public
383 | licenses. Notwithstanding, Creative Commons may elect to apply one of
384 | its public licenses to material it publishes and in those instances
385 | will be considered the “Licensor.” The text of the Creative Commons
386 | public licenses is dedicated to the public domain under the CC0 Public
387 | Domain Dedication. Except for the limited purpose of indicating that
388 | material is shared under a Creative Commons public license or as
389 | otherwise permitted by the Creative Commons policies published at
390 | creativecommons.org/policies, Creative Commons does not authorize the
391 | use of the trademark "Creative Commons" or any other trademark or logo
392 | of Creative Commons without its prior written consent including,
393 | without limitation, in connection with any unauthorized modifications
394 | to any of its public licenses or any other arrangements,
395 | understandings, or agreements concerning use of licensed material. For
396 | the avoidance of doubt, this paragraph does not form part of the
397 | public licenses.
398 |
399 | Creative Commons may be contacted at creativecommons.org.
400 |
--------------------------------------------------------------------------------
/COSMO_utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from the following repository:
5 | # https://github.com/yuvalatzmon/COSMO/
6 | #
7 | # The license for the original version of this file can be
8 | # found in this directory (LICENSE_COSMO). The modifications
9 | # to this file are subject to the License
10 | # located at the root directory.
11 | # ---------------------------------------------------------------
12 |
13 |
14 |
15 | import subprocess
16 | import sys
17 | from contextlib import contextmanager
18 |
19 | import numpy as np
20 | import pandas as pd
21 | import torch
22 | from sklearn.metrics import auc
23 |
24 |
25 | def _AUSUC(Acc_tr, Acc_ts):
26 |
27 | """ Calc area under seen-unseen curve
28 | Source: https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py
29 | """
30 |
31 | # Sort by X axis
32 | X_sorted_arg = np.argsort(Acc_tr)
33 | sorted_X = np.array(Acc_tr)[X_sorted_arg]
34 | sorted_Y = np.array(Acc_ts)[X_sorted_arg]
35 |
36 | # zero pad
37 | leftmost_X, leftmost_Y = 0, sorted_Y[0]
38 | rightmost_X, rightmost_Y = sorted_X[-1], 0
39 | sorted_X = np.block([np.array([leftmost_X]), sorted_X, np.array([rightmost_X])])
40 | sorted_Y = np.block([np.array([leftmost_Y]), sorted_Y, np.array([rightmost_Y])])
41 |
42 | # eval AUC
43 | AUSUC = auc(sorted_X, sorted_Y)
44 |
45 | return AUSUC
46 |
47 |
48 | def calc_cs_ausuc(pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy=True, gamma_range=None, verbose=True):
49 | """ Calc area under seen-unseen curve, according to calibrated stacking (Chao et al. 2016)
50 | Adapted from: https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py
51 |
52 | NOTE: pred cannot accept -np.inf values
53 | """
54 |
55 | assert(pred.min() != -np.inf)
56 |
57 | # make numbers positive with min = 0
58 | # pred = pred.copy() - pred.min()
59 |
60 | if gamma_range is None:
61 | # make a log spaced search range
62 | gamma = abs((pred[:, unseen_classses].max(dim=1)[0] - pred[:, seen_classses].max(dim=1)[0]).min()) + abs(
63 | (pred[:, unseen_classses].max(dim=1)[0] - pred[:, seen_classses].max(dim=1)[0]).max())
64 | gamma = gamma.item()
65 | pos_range = (-(np.logspace(0, np.log10(gamma + 1), 50) - 1)).tolist()
66 | neg_range = (np.logspace(0, np.log10(gamma +1), 50) - 1).tolist()
67 | gamma_range = sorted(pos_range + neg_range)
68 |
69 | # torch.cuda.empty_cache()
70 | Acc_tr_values, Acc_ts_values = calc_acc_tr_ts_over_gamma_range(gamma_range, pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy)
71 | # torch.cuda.empty_cache()
72 |
73 | cs_ausuc = _AUSUC(Acc_tr=Acc_tr_values, Acc_ts=Acc_ts_values)
74 | if min(Acc_tr_values) > 0.01:
75 | print(f'CS AUSUC ERROR: Increase gamma range (add low values), because min(Acc_tr_values) equals {min(Acc_tr_values)}')
76 | return np.nan
77 | if min(Acc_ts_values) > 0.01:
78 | print(f'CS AUSUC ERROR: Increase gamma range (add high values), because min(Acc_ts_values) equals {min(Acc_ts_values)}')
79 | return np.nan
80 | if verbose:
81 | print('AUSUC: max(acc_seen) = ', max(Acc_tr_values))
82 | print('AUSUC: max(acc_unseen)', max(Acc_ts_values))
83 | print(f'AUSUC (by Calibrated Stacking) = {cs_ausuc:.3f}')
84 | return cs_ausuc
85 |
86 |
87 | def calc_acc_tr_ts_over_gamma_range(gamma_range, pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy):
88 | Acc_tr_values = []
89 | Acc_ts_values = []
90 |
91 | for gamma in gamma_range:
92 | params = gamma, pred, seen_classses, unseen_classses, use_balanced_accuracy, y_gt
93 | Acc_tr, Acc_ts = cs_at_single_operating_point(params)
94 | Acc_tr_values.append(Acc_tr)
95 | Acc_ts_values.append(Acc_ts)
96 |
97 | return Acc_tr_values, Acc_ts_values
98 |
99 |
100 | def cs_at_single_operating_point(params):
101 | gamma, pred, seen_classses, unseen_classses, use_balanced_accuracy, y_gt = params
102 | if isinstance(pred, torch.Tensor):
103 | cs_pred = pred.clone()
104 | else:
105 | cs_pred = pred.copy()
106 |
107 | cs_pred[:, seen_classses] -= gamma
108 | zs_metrics = ZSL_Metrics(seen_classses, unseen_classses)
109 | Acc_ts, Acc_tr, H = zs_metrics.generlized_scores(y_gt.cpu().numpy(), cs_pred, num_class_according_y_true=True,
110 | use_balanced_accuracy=use_balanced_accuracy)
111 | cs_pred = None # release memory
112 | # torch.cuda.empty_cache() # release memory
113 |
114 | return Acc_tr, Acc_ts
115 |
116 |
117 | class ZSL_Metrics():
118 | def __init__(self, seen_classes, unseen_classes, report_entropy=False):
119 | self._seen_classes = np.sort(seen_classes)
120 | self._unseen_classes = np.sort(unseen_classes)
121 | self._n_seen = len(seen_classes)
122 | self._n_unseen = len(unseen_classes)
123 | self._report_entropy = report_entropy
124 |
125 | assert(self._n_seen == len(np.unique(seen_classes))) # sanity check
126 | assert(self._n_unseen == len(np.unique(unseen_classes))) # sanity check
127 |
128 |
129 | def unseen_balanced_accuracy(self, y_true, pred_softmax):
130 | Acc_zs, Ent_zs = self._subset_classes_balanced_accuracy(y_true, pred_softmax,
131 | self._unseen_classes)
132 | if self._report_entropy:
133 | return Acc_zs, Ent_zs
134 | else:
135 | return Acc_zs
136 |
137 | def seen_balanced_accuracy(self, y_true, pred_softmax):
138 | Acc_seen, Ent_seen = self._subset_classes_balanced_accuracy(y_true,
139 | pred_softmax,
140 | self._seen_classes)
141 | if self._report_entropy:
142 | return Acc_seen, Ent_seen
143 | else:
144 | return Acc_seen
145 |
146 | def generlized_scores(self, y_true_cpu, pred_softmax, num_class_according_y_true=True, use_balanced_accuracy=True):
147 |
148 | Acc_ts, Ent_ts = self._generalized_unseen_accuracy(y_true_cpu, pred_softmax, num_class_according_y_true,
149 | use_balanced_accuracy)
150 | Acc_tr, Ent_tr = self._generalized_seen_accuracy(y_true_cpu, pred_softmax, num_class_according_y_true,
151 | use_balanced_accuracy)
152 | H = 2*Acc_tr*Acc_ts/(Acc_tr + Acc_ts + 1e-8)
153 | Ent_H = 2*Ent_tr*Ent_ts/(Ent_tr + Ent_ts + 1e-8)
154 |
155 | if self._report_entropy:
156 | return Acc_ts, Acc_tr, H, Ent_ts, Ent_tr, Ent_H
157 | else:
158 | return Acc_ts, Acc_tr, H
159 |
160 | def _generalized_unseen_accuracy(self, y_true_cpu, pred_softmax, num_class_according_y_true, use_balanced_accuracy):
161 | return self._generalized_subset_accuracy(y_true_cpu, pred_softmax,
162 | self._unseen_classes, num_class_according_y_true, use_balanced_accuracy)
163 |
164 | def _generalized_seen_accuracy(self, y_true_cpu, pred_softmax, num_class_according_y_true, use_balanced_accuracy):
165 | return self._generalized_subset_accuracy(y_true_cpu, pred_softmax,
166 | self._seen_classes, num_class_according_y_true, use_balanced_accuracy)
167 |
168 | def _generalized_subset_accuracy(self, y_true_cpu, pred_softmax, subset_classes, num_class_according_y_true,
169 | use_balanced_accuracy):
170 | is_member = np.in1d # np.in1d is like MATLAB's ismember
171 | ix_subset_samples = is_member(y_true_cpu, subset_classes)
172 |
173 | y_true_subset = y_true_cpu[ix_subset_samples]
174 | all_classes = np.sort(np.block([self._seen_classes, self._unseen_classes]))
175 |
176 | if isinstance(pred_softmax, torch.Tensor):
177 | amax = (pred_softmax[:, all_classes]).argmax(1).cpu()
178 | else:
179 | amax = (pred_softmax[:, all_classes]).argmax(1)
180 |
181 | y_pred = all_classes[amax]
182 | y_pred_subset = y_pred[ix_subset_samples]
183 |
184 | if use_balanced_accuracy:
185 |
186 | num_class = len(subset_classes)
187 | # infer number of classes according to unique(y_true_subset)
188 | if num_class_according_y_true:
189 | num_class = len(np.unique(y_true_subset))
190 | Acc = float(per_class_balanced_accuracy(y_true_subset, y_pred_subset, num_class))
191 | else:
192 | Acc = (y_true_subset == y_pred_subset).mean()
193 |
194 | # Ent = float(entropy2(pred_softmax[ix_subset_samples, :][:, all_classes]).mean())
195 | Ent = 0*Acc + 1e-3 # disabled because its too slow
196 | return Acc, Ent
197 |
198 | def _subset_classes_balanced_accuracy(self, y_true, pred_softmax, subset_classes):
199 | is_member = np.in1d # np.in1d is like MATLAB's ismember
200 | ix_subset_samples = is_member(y_true, subset_classes)
201 |
202 | y_true_zs = y_true[ix_subset_samples]
203 | y_pred = subset_classes[(pred_softmax[:, subset_classes]).argmax(dim=1)]
204 | y_pred_zs = y_pred[ix_subset_samples]
205 |
206 | Acc = float(per_class_balanced_accuracy(y_true_zs, y_pred_zs, len(subset_classes)))
207 | # Ent = float(entropy2(pred_softmax[:, subset_classes]).mean())
208 | Ent = 0*Acc + 1e-3 # disabled because its too slow
209 | return Acc, Ent
210 |
211 |
212 | def per_class_balanced_accuracy(y_true, y_pred, num_class=None):
213 | """ A balanced accuracy metric as in Xian (CVPR 2017). Accuracy is
214 | evaluated individually per class, and then uniformly averaged between
215 | classes.
216 | """
217 | if len(y_true) == 0 or num_class == 0:
218 | return np.nan
219 |
220 | if isinstance(y_true, torch.Tensor):
221 | y_true = y_true.flatten().cpu().numpy().astype('int32')
222 | else:
223 | y_true = y_true.flatten().astype('int32')
224 |
225 | if isinstance(y_pred, torch.Tensor):
226 | y_pred = y_pred.cpu().numpy()
227 |
228 |
229 | if num_class is None:
230 | num_class = len(np.unique(np.block([y_true, y_pred])))
231 |
232 | max_class_id = 1+max([num_class, y_true.max(), y_pred.max()])
233 |
234 | counts_per_class_s = pd.Series(y_true).value_counts()
235 | counts_per_class = np.zeros((max_class_id,))
236 | counts_per_class[counts_per_class_s.index] = counts_per_class_s.values
237 |
238 | accuracy = (1.*(y_pred == y_true) / counts_per_class[y_true]).sum() / num_class
239 | return accuracy.astype('float32')
240 |
241 |
242 |
243 | def run_bash(cmd, raise_on_err=True, raise_on_warning=False, versbose=True, return_exist_code=False, err_ind_by_exitcode=False):
244 | """ This function takes Bash commands and return their stdout
245 | Returns: string (stdout)
246 | :type cmd: string
247 | """
248 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
249 | stderr=subprocess.PIPE, executable='/bin/bash')
250 | # p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
251 | out, err = p.communicate()
252 | out = out.strip().decode('utf-8')
253 | err = err.strip().decode('utf-8')
254 | exit_code = p.returncode
255 | is_err = err != ''
256 | if err_ind_by_exitcode:
257 | is_err = (exit_code != 0)
258 |
259 | if is_err and raise_on_err:
260 | do_raise = True
261 | if 'warning' in err.lower():
262 | do_raise = raise_on_warning
263 | if versbose and not raise_on_warning:
264 | print('command was: {}'.format(cmd))
265 | print(err, file=sys.stderr)
266 |
267 | if do_raise or 'error' in err.lower():
268 | if versbose:
269 | print('command was: {}'.format(cmd))
270 | raise RuntimeError(err)
271 |
272 | if return_exist_code:
273 | return out, exit_code
274 | else:
275 | return out # This is the stdout from the shell command
276 |
277 |
278 | @contextmanager
279 | def temporary_random_numpy_seed(seed) -> object:
280 | """ From https://github.com/yuvalatzmon/COSMO/blob/master/src/utils/ml_utils.py#L701
281 | A context manager for a temporary random seed (only within context)
282 | When leaving the context the numpy random state is restored
283 |
284 | This function is inspired by http://stackoverflow.com/q/32679403, which is shared according to https://creativecommons.org/licenses/by-sa/3.0/
285 |
286 | License
287 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
288 |
289 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
290 |
291 | 1. Definitions
292 |
293 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License.
294 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License.
295 | "Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License.
296 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
297 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike.
298 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
299 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
300 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work.
301 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
302 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
303 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium.
304 | 2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
305 |
306 | 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
307 |
308 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections;
309 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified.";
310 | to Distribute and Publicly Perform the Work including as incorporated in Collections; and,
311 | to Distribute and Publicly Perform Adaptations.
312 | For the avoidance of doubt:
313 |
314 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
315 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and,
316 | Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License.
317 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved.
318 |
319 | 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
320 |
321 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested.
322 | You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
323 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
324 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
325 | 5. Representations, Warranties and Disclaimer
326 |
327 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
328 |
329 | 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
330 |
331 | 7. Termination
332 |
333 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
334 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
335 | 8. Miscellaneous
336 |
337 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
338 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
339 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
340 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
341 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
342 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
343 |
344 |
345 |
346 |
347 | """
348 | state = np.random.get_state()
349 | np.random.seed(seed)
350 | yield None
351 | np.random.set_state(state)
352 |
353 |
354 |
--------------------------------------------------------------------------------