├── docker_cfg ├── inputrc └── Dockerfile ├── ATTOP ├── data │ ├── __init__.py │ └── dataset.py ├── models │ ├── __init__.py │ └── models.py └── LICENSE ├── taskmodularnets ├── __init__.py ├── data │ ├── __init__.py │ └── dataset.py ├── utils │ ├── __init__.py │ └── utils.py └── LICENSE ├── Causal.png ├── AO_CLEVr_examples.png ├── .gitignore ├── LICENSE_COSMO ├── LICENSE_HSIC ├── LICENSE ├── HSIC.py ├── experiment_zappos_TMNsplit.py ├── README.md ├── data.py ├── params.py ├── offline_early_stop.py ├── main.py ├── useful_utils.py ├── model.py ├── eval.py ├── train.py └── COSMO_utils.py /docker_cfg/inputrc: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ATTOP/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ATTOP/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /taskmodularnets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Causal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/causal_comp/HEAD/Causal.png -------------------------------------------------------------------------------- /AO_CLEVr_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/causal_comp/HEAD/AO_CLEVr_examples.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | outputs/ 3 | logs/ 4 | logs*/ 5 | tmp*/ 6 | beat.touch 7 | 8 | notebooks/figures 9 | 10 | *__pycache__ 11 | 12 | .ipynb_checkpoints/ 13 | .vscode/ 14 | .idea/ 15 | .ipynb_checkpoints 16 | -------------------------------------------------------------------------------- /taskmodularnets/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # taskmodularnets/LICENSE file. 6 | # -------------------------------------------------------------------------------- /taskmodularnets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # taskmodularnets/LICENSE file. 6 | # 7 | -------------------------------------------------------------------------------- /ATTOP/data/dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the following repo 5 | # (released under the MIT License). 6 | # 7 | # Source: 8 | # https://github.com/Tushar-N/attributes-as-operators/blob/master/data/dataset.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in ATTOP/LICENSE 12 | # --------------------------------------------------------------- 13 | 14 | import numpy as np 15 | 16 | def sample_negative(self, attr, obj): 17 | new_attr, new_obj = self.train_pairs[np.random.choice(len(self.train_pairs))] 18 | if new_attr == attr and new_obj == obj: 19 | return self.sample_negative(attr, obj) 20 | return (self.attr2idx[new_attr], self.obj2idx[new_obj]) 21 | -------------------------------------------------------------------------------- /docker_cfg/Dockerfile: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | FROM nvcr.io/nvidia/pytorch:20.01-py3 9 | RUN yes | pip install tensorboardX==1.2 10 | RUN yes | pip install wget 11 | RUN yes | pip install scikit-learn 12 | RUN yes | pip install ipymd 13 | RUN yes | pip install Pillow 14 | RUN yes | apt-get update 15 | RUN yes | apt-get install tmux 16 | RUN yes | apt-get install htop 17 | RUN yes | apt-get install bc 18 | RUN yes | conda install -c conda-forge accimage 19 | RUN yes | pip install notifiers 20 | RUN yes | pip install pyyaml 21 | RUN yes | pip install ruamel.yaml 22 | RUN yes | apt install rsync 23 | RUN yes | apt install git 24 | RUN yes | pip install dataclasses 25 | RUN yes | pip install simple_parsing 26 | RUN yes | pip install munch 27 | COPY inputrc /root/.inputrc 28 | 29 | 30 | -------------------------------------------------------------------------------- /ATTOP/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE_COSMO: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020 Yuval Atzmon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | -------------------------------------------------------------------------------- /LICENSE_HSIC: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020 Tom Beer and Bar Eini-Porat 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /ATTOP/models/models.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the following repo 5 | # (released under the MIT License). 6 | # 7 | # Source: 8 | # https://github.com/Tushar-N/attributes-as-operators/blob/master/models/models.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in ATTOP/LICENSE 12 | # --------------------------------------------------------------- 13 | 14 | import torch.nn as nn 15 | 16 | 17 | class MLP(nn.Module): 18 | def __init__(self, inp_dim, out_dim, num_layers=1, relu=True, bias=True): 19 | super(MLP, self).__init__() 20 | mod = [] 21 | for L in range(num_layers-1): 22 | mod.append(nn.Linear(inp_dim, inp_dim, bias=bias)) 23 | mod.append(nn.ReLU(True)) 24 | 25 | mod.append(nn.Linear(inp_dim, out_dim, bias=bias)) 26 | if relu: 27 | mod.append(nn.ReLU(True)) 28 | 29 | self.mod = nn.Sequential(*mod) 30 | 31 | def forward(self, x): 32 | output = self.mod(x) 33 | return output 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | Copyright 2021 NVIDIA 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /HSIC.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from a file in the following repo 5 | # (released under the MIT License). 6 | # 7 | # Source: 8 | # https://github.com/tom-beer/deep-scientific-discovery/blob/master/hsic.py 9 | # 10 | # The license for the original version of this file can be 11 | # found in this directory (LICENSE_HSIC). The modifications 12 | # to this file are subject to the License 13 | # located at the root directory. 14 | # --------------------------------------------------------------- 15 | 16 | 17 | import torch 18 | import numpy as np 19 | 20 | from useful_utils import ns_profiling_label 21 | 22 | 23 | def pairwise_distances(x): 24 | x_distances = torch.sum(x**2,-1).reshape((-1,1)) 25 | return -2*torch.mm(x,x.t()) + x_distances + x_distances.t() 26 | 27 | def kernelMatrixGaussian(x, sigma=1): 28 | 29 | pairwise_distances_ = pairwise_distances(x) 30 | gamma = -1.0 / (sigma ** 2) 31 | return torch.exp(gamma * pairwise_distances_) 32 | 33 | def kernelMatrixLinear(x): 34 | return torch.matmul(x,x.t()) 35 | 36 | # check 37 | def HSIC(X, Y, kernelX="Gaussian", kernelY="Gaussian", sigmaX=1, sigmaY=1, 38 | log_median_pairwise_distance=False): 39 | m,_ = X.shape 40 | assert(m>1) 41 | 42 | median_pairwise_distanceX, median_pairwise_distanceY = np.nan, np.nan 43 | if log_median_pairwise_distance: 44 | # This calc takes a long time. It is used for debugging and disabled by default. 45 | with ns_profiling_label('dist'): 46 | median_pairwise_distanceX = median_pairwise_distance(X) 47 | median_pairwise_distanceY = median_pairwise_distance(Y) 48 | 49 | with ns_profiling_label('Hkernel'): 50 | K = kernelMatrixGaussian(X,sigmaX) if kernelX == "Gaussian" else kernelMatrixLinear(X) 51 | L = kernelMatrixGaussian(Y,sigmaY) if kernelY == "Gaussian" else kernelMatrixLinear(Y) 52 | 53 | with ns_profiling_label('Hfinal'): 54 | H = torch.eye(m, device='cuda') - 1.0/m * torch.ones((m,m), device='cuda') 55 | H = H.float().cuda() 56 | 57 | Kc = torch.mm(H,torch.mm(K,H)) 58 | 59 | HSIC = torch.trace(torch.mm(L,Kc))/((m-1)**2) 60 | return HSIC, median_pairwise_distanceX, median_pairwise_distanceY 61 | 62 | 63 | def median_pairwise_distance(X): 64 | t = pairwise_distances(X).detach() 65 | triu_indices = t.triu(diagonal=1).nonzero().T 66 | 67 | if triu_indices[0].shape[0] == 0 or triu_indices[1].shape[0] == 0: 68 | return 0. 69 | else: 70 | return torch.median(t[triu_indices[0], triu_indices[1]]).item() 71 | 72 | -------------------------------------------------------------------------------- /experiment_zappos_TMNsplit.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | from pathlib import Path 9 | 10 | from dataclasses import dataclass 11 | from simple_parsing import ArgumentParser, ConflictResolution 12 | 13 | from params import CommandlineArgs 14 | import shlex 15 | 16 | from main import main 17 | from useful_utils import to_simple_parsing_args 18 | 19 | 20 | @dataclass 21 | class ScriptArgs(): 22 | output_dir: str = '/tmp/output_results_dir' 23 | """ Directory to save the results """ 24 | 25 | data_dir: str = '/local_data/zap_tmn' 26 | """ Directory that holds the data""" 27 | 28 | seed: int = 0 29 | """ Random seed. For zappos, choose \in [0..4] """ 30 | 31 | use_wandb: bool = True 32 | """ log results to w&b """ 33 | 34 | wandb_user: str = 'none' 35 | 36 | @classmethod 37 | def get_args(cls): 38 | args: cls = to_simple_parsing_args(cls) 39 | return args 40 | 41 | if __name__ == '__main__': 42 | script_args = ScriptArgs.get_args() 43 | 44 | # Seed=0: Unseen=27.4, Seen=33.9, Harmonic=30.3, Closed=56.7, AUC=25.2 45 | 46 | # Set commandline arguments for the experiment. 47 | # see params.py for their documentation and meaning with respect to definitions in the paper 48 | 49 | experiment_commandline_opt_string = f"""--output_dir={script_args.output_dir} --data_dir={script_args.data_dir} 50 | --use_wandb={script_args.use_wandb} --wandb_user={script_args.wandb_user} 51 | --learn_embeddings=1 --nlayers_label=0 --nlayers_joint_ao=2 --h_dim=300 --E_num_hidden=2 --E_num_common_hidden=1 52 | --dataset_name=zap_tmn --dataset_variant=irrelevant --optimizer_name=adam --lr=0.0003 --max_epoch=1000 --alternate_ys=0 53 | --weight_decay=5e-05 --calc_AUC=1 --lambda_feat=0 --mu_img_feat=0 --lambda_aux_disjoint=100 --lambda_aux_img=0 --metadata_from_pkl=0 54 | --lambda_ao_emb=1000 --lambda_aux=0 --seed={script_args.seed} --alphaH=450.0 --HSIC_coeff=0.0045 --balanced_loss=0 --num_split=0 55 | --num_workers=8 --report_imbalanced_metrics=True""" 56 | if script_args.use_wandb: 57 | experiment_commandline_opt_string += """ --instance_name=dev_zap_tmn --project_name=causal_comp_prep_zap_tmn""" 58 | 59 | parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) 60 | parser.add_arguments(CommandlineArgs, dest='cfg') 61 | main_args: CommandlineArgs = parser.parse_args(shlex.split(experiment_commandline_opt_string)).cfg 62 | main(main_args) 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | # A causal view of compositional zero-shot recognition 4 | This repository hosts the dataset and source code for the paper "A causal view of compositional zero-shot recognition". Yuval Atzmon, Felix Kreuk, Uri Shalit, Gal Chechik, NeurIPS 2020 (Spotlight) 5 | 6 | 7 | ## Code 8 | 9 | ### Setup 10 | #### Build docker image: 11 | ``` 12 | cd docker_cfg 13 | docker build --network=host -t causal_comp -f Dockerfile . 14 | cd .. 15 | ``` 16 | 17 | 18 | #### Prepare UT Zappos50KUT with TMN split: 19 | To reproduce Zappos50KUT results according to TMN evaluation split, you should prepare the dataset according to 20 | taskmodularnets: 21 | https://github.com/facebookresearch/taskmodularnets 22 | 23 | The Zappos50KUT dataset is for academic, non-commercial use only, released by: 24 | 25 | * A. Yu and K. Grauman. "Fine-Grained Visual Comparisons with Local Learning". In CVPR, 2014. 26 | * A. Yu and K. Grauman. "Semantic Jitter: Dense Supervision for Visual Comparisons via Synthetic Images". In ICCV, 2017. 27 | 28 | 29 | #### Reproduce Zappos experiment. 30 | Notes: 31 | 1. Set `DATA_DIR` to the directory containing the data. 32 | 2. Set `SEED` to [0..4] 33 | 3. Set `SOURCE_CODE_DIR` to the current project workdir 34 | 4. Set `OUTPUT_DIR` to the directory to save the result 35 | 36 | ``` 37 | SEED=0 # set seed \in [0..4] 38 | SOURCE_CODE_DIR=$HOME/git/causal_comp_prep/ 39 | DATA_DIR=/local_data/zap_tmn # SET HERE THE DATA DIR 40 | DATASET=zap_tmn 41 | OUTPUT_DIR=/tmp/output/causal_comp_${DATASET}__seed${SEED} 42 | 43 | # prepare output directory 44 | mkdir -p ${OUTPUT_DIR} 45 | rm -r ${OUTPUT_DIR}/* 46 | # run experiment 47 | docker run --net=host -v ${SOURCE_CODE_DIR}:/workspace/causal_comp -v ${DATA_DIR}:/data/zap_tmn -v ${OUTPUT_DIR}:/output --user $(id -u):$(id -g) --shm-size=1g --ulimit memlock=-1 --ulimit stack=6710886 --rm -it causal_comp /bin/bash -c "cd /workspace/causal_comp/; python experiment_zappos_TMNsplit.py --seed=${SEED} --output_dir=/output --data_dir=/data/zap_tmn --use_wandb=0" 48 | ``` 49 | 50 | #### Reproduce AO-CLEVr experiments. 51 | Hyperparams for reproducing the results will be published soon 52 | 53 | 54 | ## AO-CLEVr Dataset 55 | 56 | AO-CLEVr is a new synthetic-images dataset containing images of "easy" Attribute-Object categories, based on the CLEVr framework (Johnson et al. CVPR 2017). AO-CLEVr has attribute-object pairs created from 8 attributes: \{ red, purple, yellow, blue, green, cyan, gray, brown \} and 3 object shapes \{sphere, cube, cylinder\}, yielding 24 attribute-object pairs. Each pair consists of 7500 images. Each image has a single object that consists of the attribute-object pair. The object is randomly assigned one of two sizes (small/large), one of two materials (rubber/metallic), a random position, and random lightning according to CLEVr defaults. 57 | 58 | ![Examples of AO-CLEVr images](./AO_CLEVr_examples.png) 59 | 60 | The dataset can be accessed from [the following url](https://drive.google.com/drive/folders/1BBwW9VqzROgJXmvnfXcOxbLob8FB_jLf). 61 | 62 | 63 | ## Cite the paper 64 | If you use the contents of this project, please cite our paper. 65 | 66 | @inproceedings{neurips2020_causal_comp_atzmon, 67 | author = {Atzmon, Yuval and Kreuk, Felix and Shalit, Uri and Chechik, Gal}, 68 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 69 | title = {A causal view of compositional zero-shot recognition}, 70 | year = {2020} 71 | } 72 | 73 | For business inquiries, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com)
74 | For press and other inquiries, please contact Hector Marinez at [hmarinez@nvidia.com](hmarinez@nvidia.com) 75 | -------------------------------------------------------------------------------- /taskmodularnets/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # taskmodularnets/LICENSE file. 6 | # 7 | import os 8 | import re 9 | import torch 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as torchdata 12 | import numpy as np 13 | from torch.autograd import Variable 14 | import itertools 15 | import copy 16 | 17 | # Save the training script and all the arguments 18 | import shutil 19 | 20 | 21 | def save_args(args): 22 | shutil.copy('train_modular.py', args.cv_dir + '/' + args.name + '/') 23 | shutil.copy('archs/models.py', args.cv_dir + '/' + args.name + '/') 24 | with open(args.cv_dir + '/' + args.name + '/args.txt', 'w') as f: 25 | f.write(str(args)) 26 | 27 | 28 | class UnNormalizer: 29 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 30 | self.mean = mean 31 | self.std = std 32 | 33 | def __call__(self, tensor): 34 | for b in range(tensor.size(0)): 35 | for t, m, s in zip(tensor[b], self.mean, self.std): 36 | t.mul_(s).add_(m) 37 | return tensor 38 | 39 | 40 | def chunks(l, n): 41 | """Yield successive n-sized chunks from l.""" 42 | for i in range(0, len(l), n): 43 | yield l[i:i + n] 44 | 45 | 46 | def flatten(l): 47 | return list(itertools.chain.from_iterable(l)) 48 | 49 | 50 | class AverageMeter(object): 51 | """Computes and stores the average and current value""" 52 | 53 | def __init__(self): 54 | self.reset() 55 | 56 | def reset(self): 57 | self.val = 0 58 | self.avg = 0 59 | self.sum = 0 60 | self.count = 0 61 | 62 | def update(self, val, n=1): 63 | self.val = val 64 | self.sum += val * n 65 | self.count += n 66 | self.avg = self.sum / self.count 67 | 68 | 69 | def calc_pr_ovr_noref(counts, out): 70 | """ 71 | [P, R, score, ap] = calc_pr_ovr(counts, out, K) 72 | Input : 73 | counts : number of occurrences of this word in the ith image 74 | out : score for this image 75 | Output : 76 | P, R : precision and recall 77 | score : score which corresponds to the particular precision and recall 78 | ap : average precision 79 | """ 80 | #binarize counts 81 | out = out.astype(np.float64) 82 | counts = np.array(counts > 0, dtype=np.float32) 83 | tog = np.hstack((counts[:, np.newaxis].astype(np.float64), 84 | out[:, np.newaxis].astype(np.float64))) 85 | ind = np.argsort(out) 86 | ind = ind[::-1] 87 | score = np.array([tog[i, 1] for i in ind]) 88 | sortcounts = np.array([tog[i, 0] for i in ind]) 89 | 90 | tp = sortcounts 91 | fp = sortcounts.copy() 92 | for i in range(sortcounts.shape[0]): 93 | if sortcounts[i] >= 1: 94 | fp[i] = 0. 95 | elif sortcounts[i] < 1: 96 | fp[i] = 1. 97 | 98 | tp = np.cumsum(tp) 99 | fp = np.cumsum(fp) 100 | # P = np.cumsum(tp)/(np.cumsum(tp) + np.cumsum(fp)); 101 | P = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 102 | 103 | numinst = np.sum(counts) 104 | 105 | R = tp / (numinst + 1e-10) 106 | 107 | ap = voc_ap(R, P) 108 | return P, R, score, ap 109 | 110 | 111 | def voc_ap(rec, prec): 112 | # correct AP calculation 113 | # first append sentinel values at the end 114 | mrec = np.concatenate(([0.], rec, [1.])) 115 | mpre = np.concatenate(([0.], prec, [0.])) 116 | 117 | # compute the precision envelope 118 | for i in range(mpre.size - 1, 0, -1): 119 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 120 | 121 | # to calculate area under PR curve, look for points 122 | # where X axis (recall) changes value 123 | i = np.where(mrec[1:] != mrec[:-1])[0] 124 | 125 | # and sum (\Delta recall) * prec 126 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 127 | return ap 128 | 129 | 130 | def roll(x, n): 131 | return torch.cat((x[-n:], x[:-n])) 132 | 133 | 134 | def load_word_embeddings(emb_file, vocab): 135 | 136 | vocab = [v.lower() for v in vocab] 137 | 138 | embeds = {} 139 | for line in open(emb_file, 'r'): 140 | line = line.strip().split(' ') 141 | wvec = torch.FloatTensor(list(map(float, line[1:]))) 142 | embeds[line[0]] = wvec 143 | 144 | # for zappos (should account for everything) 145 | custom_map = { 146 | 'Faux.Fur': 'fur', 147 | 'Faux.Leather': 'leather', 148 | 'Full.grain.leather': 'leather', 149 | 'Hair.Calf': 'hair', 150 | 'Patent.Leather': 'leather', 151 | 'Nubuck': 'leather', 152 | 'Boots.Ankle': 'boots', 153 | 'Boots.Knee.High': 'knee-high', 154 | 'Boots.Mid-Calf': 'midcalf', 155 | 'Shoes.Boat.Shoes': 'shoes', 156 | 'Shoes.Clogs.and.Mules': 'clogs', 157 | 'Shoes.Flats': 'flats', 158 | 'Shoes.Heels': 'heels', 159 | 'Shoes.Loafers': 'loafers', 160 | 'Shoes.Oxfords': 'oxfords', 161 | 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers', 162 | 'traffic_light': 'light', 163 | 'trash_can': 'trashcan' 164 | } 165 | for k in custom_map: 166 | embeds[k.lower()] = embeds[custom_map[k]] 167 | 168 | embeds = [embeds[k] for k in vocab] 169 | embeds = torch.stack(embeds) 170 | print('loaded embeddings', embeds.size()) 171 | 172 | return embeds 173 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | from typing import NamedTuple 9 | 10 | import numpy as np 11 | 12 | from pathlib import Path 13 | from ATTOP.data.dataset import sample_negative as ATTOP_sample_negative 14 | 15 | import torch 16 | from torch.utils import data 17 | 18 | from useful_utils import categorical_histogram, get_and_update_num_calls 19 | from COSMO_utils import temporary_random_numpy_seed 20 | 21 | 22 | class DataItem(NamedTuple): 23 | """ A NamedTuple for returning a Dataset item """ 24 | feat: torch.Tensor 25 | pos_attr_id: int 26 | pos_obj_id: int 27 | neg_attr_id: int 28 | neg_obj_id: int 29 | image_fname: str 30 | 31 | 32 | class CompDataFromDict(): 33 | # noinspection PyMissingConstructor 34 | def __init__(self, dict_data: dict, data_subset: str, data_dir: str): 35 | 36 | # define instance variables to be retrieved from struct_data_dict 37 | self.split: str = 'TBD' 38 | self.phase: str = 'TBD' 39 | self.feat_dim: int = -1 40 | self.objs: list = [] 41 | self.attrs: list = [] 42 | self.attr2idx: dict = {} 43 | self.obj2idx: dict = {} 44 | self.pair2idx: dict = {} 45 | self.seen_pairs: list = [] 46 | self.all_open_pairs: list = [] 47 | self.closed_unseen_pairs: list = [] 48 | self.unseen_closed_val_pairs: list = [] 49 | self.unseen_closed_test_pairs: list = [] 50 | self.train_data: tuple = tuple() 51 | self.val_data: tuple = tuple() 52 | self.test_data: tuple = tuple() 53 | 54 | self.data_dir: str = data_dir 55 | 56 | # retrieve instance variables from struct_data_dict 57 | vars(self).update(dict_data) 58 | self.data = dict_data[data_subset] 59 | 60 | self.activations = {} 61 | features_dict = torch.load(Path(data_dir) / 'features.t7') 62 | for i, img_filename in enumerate(features_dict['files']): 63 | self.activations[img_filename] = features_dict['features'][i] 64 | 65 | self.input_shape = (self.feat_dim,) 66 | self.num_objs = len(self.objs) 67 | self.num_attrs = len(self.attrs) 68 | self.num_seen_pairs = len(self.seen_pairs) 69 | self.shape_obj_attr = (self.num_objs, self.num_attrs) 70 | 71 | self.flattened_seen_pairs_mask = self.get_flattened_pairs_mask(self.seen_pairs) 72 | self.flattened_closed_unseen_pairs_mask = self.get_flattened_pairs_mask(self.closed_unseen_pairs) 73 | self.flattened_all_open_pairs_mask = self.get_flattened_pairs_mask(self.all_open_pairs) 74 | self.seen_pairs_joint_class_ids = np.where(self.flattened_seen_pairs_mask) 75 | 76 | self.y1_freqs, self.y2_freqs, self.pairs_freqs = self._calc_freqs() 77 | self._just_load_labels = False 78 | 79 | self.train_pairs = self.seen_pairs 80 | 81 | def sample_negative(self, attr, obj): 82 | return ATTOP_sample_negative(self, attr, obj) 83 | 84 | def get_flattened_pairs_mask(self, pairs): 85 | pairs_ids = np.array([(self.obj2idx[obj], self.attr2idx[attr]) for attr, obj in pairs]) 86 | flattened_pairs = np.zeros(self.shape_obj_attr, dtype=bool) # init an array of False 87 | flattened_pairs[tuple(zip(*pairs_ids))] = True 88 | flattened_pairs = flattened_pairs.flatten() 89 | return flattened_pairs 90 | 91 | def just_load_labels(self, just_load_labels=True): 92 | self._just_load_labels = just_load_labels 93 | 94 | def get_all_labels(self): 95 | attrs = [] 96 | objs = [] 97 | joints = [] 98 | self.just_load_labels(True) 99 | for attrs_batch, objs_batch in self: 100 | if isinstance(attrs_batch, torch.Tensor): 101 | attrs_batch = attrs_batch.cpu().numpy() 102 | if isinstance(objs_batch, torch.Tensor): 103 | objs_batch = objs_batch.cpu().numpy() 104 | joint = self.to_joint_label(objs_batch, attrs_batch) 105 | 106 | attrs.append(attrs_batch) 107 | objs.append(objs_batch) 108 | joints.append(joint) 109 | 110 | self.just_load_labels(False) 111 | attrs = np.array(attrs) 112 | objs = np.array(objs) 113 | return attrs, objs, joints 114 | 115 | def _calc_freqs(self): 116 | y2_train, y1_train, ys_joint_train = self.get_all_labels() 117 | y1_freqs = categorical_histogram(y1_train, range(self.num_objs), plot=False, frac=True) 118 | y1_freqs[y1_freqs == 0] = np.nan 119 | y2_freqs = categorical_histogram(y2_train, range(self.num_attrs), plot=False, frac=True) 120 | y2_freqs[y2_freqs == 0] = np.nan 121 | 122 | pairs_freqs = categorical_histogram(ys_joint_train, 123 | range(self.num_objs * self.num_attrs), 124 | plot=False, frac=True) 125 | pairs_freqs[pairs_freqs == 0] = np.nan 126 | return y1_freqs, y2_freqs, pairs_freqs 127 | 128 | def get(self, name): 129 | return vars(self).get(name) 130 | 131 | def __getitem__(self, idx): 132 | image_fname, attr, obj = self.data[idx] 133 | pos_attr_id, pos_obj_id = self.attr2idx[attr], self.obj2idx[obj] 134 | if self._just_load_labels: 135 | return pos_attr_id, pos_obj_id 136 | 137 | num_calls_cnt = get_and_update_num_calls(self.__getitem__) 138 | 139 | negative_attr_id, negative_obj_id = -1, -1 # default values 140 | if self.phase == 'train': 141 | # we set a temp np seed to override a weird issue with 142 | # sample_negative() at __getitem__, where the sampled pairs 143 | # could not be deterministically reproduced: 144 | # Now at each call to _getitem_ we set the seed to a 834276 (chosen randomly) + the number of calls to _getitem_ 145 | with temporary_random_numpy_seed(834276 + num_calls_cnt): 146 | # draw a negative pair 147 | negative_attr_id, negative_obj_id = self.sample_negative(attr, obj) 148 | 149 | item = DataItem( 150 | feat=self.activations[image_fname], 151 | pos_attr_id=pos_attr_id, 152 | pos_obj_id=pos_obj_id, 153 | neg_attr_id=negative_attr_id, 154 | neg_obj_id=negative_obj_id, 155 | image_fname=image_fname, 156 | ) 157 | return item 158 | 159 | def __len__(self): 160 | return len(self.data) 161 | 162 | def to_joint_label(self, y1_batch, y2_batch): 163 | return (y1_batch * self.num_attrs + y2_batch) 164 | 165 | 166 | def get_data_loaders(train_dataset, valid_dataset, test_dataset, batch_size, 167 | num_workers=10, test_batchsize=None, shuffle_eval_set=True): 168 | if test_batchsize is None: 169 | test_batchsize = batch_size 170 | 171 | pin_memory = True 172 | if num_workers == 0: 173 | pin_memory = False 174 | print('num_workers = ', num_workers) 175 | print('pin_memory = ', pin_memory) 176 | train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, 177 | pin_memory=pin_memory) 178 | valid_loader = None 179 | if valid_dataset is not None and len(valid_dataset) > 0: 180 | valid_loader = data.DataLoader(valid_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set, 181 | num_workers=num_workers, pin_memory=pin_memory) 182 | test_loader = data.DataLoader(test_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set, 183 | num_workers=num_workers, pin_memory=pin_memory) 184 | return test_loader, train_loader, valid_loader 185 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | from collections import namedtuple 11 | from pathlib import Path 12 | import numpy as np 13 | 14 | from dataclasses import dataclass 15 | from simple_parsing import ArgumentParser, ConflictResolution, field 16 | 17 | from useful_utils import to_simple_parsing_args 18 | 19 | 20 | @dataclass 21 | class DataCfg(): 22 | # dataset name 23 | dataset_name: str = 'ao_clevr' 24 | 25 | # default: {dataset_basedir}/{dataset_name} 26 | data_dir: str = 'default' 27 | 28 | dataset_basedir: str = 'data' 29 | 30 | # VT_random|UV_random 31 | dataset_variant: str = 'VT_random' 32 | 33 | num_split: int = 5000 34 | 35 | metadata_from_pkl: bool = True 36 | 37 | def __post_init__(self): 38 | if self.data_dir == 'default': 39 | self.data_dir = os.path.join(self.dataset_basedir, self.dataset_name) 40 | 41 | def __getitem__(self, key): 42 | """ Allow accessing instance attributes as dictionary keys """ 43 | return getattr(self, key) 44 | 45 | 46 | @dataclass 47 | class MetricsCfg: 48 | calc_AUC: bool = True 49 | """ A flag to calc AUC metric (slower) """ 50 | 51 | def __post_init__(self): 52 | assert self.calc_AUC 53 | 54 | 55 | @dataclass(frozen=True) 56 | class EarlyStopMetric: 57 | # metric name 58 | metric: str 59 | 60 | # min: minimize or max: maximize 61 | polarity: str 62 | 63 | 64 | 65 | @dataclass 66 | class ModelCfg(): 67 | """" The hyper params to configure the architecture. 68 | Note: For backward compatibility, some of the variable names here are different than those mentioned in the 69 | paper. We explain in comments, how each variable is referenced at the paper. 70 | """ 71 | 72 | h_dim: int = 150 73 | """ Number of hidden units in each MLP layer. """ 74 | 75 | nlayers_label: int = 2 76 | """ Number of MLP layer, for h_O anb h_A. """ 77 | 78 | nlayers_joint_ao: int = 2 79 | """ Number of MLP layers, for g. """ 80 | 81 | 82 | E_num_hidden: int = 2 83 | """ Number of layers, for g^-1_A and g^-1_O """ 84 | 85 | E_num_common_hidden: int = 0 86 | """ Number of layers, when projecting pretrained image features to the feature space \X 87 | (as explained for Zappos in section C.1) 88 | """ 89 | 90 | mlp_activation: str = 'leaky_relu' 91 | 92 | learn_embeddings: bool = True 93 | """ Set to False to train as VisProd, with CE loss rather than an embedding loss. """ 94 | 95 | def __post_init__(self): 96 | self.VisProd: bool = not self.learn_embeddings 97 | 98 | assert self.E_num_common_hidden <= self.E_num_hidden 99 | 100 | 101 | @dataclass 102 | class TrainCfg: 103 | """" The hyper params to configure the training. See Supplementary C. "implementation details". 104 | Note: For backward compatibility, some of the variable names here are different than those mentioned in the 105 | paper. We explain in comments, how each variable is referenced at the paper. 106 | """ 107 | 108 | metrics: MetricsCfg 109 | 110 | 111 | batch_size: int = 2048 112 | """ batch size """ 113 | 114 | lr: float = 0.003 115 | """ initial learning rate """ 116 | 117 | max_epoch: int = 5 118 | """ max number of epochs """ 119 | 120 | alternate_ys: int = 21 121 | """ Whether and how to use alternate training. 0: no alternation|12: object then attr|21: attr then obj""" 122 | 123 | 124 | lr_step2: float = 3e-05 # 125 | """ Step 2 initial learning rate. Only relevant if alternate_ys != 0 """ 126 | 127 | max_epoch_step2: int = 1000 128 | """ Step 2 max number of epochs. Only relevant if alternate_ys != 0 """ 129 | 130 | weight_decay: float = 0.1 131 | """ weight_decay """ 132 | 133 | HSIC_coeff: float = 10 134 | """ \lambda_rep in paper """ 135 | 136 | alphaH: float = 0 137 | """ \lambda_oh in paper """ 138 | 139 | alphaH_step2 = -1 140 | """ Step 2 \lambda_oh. Only relevant if alternate_ys != 0. If set to -1, then take --alphaH value """ 141 | 142 | lambda_CE: float = 1 143 | """ a coefficient for L_data """ 144 | 145 | lambda_feat: float = 1 146 | """ \lambda_ao in paper """ 147 | 148 | lambda_ao_emb: float = 0 149 | """ \lambda_ao when projecting pretrained image features to the feature space \X 150 | (as explained for Zappos in section C.1) 151 | Note: --lambda_feat and --lambda_ao_emb cant be both non-zero (we raise exception for this case). 152 | """ 153 | 154 | lambda_aux_disjoint: float = 100 155 | """ \lambda_icore in paper """ 156 | 157 | lambda_aux_img: float = 10 158 | """ \lambda_ig in paper, when --lambda_feat>0""" 159 | 160 | lambda_aux: float = 0 161 | """ \lambda_ig in paper, when --lambda_ao_emb>0""" 162 | 163 | mu_img_feat: float = 0.1 164 | """ \lambda_ao at inference time """ 165 | 166 | balanced_loss: bool = True 167 | """ Weighed the loss of ||φa−ha||^2 and ||φo−ho||^2 according to the respective attribute and object frequencies in 168 | the training set (Described in supplementary C.2). """ 169 | 170 | triplet_loss_margin: float = 0.5 171 | """ The margin for the triplet loss. Same value as used by attributes-as-operators """ 172 | 173 | optimizer_name: str = 'Adam' 174 | 175 | seed: int = 0 176 | """ random seed """ 177 | 178 | test_batchsize: int = -1 179 | """batch-size for inference; default uses the training batch size""" 180 | 181 | verbose: bool = True 182 | num_workers: int = 8 183 | shuffle_eval_set: bool = True 184 | n_iter: int = field(init=False) 185 | mu_disjoint: float = field(init=False) 186 | mu_ao_emb: float = field(init=False) 187 | primary_early_stop_metric: EarlyStopMetric = field(init=False) 188 | freeze_class1: bool = field(init=False) 189 | freeze_class2: bool = field(init=False) 190 | Y12_balance_coeff: float = field(init=False) 191 | 192 | 193 | def __post_init__(self): 194 | # sanity checks (assertions) 195 | assert (self.alternate_ys in [0, 12, 21]) 196 | assert not ((self.lambda_ao_emb > 0) and (self.lambda_feat > 0)) 197 | if self.lambda_feat == 0: 198 | assert(self.mu_img_feat == 0) 199 | 200 | # assignments 201 | if self.test_batchsize <= 0: 202 | self.test_batchsize = self.batch_size 203 | if self.alphaH_step2 < 0: 204 | self.alphaH_step2 = self.alphaH 205 | 206 | self.mu_disjoint = self.lambda_CE 207 | self.mu_ao_emb = self.lambda_ao_emb 208 | self.primary_early_stop_metric = EarlyStopMetric('epoch', 'max') 209 | self.Y12_balance_coeff = 0.5 210 | self.freeze_class1 = False 211 | self.freeze_class2 = False 212 | self.n_iter = -1 # Should be updated after data is loaded 213 | 214 | 215 | 216 | def set_n_iter(self, num_train_samples, max_epoch = None): 217 | if max_epoch is None: 218 | max_epoch = self.max_epoch 219 | self.n_iter = int((max_epoch) * np.ceil(num_train_samples / self.batch_size)) 220 | 221 | def __getitem__(self, key): 222 | """ Allow accessing instance attributes as dictionary keys """ 223 | return getattr(self, key) 224 | 225 | @dataclass 226 | class ExperimentCfg(): 227 | output_dir: str = field(alias="-o") 228 | ignore_existing_output_contents: bool = True 229 | gpu: int = 0 230 | use_wandb: bool = True 231 | wandb_user: str = 'none' 232 | project_name: str = 'causal_comp_prep' 233 | experiment_name: str = 'default' 234 | instance_name: str = 'default' 235 | git_hash: str = '' 236 | sync_uid: str = '' 237 | report_imbalanced_metrics: bool = False 238 | 239 | # float precision when logging to CSV file 240 | csv_precision: int = 8 241 | 242 | delete_dumped_preds: bool = True 243 | 244 | 245 | def __post_init__(self): 246 | 247 | # Set default experiment name 248 | self._set_default_experiment_name() 249 | 250 | 251 | def _set_default_experiment_name(self): 252 | at_ngc: bool = ('NGC_JOB_ID' in os.environ.keys()) 253 | at_docker = np.in1d(['/opt/conda/bin'], np.array(sys.path))[0] 254 | at_local_docker = at_docker and not at_ngc 255 | name_suffix = '_local' 256 | if at_local_docker: 257 | name_suffix += '_docker' 258 | elif at_ngc: 259 | name_suffix = '_ngc' 260 | if self.experiment_name == 'default': 261 | self.experiment_name = 'dev' + name_suffix 262 | if self.instance_name == 'default': 263 | self.instance_name = 'dev' + name_suffix 264 | 265 | def __getitem__(self, key): 266 | """ Allow accessing instance attributes as dictionary keys """ 267 | return getattr(self, key) 268 | 269 | 270 | @dataclass 271 | class CommandlineArgs(): 272 | model: ModelCfg 273 | data: DataCfg 274 | train: TrainCfg 275 | exp: ExperimentCfg 276 | device: str = 'cuda' 277 | 278 | @classmethod 279 | def get_args(cls): 280 | args: cls = to_simple_parsing_args(cls) 281 | return args 282 | -------------------------------------------------------------------------------- /offline_early_stop.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | # from loguru import logger 9 | import argparse 10 | import glob 11 | import os 12 | import tempfile 13 | import warnings 14 | from collections import OrderedDict 15 | from copy import deepcopy 16 | from os.path import join 17 | import json 18 | from shutil import copyfile 19 | from datetime import datetime 20 | import pandas as pd 21 | import sys 22 | import numpy as np 23 | from scipy import signal 24 | 25 | from useful_utils import comma_seperated_str_to_list, wandb_myinit, slice_dict_to_dict, \ 26 | fill_missing_by_defaults 27 | from COSMO_utils import run_bash 28 | 29 | 30 | parser = argparse.ArgumentParser(description=__doc__) 31 | parser.add_argument('--dir', required=True, type=str) 32 | parser.add_argument('--early_stop_metrics', required=False, default=None, type=str) 33 | parser.add_argument('--ignore_skipping', default=True, action="store_true") 34 | parser.add_argument('--infer_on_incompleted', default=False, action="store_true") 35 | parser.add_argument('--smooth_val_curve', default=False, type=bool) 36 | parser.add_argument('--smooth_window_width', default=6, type=int) 37 | parser.add_argument('--use_wandb', default=False, type=bool) 38 | parser.add_argument('--wandb_project_name', default=None, type=str) 39 | parser.add_argument('--wandb_subset_metrics', default=False, type=bool) 40 | parser.add_argument('--eval_on_last_epoch', default=False, type=bool) 41 | 42 | 43 | def main(args): 44 | if isinstance(args, dict): 45 | args = fill_missing_by_defaults(args, parser) 46 | 47 | files = glob.glob(args.dir + "/*") 48 | # results_files = [file for file in files if "results" in basename(file)] 49 | print("###############") 50 | print("Starting offline_early_stop") 51 | print(f"running on '{args.dir}'") 52 | # if args.create_results_json: 53 | # if args.metric is None: 54 | # print("if creating empty results.json, must give specific metric") 55 | # with open(join(args.dir, "results.json"), 'w') as f: 56 | # json.dump({"metrics": {}, "train_cfg": {}, "meta_cfg": {}}, f) 57 | if args.early_stop_metrics is None: 58 | assert ValueError('--early_stop_metrics is required') 59 | if not args.infer_on_incompleted: 60 | assert (os.path.exists(join(args.dir, 'completed_training.touch')) or os.path.exists(join(args.dir, 'results.json'))) 61 | if join(args.dir, "summary.csv") not in files: 62 | raise (RuntimeError("no summary.csv file!\n")) 63 | 64 | if not args.ignore_skipping and os.path.exists(join(args.dir, "lock")): 65 | print("this folder was already processed, skipping!\n") 66 | sys.exit(0) 67 | else: 68 | with open(join(args.dir, "lock"), "w") as f: 69 | f.write("0") 70 | summary_csv = pd.read_csv(join(args.dir, 'summary.csv'), sep='|') 71 | 72 | def smooth_validation_curve(validation_curve): 73 | if args.smooth_val_curve: 74 | win = np.hanning(args.smooth_window_width) 75 | validation_curve = signal.convolve(validation_curve, win, mode='same', 76 | method='direct') / sum(win) 77 | validation_curve = pd.Series(validation_curve) 78 | 79 | return validation_curve 80 | 81 | es_metric_list = comma_seperated_str_to_list(args.early_stop_metrics) 82 | # get run arguments 83 | 84 | args_dict = json.load(open(join(args.dir, "args.json"), "r")) 85 | early_stop_results_dict = OrderedDict() 86 | for i, primary_early_stop_metric in enumerate(es_metric_list): 87 | metric_index = i+1 88 | results = deepcopy(args_dict) 89 | print('') 90 | 91 | new_results_json_file = join(args.dir, f"results{metric_index}.json") 92 | if os.path.exists(new_results_json_file): 93 | backup_file_name = new_results_json_file.replace(".json", 94 | f"_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json") 95 | copyfile(new_results_json_file, backup_file_name) 96 | print(f"backed up '{new_results_json_file}' => '{backup_file_name}'") 97 | 98 | print(f"creating new file: {new_results_json_file}") 99 | 100 | try: 101 | validation_curve = summary_csv[primary_early_stop_metric].copy() 102 | validation_curve = smooth_validation_curve(validation_curve) 103 | best_epoch = validation_curve.idxmax() 104 | if np.isnan(best_epoch): 105 | continue 106 | if args.eval_on_last_epoch: 107 | best_epoch = len(validation_curve) -1 108 | best_epoch_summary = summary_csv.iloc[[best_epoch]] 109 | best_epoch_test_score = best_epoch_summary[primary_early_stop_metric.replace("valid", "test")] 110 | best_epoch_summary = best_epoch_summary.to_dict(orient='index')[best_epoch] 111 | print(f"best epoch is: {best_epoch}") 112 | print(f"test score: {best_epoch_test_score}") 113 | 114 | results['metrics'] = best_epoch_summary 115 | results['train']['primary_early_stop_metric'] = primary_early_stop_metric 116 | json.dump(results, open(new_results_json_file, "w")) 117 | early_stop_results_dict[primary_early_stop_metric] = results 118 | except KeyError as e: 119 | warnings.warn(repr(e)) 120 | 121 | if args.use_wandb: 122 | import wandb 123 | offline_log_to_wandb(args.wandb_project_name, args_dict, early_stop_results_dict, summary_csv, 124 | workdir=args.dir, 125 | wandb_log_subset_of_metrics=args.wandb_subset_metrics) 126 | print("done offline_early_stop!\n") 127 | 128 | return early_stop_results_dict 129 | 130 | 131 | def offline_log_to_wandb(project_name, args_dict, early_stop_results_dict, summary_df, workdir=None, 132 | wandb_log_subset_of_metrics=False): 133 | 134 | if project_name is None: 135 | project_name = args_dict['exp']['project_name'] + '_offline' 136 | if wandb_log_subset_of_metrics: 137 | project_name += '_subset' 138 | print(f'Writing to W&B project {project_name}') 139 | 140 | curve_metric_names = None 141 | if wandb_log_subset_of_metrics: 142 | curve_metric_names = get_wandb_curve_metrics() 143 | 144 | print(f'Start dump results to W&B project: {project_name}') 145 | wandb_myinit(project_name=project_name, experiment_name=args_dict['exp']['experiment_name'], 146 | instance_name=args_dict['exp']['instance_name'], config=args_dict, workdir=workdir) 147 | 148 | 149 | global_step_name = 'epoch' 150 | summary_df = summary_df.set_index(global_step_name) 151 | print(f'Dump run curves') 152 | first_iter = True 153 | for global_step, step_metrics in summary_df.iterrows(): 154 | if first_iter: 155 | first_iter = False 156 | if curve_metric_names is not None: 157 | for metric in curve_metric_names: 158 | if metric not in step_metrics: 159 | warnings.warn(f"Can't log '{metric}'. It doesn't exists.") 160 | 161 | if wandb_log_subset_of_metrics: 162 | metrics_to_log = slice_dict_to_dict(step_metrics.to_dict(), curve_metric_names, ignore_missing_keys=True) 163 | else: 164 | # log all metrics 165 | metrics_to_log = step_metrics.to_dict() 166 | 167 | metrics_to_log[global_step_name] = global_step 168 | wandb.log(metrics_to_log) 169 | 170 | early_stop_results_to_wandb_summary(early_stop_results_dict) 171 | dump_preds_at_early_stop(early_stop_results_dict, workdir, use_wandb=True) 172 | 173 | # terminate nicely offline w&b run 174 | wandb.join() 175 | 176 | def dump_preds_at_early_stop(early_stop_results_dict, workdir, use_wandb): 177 | print(f'Save to the dumped predictions at early stop epochs') 178 | # dirpath = tempfile.mkdtemp() 179 | for es_metric, results_dict in early_stop_results_dict.items(): 180 | for phase_name in ('valid', 'test'): 181 | target_fname_preds = join(workdir, f'preds__{es_metric}_{phase_name}.npz') 182 | 183 | epoch = results_dict['metrics']['epoch'] 184 | fname = join(workdir, 'dump_preds', f'epoch_{epoch}', f'dump_preds_{phase_name}.npz') 185 | if os.path.exists(fname): 186 | run_bash(f'cp {fname} {target_fname_preds}') 187 | if use_wandb: 188 | import wandb 189 | wandb.save(target_fname_preds) 190 | print(f'Saved {target_fname_preds}') 191 | 192 | 193 | def early_stop_results_to_wandb_summary(early_stop_results_dict): 194 | print(f'Dump early stop results') 195 | wandb_summary = OrderedDict() 196 | for es_metric, results_dict in early_stop_results_dict.items(): 197 | wandb_summary[f'res__{es_metric}'] = results_dict['metrics'] 198 | import wandb 199 | wandb.run.summary.update(wandb_summary) 200 | 201 | 202 | 203 | 204 | def get_wandb_curve_metrics(): 205 | eval_metric_names = comma_seperated_str_to_list( 206 | 'y_joint_loss_mean, y1_loss_mean, y2_loss_mean' 207 | ', closed_balanced_acc' 208 | ', open_balanced_unseen_acc, open_balanced_seen_acc, open_H' 209 | ', y1_balanced_acc_unseen, y2_balanced_acc_unseen' 210 | ', y1_balanced_acc_seen, y2_balanced_acc_seen' 211 | ', closed_acc' 212 | ', unseen_open_acc, seen_open_acc, open_H_IMB' 213 | ', y1_acc_unseen, y2_acc_unseen' 214 | ) 215 | 216 | train_metric_names = comma_seperated_str_to_list('y1_loss, y2_loss, y_loss'#, d_loss' 217 | ', hsic_loss, total_loss'#, d_fool_loss' 218 | ', y1_acc, y2_acc'#, ds1_acc, ds2_acc, current_alpha' 219 | ', HSIC_cond1, HSIC_cond2' 220 | ', loss, leplus_loss, tloss_feat, tloss_ao_emb' 221 | ', tloss_a, tloss_o, loss_aux' 222 | ', loss_aux_disjoint_attr, loss_aux_disjoint_obj') 223 | 224 | 225 | logged_metrics = [] 226 | for metric in eval_metric_names: 227 | logged_metrics.append(metric + '_valid') 228 | logged_metrics.append(metric + '_test') 229 | 230 | for metric in train_metric_names: 231 | logged_metrics.append(metric + '_mean') 232 | 233 | return logged_metrics 234 | 235 | 236 | if __name__ == '__main__': 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | import json 9 | import os, sys 10 | import pickle 11 | from os.path import join 12 | from copy import deepcopy 13 | from pathlib import Path 14 | from sys import argv 15 | import random 16 | 17 | import dataclasses 18 | from munch import Munch 19 | 20 | import useful_utils 21 | import offline_early_stop 22 | import COSMO_utils 23 | from useful_utils import SummaryWriter_withCSV, wandb_myinit 24 | 25 | import torch.nn 26 | import torch 27 | import numpy as np 28 | 29 | import signal 30 | 31 | from data import CompDataFromDict 32 | from params import CommandlineArgs, TrainCfg, ExperimentCfg 33 | from train import train, alternate_training 34 | 35 | 36 | def set_random_seeds(base_seed): 37 | random.seed(base_seed) 38 | np.random.seed(base_seed+7205) # 7205 is a base seed that was randomly chosen 39 | torch.random.manual_seed(base_seed+1000) 40 | torch.cuda.manual_seed(base_seed+1001) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | 44 | def main(args: CommandlineArgs): 45 | train_cfg: TrainCfg = args.train 46 | exp_cfg: ExperimentCfg = args.exp 47 | set_random_seeds(base_seed=train_cfg.seed) 48 | 49 | # init logging 50 | writer = init_logging(args) 51 | 52 | # load data 53 | test_dataset, train_dataset, valid_dataset = load_data(args) 54 | train_cfg.set_n_iter(len(train_dataset.data)) 55 | 56 | # training 57 | with useful_utils.profileblock('complete training'): 58 | if train_cfg.alternate_ys == 0: 59 | ### train both heads jointly ## 60 | train(args, train_dataset, valid_dataset, test_dataset, writer) 61 | else: 62 | ### alternate between heads ### 63 | alternate_training(args, train_dataset, valid_dataset, test_dataset, writer) 64 | 65 | # ----- Finalizing ------ 66 | # dump log to csv 67 | writer.dump_to_csv(verbose=1, float_format=f'%.{args.exp.csv_precision}f') 68 | writer.close() 69 | 70 | # Indicate run has complete 71 | COSMO_utils.run_bash(f'touch {join(exp_cfg.output_dir, "completed_training.touch")}') 72 | 73 | # Process offline early stopping according to results at output_dir 74 | early_stop_results_dict = process_offline_early_stopping(exp_cfg) 75 | 76 | # Print results 77 | print_results(early_stop_results_dict, exp_cfg) 78 | 79 | # Delete temporary artifacts from output dir 80 | clear_output_dir(exp_cfg) 81 | 82 | print('Done.\n') 83 | 84 | 85 | def print_results(early_stop_results_dict, exp_cfg): 86 | from munch import munchify 87 | early_stop_results_dict = munchify(early_stop_results_dict) 88 | print('\n\n####################################') 89 | if exp_cfg.report_imbalanced_metrics: 90 | # E.g. Zappos 91 | U = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.unseen_open_acc_test 92 | S = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.seen_open_acc_test 93 | H = 100 * early_stop_results_dict.open_H_IMB_valid.metrics.open_H_IMB_test 94 | closed = 100 * early_stop_results_dict.AUC_open_valid.metrics.closed_acc_test 95 | AUC = 100 * early_stop_results_dict.AUC_open_valid.metrics.AUC_open_test 96 | print('Reporting IMbalanced metrics') 97 | print(f'Unseen={U:.1f}, Seen={S:.1f}, Harmonic={H:.1f}, Closed={closed:.1f}, AUC={AUC:.1f}') 98 | 99 | else: 100 | # e.g. AO-CLEVr 101 | U = 100 * early_stop_results_dict.open_H_valid.metrics.open_balanced_unseen_acc_test 102 | S = 100 * early_stop_results_dict.open_H_valid.metrics.open_balanced_seen_acc_test 103 | H = 100 * early_stop_results_dict.open_H_valid.metrics.open_H_test 104 | closed = 100 * early_stop_results_dict.closed_balanced_acc_valid.metrics.closed_balanced_acc_test 105 | print('Reporting Balanced metrics') 106 | print(f'Unseen={U:.1f}, Seen={S:.1f}, Harmonic={H:.1f}, Closed={closed:.1f}') 107 | print('####################################\n\n') 108 | 109 | 110 | def init_logging(args): 111 | exp_cfg: ExperimentCfg = args.exp 112 | output_dir = Path(exp_cfg.output_dir) 113 | if not exp_cfg.ignore_existing_output_contents and len(list(output_dir.iterdir())) > 0: 114 | raise ValueError(f'Output directory {output_dir} is not empty') 115 | 116 | args_dict = dataclasses.asdict(args) 117 | if exp_cfg.use_wandb: 118 | import wandb 119 | wandb_myinit(project_name=exp_cfg.project_name, experiment_name=exp_cfg.experiment_name, 120 | instance_name=exp_cfg.instance_name, config=args_dict, workdir=exp_cfg.output_dir, 121 | username=exp_cfg.wandb_user) 122 | # printing starts here - after initializing w&b 123 | print('commandline was:') 124 | print(' '.join(argv)) 125 | print(vars(args)) 126 | writer = SummaryWriter_withCSV(log_dir=exp_cfg.output_dir, suppress_tensorboard=True, wandb=exp_cfg.use_wandb) 127 | writer.set_print_options(pandas_max_columns=500, pandas_max_width=200) 128 | to_json(args_dict, exp_cfg.output_dir, filename='args.json') 129 | return writer 130 | 131 | 132 | def clear_output_dir(exp_cfg): 133 | # Always delete dumped (per-epoch) logits when done, because it takes a lot of space 134 | delete_dumped_logits(exp_cfg.output_dir) 135 | 136 | # Delete dumped (per epoch) decisions if required 137 | if exp_cfg.delete_dumped_preds: 138 | print('Delete logging of per-epoch dumped predictions') 139 | cmd = f'rm -rf {join(exp_cfg.output_dir, "dump_preds")}' 140 | print(cmd) 141 | COSMO_utils.run_bash(cmd) 142 | 143 | 144 | def process_offline_early_stopping(exp_cfg: ExperimentCfg): 145 | cfg_offline_early_stop = Munch() 146 | cfg_offline_early_stop.dir = exp_cfg.output_dir 147 | cfg_offline_early_stop.early_stop_metrics = 'open_H_valid,closed_balanced_acc_valid,open_H_IMB_valid,AUC_open_valid' 148 | early_stop_results_dict = offline_early_stop.main(cfg_offline_early_stop) 149 | if exp_cfg.use_wandb: 150 | # dump each early_stop result to currents project 151 | offline_early_stop.early_stop_results_to_wandb_summary(early_stop_results_dict) 152 | # and save the dumped predictions at its epoch 153 | offline_early_stop.dump_preds_at_early_stop(early_stop_results_dict, exp_cfg.output_dir, use_wandb=exp_cfg.use_wandb) 154 | return early_stop_results_dict 155 | 156 | 157 | def load_data(args: CommandlineArgs): 158 | if args.data.metadata_from_pkl: 159 | train_dataset, valid_dataset, test_dataset = load_pickled_metadata(args) 160 | print('load data from PKL') 161 | else: 162 | train_dataset, valid_dataset, test_dataset = load_TMN_data(args) 163 | print('load data using TMN project') 164 | return test_dataset, train_dataset, valid_dataset 165 | 166 | 167 | def to_json(args_dict, log_dir, filename): 168 | args_json = os.path.join(log_dir, filename) 169 | with open(args_json, 'w') as f: 170 | json.dump(args_dict, f) 171 | print(f'\nDump configuration to JSON file: {args_json}\n\n') 172 | 173 | 174 | def SIGINT_KeyboardInterrupt_handler(sig, frame): 175 | raise KeyboardInterrupt() 176 | 177 | 178 | def load_TMN_data(args: CommandlineArgs): 179 | import sys 180 | sys.path.append('taskmodularnets') 181 | import taskmodularnets.data.dataset as tmn_data 182 | 183 | dict_data = dict() 184 | for subset in ['train', 'val', 'test']: 185 | dTMN = tmn_data.CompositionDatasetActivations(root=args.data.data_dir, 186 | phase=subset, 187 | split='compositional-split-natural') 188 | 189 | 190 | # Add class attributes according to the current project API 191 | dTMN.all_open_pairs, dTMN.seen_pairs = \ 192 | dTMN.pairs, dTMN.train_pairs 193 | 194 | # Get TMN unseen pairs, because val/test_pairs include both seen and unseen pairs 195 | dTMN.unseen_closed_val_pairs = list(set(dTMN.val_pairs).difference(dTMN.seen_pairs)) 196 | dTMN.unseen_closed_test_pairs = list(set(dTMN.test_pairs).difference(dTMN.seen_pairs)) 197 | 198 | dTMN.closed_unseen_pairs = dict( 199 | train=[], 200 | val=dTMN.unseen_closed_val_pairs, 201 | test=dTMN.unseen_closed_test_pairs)[subset] 202 | 203 | dict_data[f'{subset}'] = deepcopy(vars(dTMN)) 204 | 205 | train_dataset = CompDataFromDict(dict_data['train'], data_subset='train_data', data_dir=args.data.data_dir) 206 | valid_dataset = CompDataFromDict(dict_data['val'], data_subset='val_data', data_dir=args.data.data_dir) 207 | test_dataset = CompDataFromDict(dict_data['test'], data_subset='test_data', data_dir=args.data.data_dir) 208 | 209 | print('Seen (train) pairs: ', train_dataset.seen_pairs) 210 | print('Unseen (val) pairs: ', train_dataset.unseen_closed_val_pairs) 211 | print('Unseen (test) pairs: ', train_dataset.unseen_closed_test_pairs) 212 | 213 | return train_dataset, valid_dataset, test_dataset 214 | 215 | 216 | def load_pickled_metadata(args: CommandlineArgs): 217 | data_cfg = args.data 218 | dataset_name = deepcopy(data_cfg['dataset_name']) 219 | dataset_variant = deepcopy(data_cfg['dataset_variant']) 220 | meta_path = Path(f"{data_cfg['data_dir']}/metadata_pickles") 221 | random_state_path = Path(f"{data_cfg['data_dir']}/np_random_state_pickles") 222 | meta_path = meta_path.expanduser() 223 | 224 | dict_data = dict() 225 | seen_seed = args.train.seed 226 | for subset in ['train', 'valid', 'test']: 227 | metadata_full_filename = meta_path / f"metadata_{dataset_name}__{dataset_variant}__comp_seed_{data_cfg['num_split']}__seen_seed_{seen_seed}__{subset}.pkl" 228 | dict_data[f'{subset}'] = deepcopy(pickle.load(open(metadata_full_filename, 'rb'))) 229 | 230 | np_rnd_state_fname = random_state_path / f"np_random_state_{dataset_name}__{dataset_variant}__comp_seed_{data_cfg['num_split']}__seen_seed_{seen_seed}.pkl" 231 | np_seed_state = pickle.load(open(np_rnd_state_fname, 'rb')) 232 | np.random.set_state(np_seed_state) 233 | 234 | train_dataset = CompDataFromDict(dict_data['train'], data_subset='train_data', data_dir=data_cfg['data_dir']) 235 | valid_dataset = CompDataFromDict(dict_data['valid'], data_subset='val_data', data_dir=data_cfg['data_dir']) 236 | test_dataset = CompDataFromDict(dict_data['test'], data_subset='test_data', data_dir=data_cfg['data_dir']) 237 | 238 | print('Seen (train) pairs: ', train_dataset.seen_pairs) 239 | print('Unseen (val) pairs: ', train_dataset.unseen_closed_val_pairs) 240 | print('Unseen (test) pairs: ', train_dataset.unseen_closed_test_pairs) 241 | 242 | return train_dataset, valid_dataset, test_dataset 243 | 244 | 245 | def delete_dumped_logits(logdir): 246 | # Delete dumped logits 247 | f"find {join(logdir, 'dump_preds')} -name 'logits*' -delete" 248 | 249 | 250 | if __name__ == '__main__': 251 | args = CommandlineArgs.get_args() 252 | main(args) 253 | 254 | -------------------------------------------------------------------------------- /taskmodularnets/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # taskmodularnets/LICENSE file. 6 | # 7 | import numpy as np 8 | import torch.utils.data as tdata 9 | import torch 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | import glob 13 | import os 14 | import tqdm 15 | import torchvision.models as tmodels 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | import torch 19 | import bz2 20 | from utils import utils 21 | import h5py 22 | import pdb 23 | import itertools 24 | import os 25 | import collections 26 | import scipy.io 27 | from sklearn.model_selection import train_test_split 28 | 29 | 30 | class ImageLoader: 31 | def __init__(self, root): 32 | self.img_dir = root 33 | 34 | def __call__(self, img): 35 | file = '%s/%s' % (self.img_dir, img) 36 | img = Image.open(file).convert('RGB') 37 | return img 38 | 39 | 40 | def imagenet_transform(phase): 41 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 42 | 43 | if phase == 'train': 44 | transform = transforms.Compose([ 45 | transforms.RandomResizedCrop(224), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean, std) 49 | ]) 50 | elif phase == 'test' or phase == 'val': 51 | transform = transforms.Compose([ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean, std) 56 | ]) 57 | 58 | return transform 59 | 60 | 61 | #------------------------------------------------------------------------------------------------------------------------------------# 62 | 63 | 64 | class CompositionDataset(tdata.Dataset): 65 | def __init__( 66 | self, 67 | root, 68 | phase, 69 | split='compositional-split', 70 | subset=False, 71 | num_negs=1, 72 | pair_dropout=0.0, 73 | ): 74 | self.root = root 75 | self.phase = phase 76 | self.split = split 77 | self.num_negs = num_negs 78 | self.pair_dropout = pair_dropout 79 | 80 | self.feat_dim = None 81 | self.transform = imagenet_transform(phase) 82 | self.loader = ImageLoader(self.root + '/images/') 83 | 84 | self.attrs, self.objs, self.pairs, \ 85 | self.train_pairs, self.val_pairs, \ 86 | self.test_pairs = self.parse_split() 87 | 88 | self.train_data, self.val_data, self.test_data = self.get_split_info() 89 | if self.phase == 'train': 90 | self.data = self.train_data 91 | elif self.phase == 'val': 92 | self.data = self.val_data 93 | else: 94 | self.data = self.test_data 95 | if subset: 96 | ind = np.arange(len(self.data)) 97 | ind = ind[::len(ind) // 1000] 98 | self.data = [self.data[i] for i in ind] 99 | 100 | self.obj2idx = {obj: idx for idx, obj in enumerate(self.objs)} 101 | self.attr2idx = {attr: idx for idx, attr in enumerate(self.attrs)} 102 | self.pair2idx = {pair: idx for idx, pair in enumerate(self.pairs)} 103 | 104 | print('# train pairs: %d | # val pairs: %d | # test pairs: %d' % (len( 105 | self.train_pairs), len(self.val_pairs), len(self.test_pairs))) 106 | print('# train images: %d | # val images: %d | # test images: %d' % 107 | (len(self.train_data), len(self.val_data), len(self.test_data))) 108 | 109 | # fix later -- affordance thing 110 | # return {object: all attrs that occur with obj} 111 | self.obj_affordance = {} 112 | self.train_obj_affordance = {} 113 | for _obj in self.objs: 114 | candidates = [ 115 | attr 116 | for (_, attr, 117 | obj) in self.train_data + self.val_data + self.test_data 118 | if obj == _obj 119 | ] 120 | self.obj_affordance[_obj] = sorted(list(set(candidates))) 121 | 122 | candidates = [ 123 | attr for (_, attr, obj) in self.train_data if obj == _obj 124 | ] 125 | self.train_obj_affordance[_obj] = sorted(list(set(candidates))) 126 | 127 | self.sample_indices = list(range(len(self.data))) 128 | self.sample_pairs = self.train_pairs 129 | 130 | def reset_dropout(self): 131 | self.sample_indices = list(range(len(self.data))) 132 | self.sample_pairs = self.train_pairs 133 | 134 | shuffled_ind = np.random.permutation(len(self.train_pairs)) 135 | n_pairs = int((1 - self.pair_dropout) * len(self.train_pairs)) 136 | self.sample_pairs = [ 137 | self.train_pairs[pi] for pi in shuffled_ind[:n_pairs] 138 | ] 139 | print('Using {} pairs out of {} pairs right now'.format( 140 | n_pairs, len(self.train_pairs))) 141 | self.sample_indices = [ 142 | i for i in range(len(self.data)) 143 | if (self.data[i][1], self.data[i][2]) in self.sample_pairs 144 | ] 145 | print('Using {} images out of {} images right now'.format( 146 | len(self.sample_indices), len(self.data))) 147 | 148 | def get_split_info(self): 149 | data = torch.load(self.root + '/metadata_{}.t7'.format(self.split)) 150 | train_data, val_data, test_data = [], [], [] 151 | for instance in data: 152 | image, attr, obj, settype = instance['image'], instance[ 153 | 'attr'], instance['obj'], instance['set'] 154 | 155 | if attr == 'NA' or (attr, 156 | obj) not in self.pairs or settype == 'NA': 157 | # ignore instances with unlabeled attributes 158 | # ignore instances that are not in current split 159 | continue 160 | 161 | data_i = [image, attr, obj] 162 | if settype == 'train': 163 | train_data.append(data_i) 164 | elif settype == 'val': 165 | val_data.append(data_i) 166 | else: 167 | test_data.append(data_i) 168 | 169 | return train_data, val_data, test_data 170 | 171 | def parse_split(self): 172 | def parse_pairs(pair_list): 173 | with open(pair_list, 'r') as f: 174 | pairs = f.read().strip().split('\n') 175 | pairs = [t.split() for t in pairs] 176 | pairs = list(map(tuple, pairs)) 177 | attrs, objs = zip(*pairs) 178 | return attrs, objs, pairs 179 | 180 | tr_attrs, tr_objs, tr_pairs = parse_pairs( 181 | '%s/%s/train_pairs.txt' % (self.root, self.split)) 182 | vl_attrs, vl_objs, vl_pairs = parse_pairs( 183 | '%s/%s/val_pairs.txt' % (self.root, self.split)) 184 | ts_attrs, ts_objs, ts_pairs = parse_pairs( 185 | '%s/%s/test_pairs.txt' % (self.root, self.split)) 186 | 187 | all_attrs, all_objs = sorted( 188 | list(set(tr_attrs + vl_attrs + ts_attrs))), sorted( 189 | list(set(tr_objs + vl_objs + ts_objs))) 190 | all_pairs = sorted(list(set(tr_pairs + vl_pairs + ts_pairs))) 191 | 192 | return all_attrs, all_objs, all_pairs, tr_pairs, vl_pairs, ts_pairs 193 | 194 | def sample_negative(self, attr, obj): 195 | new_attr, new_obj = self.sample_pairs[np.random.choice( 196 | len(self.sample_pairs))] 197 | if new_attr == attr and new_obj == obj: 198 | return self.sample_negative(attr, obj) 199 | return (self.attr2idx[new_attr], self.obj2idx[new_obj]) 200 | 201 | def sample_affordance(self, attr, obj): 202 | new_attr = np.random.choice(self.obj_affordance[obj]) 203 | if new_attr == attr: 204 | return self.sample_affordance(attr, obj) 205 | return self.attr2idx[new_attr] 206 | 207 | def sample_train_affordance(self, attr, obj): 208 | new_attr = np.random.choice(self.train_obj_affordance[obj]) 209 | if new_attr == attr: 210 | return self.sample_train_affordance(attr, obj) 211 | return self.attr2idx[new_attr] 212 | 213 | def __getitem__(self, index): 214 | index = self.sample_indices[index] 215 | image, attr, obj = self.data[index] 216 | img = self.loader(image) 217 | img = self.transform(img) 218 | 219 | data = [ 220 | img, self.attr2idx[attr], self.obj2idx[obj], self.pair2idx[(attr, 221 | obj)] 222 | ] 223 | 224 | if self.phase == 'train': 225 | all_neg_attrs = [] 226 | all_neg_objs = [] 227 | for _ in range(self.num_negs): 228 | neg_attr, neg_obj = self.sample_negative( 229 | attr, obj) # negative example for triplet loss 230 | all_neg_objs.append(neg_obj) 231 | all_neg_attrs.append(neg_attr) 232 | neg_attr = torch.LongTensor(all_neg_attrs) 233 | neg_obj = torch.LongTensor(all_neg_objs) 234 | inv_attr = self.sample_train_affordance( 235 | attr, obj) # attribute for inverse regularizer 236 | comm_attr = self.sample_affordance( 237 | inv_attr, obj) # attribute for commutative regularizer 238 | data += [neg_attr, neg_obj, inv_attr, comm_attr] 239 | return data 240 | 241 | def __len__(self): 242 | return len(self.sample_indices) 243 | 244 | 245 | #------------------------------------------------------------------------------------------------------------------------------------# 246 | 247 | 248 | class CompositionDatasetActivations(CompositionDataset): 249 | def __init__( 250 | self, 251 | root, 252 | phase, 253 | split, 254 | subset=False, 255 | num_negs=1, 256 | pair_dropout=0.0, 257 | ): 258 | super(CompositionDatasetActivations, self).__init__( 259 | root, 260 | phase, 261 | split, 262 | subset=subset, 263 | num_negs=num_negs, 264 | pair_dropout=pair_dropout, 265 | ) 266 | 267 | # precompute the activations -- weird. Fix pls 268 | feat_file = '%s/features.t7' % root 269 | if not os.path.exists(feat_file): 270 | with torch.no_grad(): 271 | self.generate_features(feat_file) 272 | 273 | activation_data = torch.load(feat_file) 274 | self.activations = dict( 275 | zip(activation_data['files'], activation_data['features'])) 276 | self.feat_dim = activation_data['features'].size(1) 277 | 278 | print('%d activations loaded' % (len(self.activations))) 279 | 280 | def generate_features(self, out_file): 281 | 282 | data = self.train_data + self.val_data + self.test_data 283 | transform = imagenet_transform('test') 284 | feat_extractor = tmodels.resnet18(pretrained=True) 285 | feat_extractor.fc = nn.Sequential() 286 | feat_extractor.eval().cuda() 287 | 288 | image_feats = [] 289 | image_files = [] 290 | for chunk in tqdm.tqdm( 291 | utils.chunks(data, 512), total=len(data) // 512): 292 | files, attrs, objs = zip(*chunk) 293 | imgs = list(map(self.loader, files)) 294 | imgs = list(map(transform, imgs)) 295 | feats = feat_extractor(torch.stack(imgs, 0).cuda()) 296 | image_feats.append(feats.data.cpu()) 297 | image_files += files 298 | image_feats = torch.cat(image_feats, 0) 299 | print('features for %d images generated' % (len(image_files))) 300 | 301 | torch.save({'features': image_feats, 'files': image_files}, out_file) 302 | 303 | def __getitem__(self, index): 304 | data = super(CompositionDatasetActivations, self).__getitem__(index) 305 | index = self.sample_indices[index] 306 | image, attr, obj = self.data[index] 307 | data[0] = self.activations[image] 308 | return data 309 | -------------------------------------------------------------------------------- /useful_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | import re 8 | from collections import OrderedDict, defaultdict 9 | import time 10 | from copy import deepcopy 11 | import os 12 | 13 | import torch 14 | 15 | from tensorboardX import SummaryWriter 16 | import tensorboardX 17 | import pandas as pd 18 | import numpy as np 19 | 20 | from COSMO_utils import run_bash 21 | 22 | 23 | class SummaryWriter_withCSV(SummaryWriter): 24 | """ A wrapper for tensorboard SummaryWriter that is based on pandas, 25 | and writes to CSV and optionally to W&B""" 26 | def __init__(self, log_dir, *args, **kwargs): 27 | 28 | global_step_name = 'epoch' 29 | if 'global_step_name' in kwargs: 30 | global_step_name = kwargs['global_step_name'] 31 | kwargs.pop('global_step_name') 32 | 33 | self.global_step_name = global_step_name 34 | 35 | self.log_wandb = False 36 | if 'wandb' in kwargs: 37 | self.log_wandb = kwargs['wandb'] 38 | kwargs.pop('wandb') 39 | 40 | if self.log_wandb: 41 | import wandb 42 | 43 | self.suppress_tensorboard = kwargs.get('suppress_tensorboard', False) 44 | 45 | if not self.suppress_tensorboard: 46 | super(SummaryWriter_withCSV, self).__init__(log_dir, *args, **kwargs) 47 | 48 | self.df = pd.DataFrame() 49 | self.df.index.name = self.global_step_name 50 | self._log_dir = log_dir 51 | self.last_global_step = -1 52 | 53 | def add_scalar(self, tag, scalar_value, global_step=None): 54 | if global_step is not None: 55 | self.df.loc[global_step, tag] = scalar_value 56 | 57 | # if finalized last step, dump its metrics to wandb 58 | if self.last_global_step < global_step: 59 | if self.log_wandb and self.last_global_step >= 0: 60 | with ns_profiling_label('wandb_log'): 61 | import wandb 62 | wandb.log(self.df.loc[self.last_global_step, :].to_dict(), sync=False) 63 | self.last_global_step = global_step 64 | 65 | if not self.suppress_tensorboard: 66 | super().add_scalar(tag, scalar_value, global_step=global_step) 67 | 68 | def add_summary(self, summary, global_step=None): 69 | summary_proto = tensorboardX.summary.Summary() 70 | 71 | if isinstance(summary, bytes): 72 | summary_list = [value for value in summary_proto.FromString(summary).value] 73 | else: 74 | summary_list = [value for value in summary.value] 75 | for val in summary_list: 76 | self.add_scalar(val.tag, val.simple_value, global_step=global_step) 77 | 78 | 79 | def set_print_options(self, pandas_max_columns=500, pandas_max_width=1000): 80 | pd.set_option('display.max_columns', pandas_max_columns) 81 | pd.set_option('display.width', pandas_max_width) 82 | 83 | def last_results_as_string(self, regex_filter_out=None): 84 | if regex_filter_out is not None: 85 | s = self.df.loc[:, ~self.df.columns.str.match(regex_filter_out)].iloc[-1, :] 86 | else: 87 | s = self.df.iloc[-1, :] 88 | string = ' '.join([f'{key}={s[key]:.2g}' for key in s.keys()]) 89 | return string 90 | 91 | def dump_to_csv(self, fname='summary.csv', sep='|', verbose=0, **kwargs_df_to_csv): 92 | fullfname = os.path.join(self._log_dir, fname) 93 | self.df.to_csv(fullfname, sep=sep, **kwargs_df_to_csv) 94 | if verbose>0: 95 | print(f'Dump history to CSV file: {fullfname}') 96 | if verbose == 2: 97 | print('') 98 | with open(fullfname) as f: 99 | print(f.read()) 100 | def close(self): 101 | 102 | # dump last step results to wandb 103 | if self.log_wandb: 104 | with ns_profiling_label('wandb_log'): 105 | import wandb 106 | wandb.log(self.df.loc[self.last_global_step, :].to_dict(), sync=False) 107 | 108 | if not self.suppress_tensorboard: 109 | super().close() 110 | 111 | 112 | def slice_dict_to_dict(d, keys, returned_keys_prefix='', returned_keys_postfix='', ignore_missing_keys=False): 113 | """ Returns a tuple from dictionary values, ordered and slice by given keys 114 | keys can be a list, or a CSV string 115 | """ 116 | if isinstance(keys, str): 117 | keys = keys[:-1] if keys[-1] == ',' else keys 118 | keys = re.split(', |[, ]', keys) 119 | 120 | if returned_keys_prefix != '' or returned_keys_postfix != '': 121 | return OrderedDict((returned_keys_prefix + k + returned_keys_postfix, d[k]) for k in keys) 122 | 123 | if ignore_missing_keys: 124 | return OrderedDict((k, d[k]) for k in keys if k in d) 125 | else: 126 | return OrderedDict((k, d[k]) for k in keys) 127 | 128 | 129 | def clone_model(model): 130 | # clone a pytorch model 131 | return deepcopy(model) 132 | 133 | 134 | def is_uncommited_git_repo(list_ignored_regex_pattern_filenames=None, ignore_untracked_files=True): 135 | """ Check if there are uncommited files in workdir. 136 | Can ignore specific file names or regex patterns 137 | """ 138 | 139 | uncommitted_files_list = get_uncommitted_files(list_ignored_regex_pattern_filenames, ignore_untracked_files) 140 | 141 | if uncommitted_files_list: 142 | return True 143 | else: 144 | return False 145 | 146 | 147 | def get_uncommitted_files(list_ignored_regex_pattern_filenames, ignore_untracked_files): 148 | ignore_untracked_files_str = '' 149 | if ignore_untracked_files: 150 | ignore_untracked_files_str = ' -uno' 151 | if list_ignored_regex_pattern_filenames is None: 152 | list_ignored_regex_pattern_filenames = [] 153 | git_status = run_bash('git status --porcelain' + ignore_untracked_files_str) 154 | uncommitted_files = [] 155 | for line in git_status.split('\n'): 156 | ignore_current_file = False 157 | if line: 158 | fname = re.split(' ?\w +', line)[1] 159 | for ig_file_regex_pattern in list_ignored_regex_pattern_filenames: 160 | if re.match(ig_file_regex_pattern, fname): 161 | ignore_current_file = True 162 | break 163 | if ignore_current_file: 164 | continue 165 | else: 166 | uncommitted_files.append(fname) 167 | return uncommitted_files 168 | 169 | 170 | def list_to_2d_tuple(l): 171 | return tuple(tuple(tup) for tup in l) 172 | 173 | 174 | def categorical_histogram(data, labels_list, plot=True, frac=True, plt_show=False): 175 | import matplotlib.pyplot as plt 176 | s_counts = pd.Series(data).value_counts() 177 | s_frac = s_counts/s_counts.sum() 178 | hist_dict = s_counts.to_dict() 179 | if frac: 180 | hist_dict = s_frac.to_dict() 181 | hist = [] 182 | for ix, _ in enumerate(labels_list): 183 | hist.append(hist_dict.get(ix, 0)) 184 | 185 | if plot: 186 | pd.Series(hist, index=labels_list).plot(kind='bar') 187 | if frac: 188 | plt.ylim((0,1)) 189 | if plt_show: 190 | plt.show() 191 | else: 192 | return np.array(hist, dtype='float32') 193 | 194 | 195 | def to_torch(array, device): 196 | array = np.asanyarray(array) # cast to array if not an array. othrwise do nothing 197 | return torch.from_numpy(array).to(device) 198 | 199 | 200 | def comma_seperated_str_to_list(comma_seperated_str, regex_sep=r', |[, ]'): 201 | return re.split(regex_sep, comma_seperated_str) 202 | 203 | 204 | 205 | class profileblock(object): 206 | """ 207 | Usage example: 208 | with profileblock(label='abc'): 209 | time.sleep(0.1) 210 | 211 | """ 212 | def __init__(self, label=None, disable=False): 213 | self.disable = disable 214 | self.label = '' 215 | if label is not None: 216 | self.label = label + ': ' 217 | 218 | def __enter__(self): 219 | if self.disable: 220 | return 221 | self.tic = time.time() 222 | return self 223 | 224 | def __exit__(self, type, value, traceback): 225 | if self.disable: 226 | return 227 | elapsed = np.round(time.time() - self.tic, 2) 228 | print(f'{self.label} Elapsed {elapsed} sec') 229 | 230 | ########### 231 | 232 | 233 | ################## 234 | 235 | from contextlib import contextmanager 236 | from torch.cuda import nvtx 237 | 238 | 239 | # @contextmanager 240 | # def ns_profiling_label(label, disable=False): 241 | # """ 242 | # Wraps a code block with a label for profiling using "nsight-systems" 243 | # Usage example: 244 | # with ns_profiling_label('epoch %d'%epoch): 245 | # << CODE FOR TRAINING AN EPOCH >> 246 | # 247 | # """ 248 | # if not disable: 249 | # nvtx.range_push(label) 250 | # try: 251 | # yield None 252 | # finally: 253 | # if not disable: 254 | # nvtx.range_pop() 255 | 256 | 257 | @contextmanager 258 | def ns_profiling_label(label): 259 | """ 260 | A do nothing version of ns_profiling_label() 261 | 262 | """ 263 | try: 264 | yield None 265 | finally: 266 | pass 267 | 268 | 269 | def torch_nans(*args, **kwargs): 270 | return np.nan*torch.zeros(*args, **kwargs) 271 | 272 | 273 | 274 | class batch_torch_logger(): 275 | def __init__(self, cs_str_args=None, num_batches=None, nanmean_args_cs_str=None, device=None): 276 | """ 277 | cs_str_args: arguments list as a comma separated string 278 | """ 279 | assert(num_batches is not None) 280 | self.device = device 281 | self.loggers = {} 282 | if cs_str_args is not None: 283 | args_list = comma_seperated_str_to_list(cs_str_args) 284 | self.loggers = dict((arg_name, torch_nans(num_batches, device=device)) for arg_name in args_list) 285 | 286 | self.nanmean_args_list = [] 287 | if nanmean_args_cs_str is not None: 288 | self.nanmean_args_list = comma_seperated_str_to_list(nanmean_args_cs_str) 289 | 290 | 291 | self.cnt = -1 292 | 293 | def new_batch(self): 294 | self.cnt += 1 295 | 296 | def log(self, locals_dict): 297 | for arg in self.loggers.keys(): 298 | self.log_arg(arg, locals_dict.get(arg, torch.tensor(-9999.).to(self.device))) 299 | 300 | def log_arg(self, arg, value): 301 | with ns_profiling_label(f'log arg'): 302 | try: 303 | if type(value) == float or isinstance(value, np.number): 304 | value = torch.FloatTensor([value]) 305 | self.loggers[arg][self.cnt] = value.detach() 306 | except AttributeError: 307 | print(f'Error: arg name = {arg}') 308 | raise 309 | 310 | def mean(self, arg): 311 | if arg in self.nanmean_args_list: 312 | return torch_nanmean(self.loggers[arg][:(self.cnt + 1)]).detach().item() 313 | else: 314 | return self.loggers[arg][:(self.cnt+1)].mean().detach().item() 315 | 316 | def get_means(self): 317 | return OrderedDict((arg + '_mean', self.mean(arg)) for arg in self.loggers.keys()) 318 | 319 | def torch_nanmean(x): 320 | return x[~torch.isnan(x)].mean() 321 | 322 | 323 | def wandb_myinit(project_name, experiment_name, instance_name, config, workdir=None, wandb_output_dir='/tmp/wandb', 324 | reinit=True, username='user'): 325 | 326 | import wandb 327 | tags = [experiment_name] 328 | if experiment_name.startswith('qa_'): 329 | tags.append('qa') 330 | config['workdir'] = workdir 331 | wandb.init(project=project_name, name=instance_name, config=config, tags=tags, dir=wandb_output_dir, reinit=reinit, 332 | entity=username) 333 | 334 | 335 | def get_all_argparse_keys(parser): 336 | return [action.dest for action in parser._actions] 337 | 338 | 339 | def fill_missing_by_defaults(args_dict, argparse_parser): 340 | all_arg_keys = get_all_argparse_keys(argparse_parser) 341 | for key in all_arg_keys: 342 | if key not in args_dict: 343 | args_dict[key] = argparse_parser.get_default(key) 344 | return args_dict 345 | 346 | 347 | def get_and_update_num_calls(func_ptr): 348 | try: 349 | get_and_update_num_calls.num_calls_cnt[func_ptr] += 1 350 | except AttributeError as e: 351 | if 'num_calls_cnt' in repr(e): 352 | get_and_update_num_calls.num_calls_cnt = defaultdict(int) 353 | else: 354 | raise 355 | 356 | return get_and_update_num_calls.num_calls_cnt[func_ptr] 357 | 358 | 359 | def duplicate(x, times): 360 | return tuple(deepcopy(x) for _ in range(times)) 361 | 362 | 363 | def to_simple_parsing_args(some_dataclass_type): 364 | """ 365 | Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline 366 | Example: 367 | 368 | @classmethod 369 | def get_args(cls): 370 | args: cls = to_simple_parsing_args(cls) 371 | return args 372 | 373 | """ 374 | from simple_parsing import ArgumentParser, ConflictResolution 375 | parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) 376 | parser.add_arguments(some_dataclass_type, dest='cfg') 377 | args: some_dataclass_type = parser.parse_args().cfg 378 | return args 379 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | from copy import deepcopy 9 | from typing import Union 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | from torch import nn 14 | from torch.nn.functional import triplet_margin_loss 15 | 16 | from data import CompDataFromDict 17 | from useful_utils import ns_profiling_label 18 | from params import CommandlineArgs 19 | from ATTOP.models.models import MLP as ATTOP_MLP 20 | 21 | def get_model(args: CommandlineArgs, dataset: CompDataFromDict): 22 | cfg = args.model 23 | if cfg.E_num_common_hidden==0: 24 | ECommon = nn.Sequential() 25 | ECommon.output_shape = lambda: (None, dataset.input_shape[0]) 26 | else: 27 | ECommon = ATTOP_MLP(dataset.input_shape[0], cfg.h_dim, num_layers=cfg.E_num_common_hidden - 1, relu=True, bias=True).to(args.device) 28 | ECommon.output_shape = lambda: (None, cfg.h_dim) 29 | 30 | 31 | if not cfg.VisProd: 32 | h_A = Label_MLP_embed(dataset.num_attrs, cfg.h_dim, 33 | num_layers=cfg.nlayers_label).to(args.device) 34 | h_O = Label_MLP_embed(dataset.num_objs, cfg.h_dim, 35 | num_layers=cfg.nlayers_label).to(args.device) 36 | g1_emb_to_hidden_feat = ATTOP_MLP(2 * cfg.h_dim, cfg.h_dim, 37 | num_layers=cfg.nlayers_joint_ao).to(args.device) 38 | g2_feat_to_image_feat = ATTOP_MLP(cfg.h_dim, dataset.input_shape[0], num_layers=cfg.nlayers_joint_ao 39 | ).to(args.device) 40 | 41 | g_inv_O = MLP_Encoder(ECommon.output_shape()[1], h_dim=cfg.h_dim, 42 | E_num_common_hidden=cfg.E_num_hidden - cfg.E_num_common_hidden, 43 | mlp_activation='leaky_relu', BN=True).to(args.device) # category 44 | g_inv_A = MLP_Encoder(ECommon.output_shape()[1], h_dim=cfg.h_dim, 45 | E_num_common_hidden=cfg.E_num_hidden - cfg.E_num_common_hidden, 46 | mlp_activation='leaky_relu', BN=True).to(args.device) # category 47 | 48 | emb_cf_O = EmbeddingClassifier(h_O, image_feat_dim=g_inv_O.h_dim, device=args.device).to(args.device) 49 | # redundant historic call - just to make sure random-number-gen is kept aligned with original codebase 50 | _ = EmbeddingClassifier(h_A, image_feat_dim=g_inv_A.h_dim, device=args.device).to(args.device) 51 | 52 | emb_cf_A = EmbeddingClassifier(h_A, image_feat_dim=g_inv_A.h_dim, device=args.device).to(args.device) 53 | # redundant historic call - just to make sure random-number-gen is kept aligned with original codebase 54 | _ = EmbeddingClassifier(h_O, image_feat_dim=g_inv_O.h_dim, device=args.device).to(args.device) 55 | 56 | model = CompModel(ECommon, g_inv_O, g_inv_A, emb_cf_O, emb_cf_A, h_O, h_A, g1_emb_to_hidden_feat, g2_feat_to_image_feat, 57 | args, dataset).to(args.device) 58 | 59 | return model 60 | 61 | 62 | def nll_sum_loss(attr_logits, obj_logits, attr_gt, obj_gt, nll_loss_funcs): 63 | nll_loss_attr = nll_loss_funcs.y2(attr_logits, attr_gt) 64 | nll_loss_obj = nll_loss_funcs.y1(obj_logits, obj_gt) 65 | return nll_loss_attr + nll_loss_obj 66 | 67 | 68 | def get_activation_layer(activation): 69 | return dict(leaky_relu=nn.LeakyReLU(), 70 | relu=nn.ReLU(), 71 | )[activation] 72 | 73 | class MLP_block(nn.Module): 74 | def __init__(self, input_dim, output_dim, mlp_activation='leaky_relu', BN=True): 75 | super(MLP_block, self).__init__() 76 | layers_list = [] 77 | layers_list += [nn.Linear(input_dim, output_dim)] 78 | if BN: 79 | layers_list += [nn.BatchNorm1d(num_features=output_dim)] 80 | layers_list += [get_activation_layer(mlp_activation)] 81 | self.NN = nn.Sequential(*layers_list) 82 | 83 | def forward(self, x): 84 | return self.NN(x) 85 | 86 | 87 | class Label_MLP_embed(nn.Module): 88 | def __init__(self, num_embeddings, embedding_dim, num_layers=0): 89 | super(Label_MLP_embed, self).__init__() 90 | self.lin_emb = nn.Embedding(num_embeddings, embedding_dim) 91 | layers_list = [] 92 | layers_list += [self.lin_emb] 93 | if num_layers > 0: 94 | layers_list += [ATTOP_MLP(embedding_dim, embedding_dim, num_layers=num_layers)] 95 | self.NN = nn.Sequential(*layers_list) 96 | self.embedding_dim = embedding_dim 97 | 98 | def forward(self, tokens): 99 | output = self.NN(tokens) 100 | return output 101 | 102 | 103 | 104 | class MLP_Encoder(nn.Module): 105 | def __init__(self, input_dim, h_dim=150, E_num_common_hidden=1, mlp_activation='leaky_relu', BN=True, 106 | linear_last_layer=False, **kwargs): 107 | super().__init__() 108 | 109 | self.h_dim = h_dim 110 | layers_list = [] 111 | prev_dim = input_dim 112 | for n in range(E_num_common_hidden): 113 | if n==0: 114 | layers_list += [MLP_block(prev_dim, h_dim, mlp_activation=mlp_activation, BN=BN)] 115 | else: 116 | layers_list += [MLP_block(prev_dim, h_dim, mlp_activation=mlp_activation, BN=BN)] 117 | prev_dim = h_dim 118 | 119 | 120 | if linear_last_layer: 121 | layers_list += [nn.Linear(prev_dim, prev_dim)] 122 | 123 | self._output_shape = prev_dim 124 | self.layers_list = layers_list 125 | self.NN = nn.Sequential(*layers_list) 126 | 127 | 128 | def forward(self, x): 129 | return self.NN(x) 130 | 131 | def output_shape(self): 132 | return (None, self._output_shape) 133 | 134 | 135 | class EmbeddingClassifier(nn.Module): 136 | def __init__(self, embeddings, image_feat_dim, device): 137 | super().__init__() 138 | self.num_emb_class, _ = list(embeddings.parameters())[0].shape 139 | self.image_feat_dim = image_feat_dim 140 | self.device = device 141 | 142 | self.embeddings = embeddings 143 | self.emb_alignment = nn.Linear(image_feat_dim, self.embeddings.embedding_dim) 144 | def forward(self, feat): 145 | feat = self.emb_alignment(feat) 146 | emb_per_class = self.embeddings(torch.arange(0, self.num_emb_class, dtype=int).to(self.device)).T 147 | out = -((feat[:, :, None] - emb_per_class) ** 2).sum(1) 148 | return out 149 | 150 | 151 | def LinearSoftmaxLogits(input_dim, output_dim): 152 | return nn.Sequential(nn.Linear(input_dim, output_dim), 153 | nn.LogSoftmax()) 154 | 155 | 156 | class CompModel(nn.Module): 157 | def __init__(self, ECommon, g_inv_O, g_inv_A, emb_cf_O, emb_cf_A, h_O, h_A, g1_emb_to_hidden_feat, g2_feat_to_image_feat, 158 | args: CommandlineArgs, dataset): 159 | """ 160 | Input: 161 | E: encoder 162 | M: classifier 163 | D: discriminator 164 | alpha: weighting parameter of label classifier and domain classifier 165 | num_classes: the number of classes 166 | """ 167 | super(CompModel, self).__init__() 168 | model_cfg = args.model 169 | self.args = args 170 | self.is_mutual = True 171 | self.ECommon = ECommon 172 | self.g_inv_O = g_inv_O 173 | self.g_inv_A = g_inv_A 174 | self.emb_cf_O = emb_cf_O 175 | self.emb_cf_A = emb_cf_A 176 | self.h_O = h_O 177 | self.h_A = h_A 178 | self.g1_emb_to_hidden_feat = g1_emb_to_hidden_feat 179 | self.g2_feat_to_image_feat = g2_feat_to_image_feat 180 | self.loader = None 181 | self.num_objs = dataset.num_objs 182 | self.num_attrs = dataset.num_attrs 183 | self.attrs_idxs, self.objs_idxs = self._get_ao_outerprod_idxs() 184 | self.last_feature_common = None 185 | 186 | if not model_cfg.VisProd: 187 | self.obj_inv_core_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_objs).to(args.device) 188 | self.attr_inv_core_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_attrs).to(args.device) 189 | 190 | self.obj_inv_g_hidden_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_objs).to(args.device) 191 | self.attr_inv_g_hidden_logits = LinearSoftmaxLogits(model_cfg.h_dim, self.num_attrs).to(args.device) 192 | 193 | self.obj_inv_g_imgfeat_logits = LinearSoftmaxLogits(dataset.input_shape[0], self.num_objs).to(args.device) 194 | self.attr_inv_g_imgfeat_logits = LinearSoftmaxLogits(dataset.input_shape[0], self.num_attrs).to(args.device) 195 | 196 | self.device = args.device 197 | # check (rename) 198 | self.mu_disjoint = args.train.mu_disjoint 199 | self.mu_ao_emb = args.train.mu_ao_emb 200 | self.mu_img_feat = args.train.mu_img_feat 201 | 202 | def encode(self, input_data, freeze_class1=False, freeze_class2=False): 203 | feature_common = self.ECommon(input_data) 204 | self.last_feature_common = feature_common 205 | if feature_common is input_data: 206 | self.last_feature_common = None 207 | 208 | if freeze_class1: 209 | with torch.no_grad(): 210 | feature1 = self.g_inv_O(feature_common).detach() 211 | else: 212 | feature1 = self.g_inv_O(feature_common) 213 | 214 | 215 | if freeze_class2: 216 | with torch.no_grad(): 217 | feature2 = self.g_inv_A(feature_common).detach() 218 | else: 219 | feature2 = self.g_inv_A(feature_common) 220 | 221 | return feature1, feature2, feature_common 222 | 223 | def forward(self, input_data, 224 | freeze_class1=False, freeze_class2=False): 225 | 226 | ### init and definitions 227 | freeze_class = freeze_class1, freeze_class2 228 | classifiers = self.emb_cf_O, self.emb_cf_A 229 | class_outputs = [None, None] 230 | 231 | def set_grad_disabled(condition): 232 | return torch.set_grad_enabled(not condition) 233 | ### end init and definitions 234 | 235 | feature1, feature2, feature_common = self.encode(input_data, freeze_class1, freeze_class2) 236 | 237 | for m, feature in enumerate([feature1, feature2]): 238 | with set_grad_disabled(freeze_class[m]): 239 | class_outputs[m] = classifiers[m](feature) 240 | if freeze_class[m]: 241 | class_outputs[m] = class_outputs[m].detach() 242 | 243 | joint_output = (class_outputs[0][..., None] + class_outputs[1][..., None, :]) 244 | joint_output = torch.flatten(joint_output[:, :], start_dim=1) # flatten 245 | 246 | # inference 247 | if not self.training and not self.args.model.VisProd: 248 | joint_output = self.mu_disjoint * joint_output 249 | 250 | if self.mu_img_feat > 0 or self.mu_ao_emb > 0: 251 | flattened_ao_emb_joint_scores, flattened_img_emb_scores = \ 252 | self.get_joint_embed_classification_scores(self.attrs_idxs, self.objs_idxs, 253 | self.last_feature_common, input_data) 254 | 255 | joint_output += self.mu_ao_emb * flattened_ao_emb_joint_scores 256 | joint_output += self.mu_img_feat * flattened_img_emb_scores 257 | 258 | scores_emb = joint_output.view((-1, self.num_objs, self.num_attrs)) 259 | class_outputs[0] = scores_emb.max(axis=2)[0].detach() 260 | class_outputs[1] = scores_emb.max(axis=1)[0].detach() # obj, attr 261 | 262 | return tuple(class_outputs + [feature1, feature2, joint_output]) 263 | 264 | def eval_pair_embed_losses(self, args: CommandlineArgs, img_feat, img_hidden_emb, attr_labels, obj_labels, 265 | neg_attr_labels, neg_obj_labels, nll_loss_funcs): 266 | device = args.device 267 | 268 | with ns_profiling_label('labels_to_embeddings'): 269 | h_A_pos, h_O_pos, g_hidden_pos, g_img_pos = self.labels_to_embeddings(attr_labels, obj_labels) 270 | _, _, g_hidden_neg, g_img_neg = self.labels_to_embeddings(neg_attr_labels, neg_obj_labels) 271 | 272 | tloss_g_imgfeat = torch.tensor(0.).to(device) 273 | if args.train.lambda_feat > 0: 274 | with ns_profiling_label('tloss_g_imgfeat'): 275 | tloss_g_imgfeat = triplet_margin_loss(img_feat, g_img_pos, g_img_neg, 276 | margin=args.train.triplet_loss_margin) 277 | 278 | tloss_g_hidden = torch.tensor(0.).to(device) 279 | if args.train.lambda_ao_emb > 0: 280 | with ns_profiling_label('tloss_g_hidden'): 281 | tloss_g_hidden = triplet_margin_loss(img_hidden_emb, g_hidden_pos, g_hidden_neg, 282 | margin=args.train.triplet_loss_margin) 283 | 284 | 285 | # Loss_invert terms 286 | loss_inv_core = torch.tensor(0.).to(device) 287 | if args.train.lambda_aux_disjoint > 0: # check hp name 288 | with ns_profiling_label('loss_inv_core'): 289 | loss_inv_core = nll_sum_loss(self.attr_inv_core_logits(h_A_pos), 290 | self.obj_inv_core_logits(h_O_pos), 291 | attr_labels, obj_labels, nll_loss_funcs) 292 | 293 | loss_inv_g_imgfeat = torch.tensor(0.).to(device) 294 | if args.train.lambda_aux_img > 0: # check hp name 295 | with ns_profiling_label('loss_inv_g_imgfeat'): 296 | loss_inv_g_imgfeat = nll_sum_loss(self.attr_inv_g_imgfeat_logits(g_img_pos), 297 | self.obj_inv_g_imgfeat_logits(g_img_pos), 298 | attr_labels, obj_labels, nll_loss_funcs) 299 | 300 | loss_inv_g_hidden = torch.tensor(0.).to(device) 301 | if args.train.lambda_aux > 0: # check hp name 302 | with ns_profiling_label('loss_inv_g_hidden'): 303 | loss_inv_g_hidden = nll_sum_loss(self.attr_inv_g_hidden_logits(g_hidden_pos), 304 | self.obj_inv_g_hidden_logits(g_hidden_pos), 305 | attr_labels, obj_labels, nll_loss_funcs) 306 | 307 | 308 | return tloss_g_hidden, tloss_g_imgfeat, loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat 309 | 310 | def labels_to_embeddings(self, attr_labels, obj_labels): 311 | """ 312 | h_A 313 | | 314 | attr_labels -> h_A -> 315 | > g1_emb_to_hidden_feat -> g2_feat_to_image_feat -> 316 | obj_labels -> h_O -> | | 317 | | g_hidden g_img 318 | h_O 319 | 320 | """ 321 | h_A = self.h_A(attr_labels) 322 | h_O = self.h_O(obj_labels) 323 | 324 | g_hidden = self.g1_emb_to_hidden_feat(torch.cat((h_A, h_O), dim=1)) 325 | g_img = self.g2_feat_to_image_feat(g_hidden) 326 | 327 | return h_A, h_O, g_hidden, g_img 328 | 329 | def get_joint_embed_classification_scores(self, attrs, objs, common_emb_feat, img_feat): 330 | _, _, g_hidden, g_img = self.labels_to_embeddings(attrs, objs) 331 | vec_dist_img_emb = ((img_feat[:, :, None] - g_img.T[None, :, :])) 332 | flattened_img_emb_scores = -((vec_dist_img_emb ** 2).sum(1)) 333 | 334 | if common_emb_feat is not None: 335 | vec_dist_joint_ao_emb = ((common_emb_feat[:, :, None] - g_hidden.T[None, :, :])) 336 | flattened_joint_scores = -((vec_dist_joint_ao_emb ** 2).sum(1)) 337 | else: 338 | flattened_joint_scores = 0*flattened_img_emb_scores.detach() 339 | 340 | return flattened_joint_scores, flattened_img_emb_scores 341 | 342 | def _get_ao_outerprod_idxs(self): 343 | device = self.args.device 344 | outerprod_pairs = torch.cartesian_prod(torch.arange(0, self.num_objs, device=device), 345 | torch.arange(0, self.num_attrs, device=device)) 346 | objs_idxs = outerprod_pairs[:, 0] 347 | attrs_idxs = outerprod_pairs[:, 1] 348 | return attrs_idxs, objs_idxs 349 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | 10 | import torch 11 | from torch import Tensor 12 | import numpy as np 13 | 14 | from model import CompModel 15 | from data import CompDataFromDict 16 | from useful_utils import ns_profiling_label, batch_torch_logger, list_to_2d_tuple, slice_dict_to_dict, to_torch 17 | from COSMO_utils import calc_cs_ausuc, per_class_balanced_accuracy 18 | 19 | 20 | def evaluation_step(model, valid_loader, test_loader, loss_funcs, writer, epoch, n_epochs, curr_epoch_metrics, 21 | early_stop_metric_name, best_ES_metric_value, calc_AUC=False): 22 | """ 23 | Forward pass validation and test data, updated the metrics and finds best epoch according to validation metric. 24 | """ 25 | model.eval() 26 | # a path for dumping logits and predictions of each epoch. 27 | dump_preds_path = os.path.join(writer._log_dir, 'dump_preds', f'epoch_{epoch}') 28 | 29 | # get metrics on validation set 30 | with ns_profiling_label('eval val set'): 31 | metrics_valid = eval_model_with_dataloader(model, valid_loader, loss_funcs, phase_name='valid', 32 | calc_AUC=calc_AUC, dump_to_fs_basedir=dump_preds_path) 33 | metrics_valid['epoch'] = epoch # log epoch number as a metric 34 | # update metrics dictionary with validation metrics 35 | curr_epoch_metrics.update(metrics_valid) 36 | 37 | # get metrics on test set 38 | with ns_profiling_label('eval test set'): 39 | metrics_test = eval_model_with_dataloader(model, test_loader, loss_funcs, phase_name='test', calc_AUC=calc_AUC, 40 | dump_to_fs_basedir=dump_preds_path) 41 | # update metrics dictionary with test metrics 42 | curr_epoch_metrics.update(metrics_test) 43 | 44 | 45 | # Early Stop (ES) monitoring 46 | current_ES_metric_value = metrics_valid[early_stop_metric_name.metric] 47 | ES_metric_polarity = 2 * (early_stop_metric_name.polarity == 'max') - 1 48 | is_best = False 49 | if best_ES_metric_value*ES_metric_polarity <= current_ES_metric_value*ES_metric_polarity: 50 | is_best = True 51 | best_ES_metric_value = current_ES_metric_value 52 | 53 | model.train() 54 | 55 | return best_ES_metric_value, is_best 56 | 57 | 58 | def eval_model_with_dataloader(model, data_loader, loss_funcs, phase_name, 59 | calc_AUC=False, 60 | dump_to_fs_basedir=None): 61 | if dump_to_fs_basedir is not None: 62 | os.makedirs(dump_to_fs_basedir, exist_ok=True) 63 | print(f'\n\n\nEvaluating metrics for {phase_name} phase:') 64 | 65 | # get model logits and ground-truth for evaluation data 66 | with ns_profiling_label('forward_pass_data'): 67 | logits_pred_ys_1, logits_pred_ys_2, logits_ys_joint, y1, y2, ys_joint, logger_means, logits_filenames \ 68 | = forward_pass_data(model, data_loader, loss_funcs) 69 | 70 | preds_dump_dict, results = calc_metrics_from_logits(data_loader.dataset, logits_ys_joint, ys_joint, y1, y2, 71 | logits_filenames, phase_name, logger_means, calc_AUC) 72 | 73 | ### dump predictions to filesystem 74 | if dump_to_fs_basedir is not None: 75 | # cast to numpy 76 | preds_dump_dict.update( 77 | dict((k, v.detach().cpu().numpy()) for k, v in preds_dump_dict.items() if isinstance(v, torch.Tensor))) 78 | fname = os.path.join(dump_to_fs_basedir, 'dump_preds' + f'_{phase_name}') + '.npz' 79 | np.savez_compressed(fname, **preds_dump_dict) 80 | print(f'dumped predictions to: {fname}') 81 | 82 | return results 83 | 84 | 85 | def forward_pass_data(model: CompModel, data_loader, loss_funcs): 86 | device = model.device 87 | with torch.no_grad(): 88 | num_batches = len(data_loader) 89 | batch_metrics_logger = batch_torch_logger(num_batches=num_batches, 90 | cs_str_args='y1_loss, y2_loss, y_sum_loss', 91 | device=device) 92 | 93 | # predefine lists that aggregates data across batches 94 | ys_1 = [None] * num_batches 95 | ys_2 = [None] * num_batches 96 | fname_list = [None] * num_batches 97 | logits_pred_ys_1 = [None] * num_batches 98 | logits_pred_ys_2 = [None] * num_batches 99 | ys_joint = [None] * num_batches 100 | logits_ys_joint = [None] * num_batches 101 | data_iter = iter(data_loader) 102 | 103 | # iterate on all the data - forward-pass it through the model and aggregate the ground-truth labels, 104 | # the logits, and losses 105 | for i in range(num_batches): 106 | X_batch, y2_batch, y1_batch, _, _, fname_batch = next(data_iter) 107 | 108 | with ns_profiling_label('process batch'): 109 | batch_metrics_logger.new_batch() 110 | 111 | with torch.no_grad(): 112 | 113 | with ns_profiling_label('copy to gpu'): 114 | X_batch = X_batch.float().to(device) 115 | y1_batch = y1_batch.long().to(device) 116 | y2_batch = y2_batch.long().to(device) 117 | 118 | with ns_profiling_label('fwd pass'): 119 | y1_pred, y2_pred, _, _, joint_pred = model(X_batch) 120 | y1_pred = y1_pred.detach() 121 | y2_pred = y2_pred.detach() 122 | 123 | # calc a joint label 124 | y_joint = data_loader.dataset.to_joint_label(y1_batch, y2_batch).detach() 125 | 126 | # evaluate the loss for the current eval batch 127 | if loss_funcs is not None: 128 | y1_loss = loss_funcs.y1(y1_pred, y1_batch).detach() 129 | y2_loss = loss_funcs.y2(y2_pred, y2_batch).detach() 130 | y_sum_loss = (y1_loss + y2_loss).detach() 131 | 132 | batch_metrics_logger.log(locals_dict=locals()) 133 | 134 | # aggregate labels and logits of current batch. 135 | ys_1[i] = y1_batch 136 | ys_2[i] = y2_batch 137 | ys_joint[i] = y_joint 138 | 139 | fname_list[i] = fname_batch 140 | 141 | logits_pred_ys_1[i] = y1_pred 142 | logits_pred_ys_2[i] = y2_pred 143 | logits_ys_joint[i] = joint_pred 144 | 145 | # calc the mean of each metric across all batches 146 | logger_means = batch_metrics_logger.get_means() 147 | 148 | # concat per-batch ground-truth labels list to tensors 149 | y1 = torch.cat(ys_1).detach() 150 | y2 = torch.cat(ys_2).detach() 151 | ys_joint = torch.cat(ys_joint).detach() 152 | 153 | # cocant per-batch filenames list to a single tuple 154 | fname = sum(list_to_2d_tuple(fname_list), tuple()) 155 | 156 | # cocant per-batch logits list to tensors 157 | logits_pred_ys_1 = torch.cat(logits_pred_ys_1).detach() 158 | logits_pred_ys_2 = torch.cat(logits_pred_ys_2).detach() 159 | logits_ys_joint = torch.cat(logits_ys_joint).detach() 160 | return logits_pred_ys_1, logits_pred_ys_2, logits_ys_joint, y1, y2, ys_joint, logger_means, fname 161 | 162 | 163 | def calc_metrics_from_logits(dataset, logits_ys_joint, ys_joint, y1, y2, logits_fnames, phase_name, logger_means={}, 164 | calc_AUC=False, device=None): 165 | 166 | """ 167 | Notes: 1. This function assumes that the logits contain all (cartesian product) combinations of attribute and object 168 | i.e. including combinations that are not included in the open-set. 169 | 2. calc_AUC is optional, because it slows down evaluation step 170 | """ 171 | 172 | # Prepare variables 173 | fname = logits_fnames # filenames 174 | 175 | 176 | # Get: 177 | # (1) logits based scores for objects and for attributes 178 | # (2) pairs indexing variables 179 | logits_pred_ys_1, logits_pred_ys_2 = \ 180 | prepare_logits_for_metrics_calc(logits_ys_joint, dataset) 181 | 182 | 183 | shape_obj_attr = dataset.shape_obj_attr 184 | flattened_seen_pairs_mask = dataset.flattened_seen_pairs_mask 185 | flattened_closed_unseen_pairs_mask = dataset.flattened_closed_unseen_pairs_mask 186 | flattened_all_open_pairs_mask = dataset.flattened_all_open_pairs_mask 187 | 188 | # get indices of seen samples and of unseen samples 189 | seen_pairs_joint_class_ids = dataset.seen_pairs_joint_class_ids 190 | ids_seen_samples = np.where(np.in1d(ys_joint.cpu(), seen_pairs_joint_class_ids[0]))[0] 191 | ids_unseen_samples = np.where(~np.in1d(ys_joint.cpu(), seen_pairs_joint_class_ids[0]))[0] 192 | 193 | # ======== Open-Set Accuracy Metrics ==================== 194 | 195 | # filter-out logits of irrelevant pairs by setting their scores to -inf 196 | logits_ys_open = logits_ys_joint.clone() 197 | logits_ys_open[:, ~flattened_all_open_pairs_mask] = -np.inf 198 | 199 | # Calc standard (imbalanced) accuracy metrics for open set 200 | unseen_open_acc = acc_from_logits(logits_ys_open, ys_joint, ids_unseen_samples) 201 | seen_open_acc = acc_from_logits(logits_ys_open, ys_joint, ids_seen_samples) 202 | open_H_IMB = 2 * (unseen_open_acc * seen_open_acc) / (unseen_open_acc + seen_open_acc + 1e-7) 203 | 204 | # Calc balanced accuracy metrics for open set 205 | num_class_joint_unseen = len(np.unique(ys_joint[ids_unseen_samples].cpu())) 206 | num_class_joint_seen = len(ys_joint[ids_seen_samples].unique()) 207 | pred_open_unseen = logits_ys_open.argmax(1)[ids_unseen_samples] 208 | open_balanced_unseen_acc = per_class_balanced_accuracy(ys_joint[ids_unseen_samples].cpu().numpy(), 209 | pred_open_unseen.cpu().numpy(), num_class_joint_unseen) 210 | pred_open_seen = logits_ys_open.argmax(1)[ids_seen_samples] 211 | open_balanced_seen_acc = per_class_balanced_accuracy(ys_joint[ids_seen_samples].cpu().numpy(), 212 | pred_open_seen.cpu().numpy(), num_class_joint_seen) 213 | # harmonic accuracy 214 | open_H = 2 * (open_balanced_unseen_acc * open_balanced_seen_acc) / ( 215 | open_balanced_unseen_acc + open_balanced_seen_acc) 216 | logits_ys_open = None # release memory 217 | 218 | 219 | # ======== Closed-Set Accuracy Metrics ==================== 220 | 221 | # filter-out logits of irrelevant pairs by setting their scores to -inf 222 | logits_ZS_ys_closed = logits_ys_joint.clone() 223 | logits_ZS_ys_closed[:, ~flattened_closed_unseen_pairs_mask] = -np.inf 224 | 225 | # Calc standard (imbalanced) accuracy metrics for closed set 226 | closed_acc = acc_from_logits(logits_ZS_ys_closed, ys_joint, ids_unseen_samples) 227 | # Calc balanced accuracy metrics for closed set 228 | pred_closed_unseen = logits_ZS_ys_closed.argmax(1)[ids_unseen_samples] 229 | closed_balanced_acc = per_class_balanced_accuracy(ys_joint[ids_unseen_samples].cpu().numpy(), 230 | pred_closed_unseen.cpu().numpy(), num_class_joint_unseen) 231 | closed_balanced_acc_random = calc_random_balanced_baseline_by_logits_neginf(logits_ZS_ys_closed) 232 | pred_closed_unseen, logits_ZS_ys_closed = None, None # release memory 233 | 234 | # ===== Unseen Accuracy metrics for objects (y1) or attributes (y2) ===== 235 | # Note 'unseen' indicates that y1 resides in an unseen *combination* of (y1,y2), not that y1 is unseen 236 | # (and similarly for y2) 237 | pred_y_1_unseen = logits_pred_ys_1.argmax(1)[ids_unseen_samples] 238 | pred_y_2_unseen = logits_pred_ys_2.argmax(1)[ids_unseen_samples] 239 | pred_joint_unseen = logits_ys_joint.argmax(1)[ids_unseen_samples] 240 | y1_acc_unseen = (pred_y_1_unseen == y1[ids_unseen_samples]).sum().float() / len(y1[ids_unseen_samples]) 241 | y2_acc_unseen = (pred_y_2_unseen == y2[ids_unseen_samples]).sum().float() / len(y2[ids_unseen_samples]) 242 | 243 | # balanced accuracy metric. 244 | # NOTE !!: balanced accuracy metric ignores classes that don't participate in the set (relevant to val set) 245 | num_class_y1_unseen = len(y1[ids_unseen_samples].unique()) 246 | y1_balanced_acc_unseen = per_class_balanced_accuracy(y1[ids_unseen_samples], pred_y_1_unseen, num_class_y1_unseen) 247 | y1_balanced_acc_unseen_random = calc_random_balanced_baseline_by_logits_neginf(logits_pred_ys_1) 248 | num_class_y2_unseen = len(y2[ids_unseen_samples].unique()) 249 | y2_balanced_acc_unseen = per_class_balanced_accuracy(y2[ids_unseen_samples], pred_y_2_unseen, num_class_y2_unseen) 250 | y2_balanced_acc_unseen_random = calc_random_balanced_baseline_by_logits_neginf(logits_pred_ys_2) 251 | 252 | ## 253 | # Init a dictionary (preds_dump_dict) that holds variables to dump to file system 254 | preds_dump_dict = slice_dict_to_dict(locals(), 'flattened_seen_pairs_mask, seen_pairs_joint_class_ids, shape_obj_attr, ' 255 | 'y1, y2, ys_joint, fname, ' 256 | 'ids_seen_samples, ids_unseen_samples, ' 257 | 'flattened_closed_unseen_pairs_mask, ' 258 | 'num_class_joint_unseen, num_class_joint_seen, ' 259 | 'flattened_all_open_pairs_mask, pred_open_unseen, pred_open_seen, ' 260 | 'pred_closed_unseen, ' 261 | 'pred_y_1_unseen, pred_y_2_unseen, pred_joint_unseen, num_class_y1_unseen, ' 262 | 'num_class_y2_unseen') 263 | 264 | if calc_AUC: 265 | AUC_open, AUC_open_balanced = calc_ausuc_from_logits(preds_dump_dict, logits_ys_joint, dataset, device) 266 | 267 | results = build_results_dict(locals(), logger_means, phase_name, calc_AUC) 268 | return preds_dump_dict, results 269 | 270 | 271 | def build_results_dict(local_vars_dict, logger_means, phase_name, calc_AUC): 272 | _phase_name = f'_{phase_name}' 273 | results = slice_dict_to_dict( 274 | local_vars_dict, 275 | ['closed_acc', 276 | 'closed_balanced_acc', 277 | 'open_H', 278 | 'open_H_IMB', 279 | 'open_balanced_seen_acc', 280 | 'open_balanced_unseen_acc', 281 | 'seen_open_acc', 282 | 'unseen_open_acc', 283 | 'y1_acc_unseen', 284 | 'y1_balanced_acc_unseen', 285 | 'y1_balanced_acc_unseen_random', 286 | 'y2_acc_unseen', 287 | 'y2_balanced_acc_unseen', 288 | 'y2_balanced_acc_unseen_random'], 289 | returned_keys_postfix=_phase_name) 290 | logger_means_with_postfix = slice_dict_to_dict(logger_means, logger_means.keys(), 291 | returned_keys_postfix=_phase_name) 292 | results.update(logger_means_with_postfix) 293 | # cast torch & numpy scalars to ordinary python scalar (in order to be compatibility with json.dump() ) 294 | results.update(dict((k, v.cpu().detach().item()) for k, v in results.items() if isinstance(v, torch.Tensor))) 295 | results.update(dict((k, v.item()) for k, v in results.items() if isinstance(v, np.number))) 296 | if calc_AUC: 297 | results['AUC_open' + _phase_name] = local_vars_dict['AUC_open'] 298 | results['AUC_open_balanced' + _phase_name] = local_vars_dict['AUC_open_balanced'] 299 | return results 300 | 301 | 302 | 303 | def prepare_logits_for_metrics_calc(logits_ys_joint, dataset): 304 | """ Returns: 305 | (1) logits based scores for objects and for attributes 306 | (2) pairs indexing variables 307 | """ 308 | unflattened_logits = logits_ys_joint.view((-1, dataset.num_objs, dataset.num_attrs)) 309 | logits_pred_ys_1 = unflattened_logits.max(dim=2)[0].detach() 310 | logits_pred_ys_2 = unflattened_logits.max(dim=1)[0].detach() 311 | return logits_pred_ys_1, logits_pred_ys_2 312 | 313 | 314 | def accuracy(predictions: Tensor, labels: Tensor, return_tensor=False): 315 | acc = (predictions == labels).float().mean() 316 | if not return_tensor: 317 | acc = acc.item() 318 | return acc 319 | 320 | 321 | def acc_from_logits(logits: Tensor, labels: Tensor, subset_ids=None, return_tensor=False): 322 | predictions = logits.argmax(1) 323 | if subset_ids is None: 324 | return accuracy(predictions, labels, return_tensor=return_tensor) 325 | else: 326 | if len(subset_ids) == 0: 327 | return np.nan 328 | return accuracy(predictions[subset_ids], labels[subset_ids], return_tensor=return_tensor) 329 | 330 | 331 | def calc_ausuc_from_logits(preds_dump_dict, logits_ys_joint, dataset: CompDataFromDict, device): 332 | # AUSUC when evaluating on open pairs. 333 | # We use the code from https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py (CVPR 2019) 334 | # For that we first represent our logits and labels to match that API 335 | 336 | ys_joint = preds_dump_dict['ys_joint'] 337 | seen_pairs_joint_class_ids = preds_dump_dict['seen_pairs_joint_class_ids'] 338 | flattened_closed_unseen_pairs = preds_dump_dict['flattened_closed_unseen_pairs_mask'] 339 | flattened_all_open_pairs_mask = preds_dump_dict['flattened_all_open_pairs_mask'] 340 | # shape_obj_attr = preds_dump_dict['shape_obj_attr'] 341 | 342 | # filter-in only the logits of the open pairs 343 | ids_open_pairs = np.where(flattened_all_open_pairs_mask)[0] 344 | logits_open = logits_ys_joint[:, ids_open_pairs] 345 | logits_ys_joint = None # release memory 346 | 347 | # get a mapping from indices of cartesian-product for y1,y2 to open-pairs 348 | inv_ids_open_pairs = (0 * flattened_all_open_pairs_mask) # init an array of zeros 349 | inv_ids_open_pairs[ids_open_pairs] = np.array(list(range(len(ids_open_pairs)))) 350 | 351 | if not isinstance(logits_open, torch.Tensor): 352 | logits_open = to_torch(logits_open, device) 353 | logits_open = logits_open.to(device) 354 | 355 | if isinstance(ys_joint, torch.Tensor): 356 | ys_joint = ys_joint.cpu() 357 | 358 | ys_open = to_torch(inv_ids_open_pairs[ys_joint], device) 359 | seen_pairs = inv_ids_open_pairs[seen_pairs_joint_class_ids[0]] 360 | unseen_pairs = inv_ids_open_pairs[np.where(flattened_closed_unseen_pairs)[0]] 361 | 362 | AUSUC_open = calc_cs_ausuc(logits_open, ys_open, seen_pairs, unseen_pairs, use_balanced_accuracy=False) 363 | AUSUC_open_balanced = calc_cs_ausuc(logits_open, ys_open, seen_pairs, unseen_pairs, use_balanced_accuracy=True) 364 | 365 | return AUSUC_open, AUSUC_open_balanced 366 | 367 | 368 | def calc_random_balanced_baseline_by_logits_neginf(logits): 369 | num_active = (logits[0, :] != -np.inf).sum().item() 370 | if num_active == 0: 371 | return np.nan 372 | else: 373 | return 1./num_active 374 | 375 | 376 | 377 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the License 5 | # located at the root directory. 6 | # --------------------------------------------------------------- 7 | 8 | from collections import OrderedDict 9 | from typing import NamedTuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn, optim 14 | from torch.nn.functional import one_hot 15 | 16 | from HSIC import HSIC 17 | from model import get_model 18 | 19 | from useful_utils import to_torch, batch_torch_logger, ns_profiling_label, profileblock, clone_model 20 | from data import get_data_loaders 21 | from pprint import pprint 22 | from eval import evaluation_step, acc_from_logits, eval_model_with_dataloader 23 | 24 | from params import EarlyStopMetric, CommandlineArgs 25 | from model import CompModel 26 | 27 | 28 | def train(args: CommandlineArgs, train_dataset, valid_dataset, test_dataset, writer, model: CompModel = None): 29 | # Init 30 | train_cfg = args.train 31 | 32 | best_metrics = {} 33 | epoch = -1 34 | start_epoch = 0 35 | device = args.device 36 | if len(writer.df) > 0: 37 | start_epoch = writer.df.index.max() 38 | 39 | # Get pytorch data loaders 40 | test_loader, train_loader, valid_loader = get_data_loaders(train_dataset, valid_dataset, test_dataset, 41 | train_cfg.batch_size, train_cfg.num_workers, 42 | test_batchsize=train_cfg.test_batchsize, 43 | shuffle_eval_set=train_cfg.shuffle_eval_set) 44 | 45 | if model is None: 46 | model: CompModel = get_model(args, train_dataset) 47 | best_model = clone_model(model) 48 | 49 | ## NOTE: 50 | # y1 refer to object labels 51 | # y2 refer to attribute labels 52 | num_classes1 = train_dataset.num_objs 53 | num_classes2 = train_dataset.num_attrs 54 | 55 | class NLLLossFuncs(NamedTuple): 56 | y1: nn.NLLLoss 57 | y2: nn.NLLLoss 58 | 59 | nll_loss_funcs = NLLLossFuncs(y1=nn.NLLLoss(), y2=nn.NLLLoss()) 60 | if train_cfg.balanced_loss: 61 | nll_loss_funcs=NLLLossFuncs(y1=nn.NLLLoss(weight=to_torch(1 / train_dataset.y1_freqs, device)), 62 | y2=nn.NLLLoss(weight=to_torch(1 / train_dataset.y2_freqs, device))) 63 | 64 | itr_per_epoch = len(train_loader) 65 | n_epochs = train_cfg.n_iter // itr_per_epoch 66 | 67 | best_primary_metric = np.inf * (2 * (train_cfg.primary_early_stop_metric.polarity == 'min') - 1) 68 | 69 | optimizer = get_optimizer(train_cfg.optimizer_name, train_cfg.lr, train_cfg.weight_decay, model, args) 70 | 71 | epoch_range = range(start_epoch + 1, start_epoch + n_epochs + 1) 72 | data_iterator = iter(train_loader) 73 | 74 | for epoch in epoch_range: 75 | with profileblock(label='Epoch train step'): 76 | # Select which tensors to log. Taking an average on all batches per epoch. 77 | logger = batch_torch_logger(num_batches=len(train_loader), 78 | cs_str_args='y1_loss, y2_loss, y_loss, ' 79 | 'L_rep, ' 80 | 'y1_acc, y2_acc, ' 81 | 'HSIC_cond1, HSIC_cond2, ' 82 | 'pairwise_dist_cond1_repr1, ' 83 | 'pairwise_dist_cond1_repr2, ' 84 | 'pairwise_dist_cond2_repr1, ' 85 | 'pairwise_dist_cond2_repr2, ' 86 | 'HSIC_label_cond1, HSIC_label_cond2', 87 | nanmean_args_cs_str = 'pairwise_dist_cond1_repr1, ' 88 | 'pairwise_dist_cond1_repr2, ' 89 | 'pairwise_dist_cond2_repr1, ' 90 | 'pairwise_dist_cond2_repr2, ' 91 | 'tloss_a, tloss_o, tloss_g_imgfeat, ' 92 | 'loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat', 93 | device=device 94 | ) 95 | 96 | 97 | for batch_cnt in range(len(train_loader)): 98 | logger.new_batch() 99 | 100 | optimizer.zero_grad() 101 | 102 | with ns_profiling_label('fetch batch'): 103 | try: 104 | batch = next(data_iterator) 105 | except StopIteration: 106 | data_iterator = iter(train_loader) 107 | batch = next(data_iterator) 108 | 109 | with ns_profiling_label('send to gpu'): 110 | X, y2, y1 = batch[0], batch[1], batch[2] 111 | neg_attrs, neg_objs = batch[3].to(device), batch[4].to(device) 112 | X = X.float().to(device) # images 113 | y1 = y1.long().to(device) # object labels 114 | y2 = y2.long().to(device) # attribute labels 115 | 116 | with ns_profiling_label('forward pass'): 117 | # y1_scores, y2_scores are logits of negative-squared-distances at the embedding space 118 | # repr1, repr2 are phi_hat1, phi_hat2 at the paper 119 | y1_scores, y2_scores, repr1, repr2, _ = \ 120 | model(X, freeze_class1=train_cfg.freeze_class1, 121 | freeze_class2=train_cfg.freeze_class2) 122 | 123 | y1_loss = nll_loss_funcs.y1(y1_scores, y1) 124 | y2_loss = nll_loss_funcs.y2(y2_scores, y2) 125 | y_loss = y1_loss * train_cfg.Y12_balance_coeff + y2_loss * (1 - train_cfg.Y12_balance_coeff) 126 | L_data = train_cfg.lambda_CE * y_loss 127 | 128 | L_invert = 0. 129 | if not args.model.VisProd: 130 | # pair embedding losses 131 | tloss_g_hidden, tloss_g_imgfeat, loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat = \ 132 | model.eval_pair_embed_losses(args, X, model.last_feature_common, y2, y1, neg_attrs, 133 | neg_objs, nll_loss_funcs) 134 | 135 | # aggregate triplet loss into L_data 136 | L_data += train_cfg.lambda_ao_emb * tloss_g_hidden 137 | L_data += train_cfg.lambda_feat * tloss_g_imgfeat 138 | 139 | # aggregate components of L_invert 140 | L_invert += train_cfg.lambda_aux_disjoint * loss_inv_core 141 | L_invert += train_cfg.lambda_aux * loss_inv_g_hidden 142 | L_invert += train_cfg.lambda_aux_img * loss_inv_g_imgfeat 143 | 144 | 145 | ys = (y1, y2) 146 | L_rep, HSIC_rep_loss_terms, HSIC_mean_of_median_pairwise_dist_terms = \ 147 | conditional_indep_losses(repr1, repr2, ys, train_cfg.HSIC_coeff, indep_coeff2=train_cfg.HSIC_coeff, 148 | num_classes1=num_classes1, 149 | num_classes2=num_classes2, log_median_pairwise_distance=False, 150 | device=device) 151 | 152 | ohy1 = one_hot(y1, num_classes1) 153 | ohy2 = one_hot(y2, num_classes2) 154 | L_oh1, HSIC_oh_loss_terms1, _ = \ 155 | conditional_indep_losses(ohy2, repr1, ys, train_cfg.alphaH, indep_coeff2=0, num_classes1=num_classes1, 156 | num_classes2=num_classes2, log_median_pairwise_distance=False, 157 | device=device) 158 | 159 | L_oh2, HSIC_oh_loss_terms2, _ = \ 160 | conditional_indep_losses(ohy1, repr2, ys, 0, indep_coeff2=train_cfg.alphaH, num_classes1=num_classes1, 161 | num_classes2=num_classes2, log_median_pairwise_distance=False, 162 | device=device) 163 | 164 | L_indep = L_rep + L_oh1 + L_oh2 165 | 166 | loss = L_data + L_indep + L_invert 167 | 168 | 169 | with ns_profiling_label('loss and update'): 170 | loss.backward() 171 | optimizer.step() 172 | 173 | # log the metrics 174 | with ns_profiling_label('log batch'): 175 | 176 | # extract indep loss terms from lists for logging 177 | HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \ 178 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2 = \ 179 | HSIC_logging_terms(HSIC_rep_loss_terms, HSIC_mean_of_median_pairwise_dist_terms) 180 | 181 | HSIC_label_cond1 = HSIC_oh_loss_terms1[0] 182 | HSIC_label_cond2 = HSIC_oh_loss_terms2[1] 183 | 184 | with ns_profiling_label('calc y1 train acc'): 185 | y1_acc = acc_from_logits(y1_scores, y1, return_tensor=True).detach() 186 | with ns_profiling_label('calc y2 train acc'): 187 | y2_acc = acc_from_logits(y2_scores, y2, return_tensor=True).detach() 188 | 189 | logger.log(locals_dict=locals()) 190 | 191 | curr_epoch_metrics = OrderedDict() 192 | curr_epoch_metrics.update(logger.get_means()) 193 | with profileblock(label='Evaluation step'): 194 | 195 | best_primary_metric, is_best = evaluation_step(model, valid_loader, test_loader, nll_loss_funcs, writer, epoch, 196 | n_epochs, curr_epoch_metrics, 197 | early_stop_metric_name=train_cfg.primary_early_stop_metric, 198 | best_ES_metric_value=best_primary_metric, 199 | calc_AUC=train_cfg.metrics.calc_AUC) 200 | 201 | # write current epoch metrics to metrics logger 202 | with ns_profiling_label('write eval step metrics'): 203 | for metric_key, value in curr_epoch_metrics.items(): 204 | writer.add_scalar(f'{metric_key}', value, epoch) 205 | # dump collected metrics to csv 206 | writer.dump_to_csv() 207 | 208 | # print all columns 209 | last_results_as_string = writer.last_results_as_string() 210 | last_results_as_string = '\n '.join(last_results_as_string.split('\n')) 211 | if train_cfg.verbose: 212 | print('\n[%d/%d]' % (epoch, n_epochs), last_results_as_string) 213 | 214 | if is_best: 215 | best_model = clone_model(model) 216 | best_metrics = writer.df.iloc[-1, :].to_dict() 217 | best_metrics['epoch'] = int(writer.df.iloc[-1, :].name) 218 | if train_cfg.verbose: 219 | print(f'Best! (@epoch {epoch})') 220 | 221 | 222 | model = best_model 223 | model.eval() 224 | 225 | print('Best epoch was: ', best_metrics['epoch']) 226 | print(f'Primary early stop monitor was {train_cfg.primary_early_stop_metric}') 227 | 228 | val_metrics = eval_model_with_dataloader(model, valid_loader, nll_loss_funcs, phase_name='valid') 229 | best_metrics.update(val_metrics) 230 | print('Val metrics on best val epoch :') 231 | pprint([(k, v) for k, v in val_metrics.items()]) 232 | 233 | test_metrics = eval_model_with_dataloader(model, test_loader, nll_loss_funcs, phase_name='test') 234 | best_metrics.update(test_metrics) 235 | print('\n\nTest metrics on best val epoch :') 236 | pprint([(k, v) for k, v in test_metrics.items()]) 237 | 238 | # cast numpy items to their original type 239 | for k, v in best_metrics.items(): 240 | if isinstance(v, np.number): 241 | best_metrics[k] = v.item() 242 | 243 | #### two redundant calls to align random-number-generator with original training script 244 | # check: (to delete?) 245 | _ = eval_model_with_dataloader(model, valid_loader, nll_loss_funcs, phase_name='valid') 246 | _ = eval_model_with_dataloader(model, test_loader, nll_loss_funcs, phase_name='test') 247 | 248 | 249 | return model, best_metrics 250 | 251 | 252 | def alternate_training(args, train_dataset, valid_dataset, test_dataset, writer): 253 | train_cfg = args.train 254 | ### alternate between heads ### 255 | ## train first head # 256 | # Save 'HSIC_coeff' for usage during step2. For step1, set HSIC_coeff to 0, 257 | HSIC_coeff_step2 = train_cfg.HSIC_coeff 258 | train_cfg.HSIC_coeff = 0 259 | if train_cfg.alternate_ys == 12: 260 | train_cfg.Y12_balance_coeff = 1 261 | train_cfg.primary_early_stop_metric = EarlyStopMetric('y1_balanced_acc_unseen_valid', 'max') 262 | train_cfg.freeze_class1 = False 263 | train_cfg.freeze_class2 = True 264 | elif train_cfg.alternate_ys == 21: 265 | train_cfg.Y12_balance_coeff = 0 266 | train_cfg.primary_early_stop_metric = EarlyStopMetric('y2_balanced_acc_unseen_valid', 'max') 267 | train_cfg.freeze_class1 = True 268 | train_cfg.freeze_class2 = False 269 | else: 270 | raise ValueError("train_cfg.alternate_ys = ", train_cfg.alternate_ys) 271 | print( 272 | f"first iter ay={train_cfg.alternate_ys}, primary_early_stop_metric={train_cfg.primary_early_stop_metric.metric}") 273 | model, best_metrics_dict1 = train(args, train_dataset, valid_dataset, test_dataset, writer) 274 | print('step1 best metrics epoch = ', best_metrics_dict1['epoch']) 275 | 276 | ## train 2nd head # 277 | model.train() 278 | train_cfg.set_n_iter(len(train_dataset.data), (1 + train_cfg.max_epoch_step2)) 279 | train_cfg.lr = train_cfg.lr_step2 280 | if train_cfg.alternate_ys == 12: 281 | train_cfg.Y12_balance_coeff = 0 282 | train_cfg.primary_early_stop_metric = EarlyStopMetric('epoch', 'max') 283 | train_cfg.freeze_class1 = True 284 | train_cfg.freeze_class2 = False 285 | elif train_cfg.alternate_ys == 21: 286 | train_cfg.Y12_balance_coeff = 1 287 | train_cfg.primary_early_stop_metric = EarlyStopMetric('epoch', 'max') 288 | train_cfg.freeze_class1 = False 289 | train_cfg.freeze_class2 = True 290 | train_cfg.HSIC_coeff = HSIC_coeff_step2 291 | train(args, train_dataset, valid_dataset, test_dataset, writer, model=model) 292 | 293 | 294 | 295 | def get_optimizer(optimizer_name, lr, weight_decay, model, args: CommandlineArgs): 296 | """ returns an optimizer instance """ 297 | 298 | # list the weights to optimize 299 | obj_related_weights = list(model.g_inv_O.parameters()) + list(model.emb_cf_O.parameters()) 300 | attr_related_weights = list(model.g_inv_A.parameters()) + list(model.emb_cf_A.parameters()) 301 | pair_related_weights = [] 302 | if not args.model.VisProd: 303 | obj_related_weights += list(model.obj_inv_core_logits.parameters()) 304 | attr_related_weights += list(model.attr_inv_core_logits.parameters()) 305 | 306 | obj_related_weights += list(model.obj_inv_g_hidden_logits.parameters()) 307 | attr_related_weights += list(model.attr_inv_g_hidden_logits.parameters()) 308 | 309 | obj_related_weights += list(model.obj_inv_g_imgfeat_logits.parameters()) 310 | attr_related_weights += list(model.attr_inv_g_imgfeat_logits.parameters()) 311 | 312 | pair_related_weights = list(model.g1_emb_to_hidden_feat.parameters()) + list(model.g2_feat_to_image_feat.parameters()) 313 | all_weights = obj_related_weights + attr_related_weights + pair_related_weights + list(model.ECommon.parameters()) 314 | 315 | # set optimizer hyper param 316 | optimizer_kwargs = dict(lr=lr, weight_decay=weight_decay) 317 | if optimizer_name.lower() == 'nest': 318 | optimizer_kwargs.update(momentum=0.9, nesterov=True) 319 | 320 | # choose optimizer 321 | optimizer_class = dict(adam=optim.Adam, sgd=optim.SGD, nest=optim.SGD)[optimizer_name.lower()] 322 | 323 | # initialize an optimizer instance 324 | optimizer = optimizer_class(all_weights, **optimizer_kwargs) 325 | 326 | return optimizer 327 | 328 | 329 | def conditional_indep_losses(repr1, repr2, ys, indep_coeff, indep_coeff2=None, num_classes1=None, num_classes2=None, 330 | Hkernel='L', Hkernel_sigma_obj=None, Hkernel_sigma_attr=None, 331 | log_median_pairwise_distance=False, device=None): 332 | # check readability 333 | 334 | normalize_to_mean = (num_classes1, num_classes2) 335 | 336 | if indep_coeff2 is None: 337 | indep_coeff2 = indep_coeff 338 | 339 | HSIC_loss_terms = [] 340 | HSIC_mean_of_median_pairwise_dist_terms = [] 341 | with ns_profiling_label('HSIC/d loss calc'): 342 | # iterate on both heads 343 | for m, num_class in enumerate((num_classes1, num_classes2)): 344 | with ns_profiling_label(f'iter m={m}'): 345 | HSIC_tmp_loss = 0. 346 | HSIC_median_pw_y1 = [] 347 | HSIC_median_pw_y2 = [] 348 | 349 | labels_in_batch_sorted, indices = torch.sort(ys[m]) 350 | unique_ixs = 1 + (labels_in_batch_sorted[1:] - labels_in_batch_sorted[:-1]).nonzero() 351 | unique_ixs = [0] + unique_ixs.flatten().cpu().numpy().tolist() + [len(ys[m])] 352 | 353 | for j in range(len(unique_ixs)-1): 354 | current_class_indices = unique_ixs[j], unique_ixs[j + 1] 355 | count = current_class_indices[1] - current_class_indices[0] 356 | if count < 2: 357 | continue 358 | curr_class_slice = slice(*current_class_indices) 359 | curr_class_indices = indices[curr_class_slice].sort()[0] 360 | 361 | with ns_profiling_label(f'iter j={j}'): 362 | HSIC_kernel = dict(G='Gaussian', L='Linear')[Hkernel] 363 | with ns_profiling_label('HSIC call'): 364 | hsic_loss_i, median_pairwise_distance_y1, median_pairwise_distance_y2 = \ 365 | HSIC(repr1[curr_class_indices, :].float(), repr2[curr_class_indices, :].float(), 366 | kernelX=HSIC_kernel, kernelY=HSIC_kernel, 367 | sigmaX=Hkernel_sigma_obj, sigmaY=Hkernel_sigma_attr, 368 | log_median_pairwise_distance=log_median_pairwise_distance) 369 | HSIC_tmp_loss += hsic_loss_i 370 | HSIC_median_pw_y1.append(median_pairwise_distance_y1) 371 | HSIC_median_pw_y2.append(median_pairwise_distance_y2) 372 | 373 | HSIC_tmp_loss = HSIC_tmp_loss / normalize_to_mean[m] 374 | HSIC_loss_terms.append(HSIC_tmp_loss) 375 | HSIC_mean_of_median_pairwise_dist_terms.append([np.mean(HSIC_median_pw_y1), np.mean(HSIC_median_pw_y2)]) 376 | 377 | indep_loss = torch.tensor(0.).to(device) 378 | if indep_coeff > 0: 379 | indep_loss = (indep_coeff * HSIC_loss_terms[0] + indep_coeff2 * HSIC_loss_terms[1]) / 2 380 | return indep_loss, HSIC_loss_terms, HSIC_mean_of_median_pairwise_dist_terms 381 | 382 | 383 | 384 | 385 | def HSIC_logging_terms(HSIC_loss_terms, HSIC_mean_of_median_pairwise_dist_terms): 386 | """ This is just a utility function for naming monitored values of HSIC loss """ 387 | HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \ 388 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2 = [np.nan] * 6 389 | 390 | if HSIC_loss_terms: 391 | HSIC_cond1 = HSIC_loss_terms[0] 392 | HSIC_cond2 = HSIC_loss_terms[1] 393 | pairwise_dist_cond1_repr1 = HSIC_mean_of_median_pairwise_dist_terms[0][0] 394 | pairwise_dist_cond1_repr2 = HSIC_mean_of_median_pairwise_dist_terms[0][1] 395 | pairwise_dist_cond2_repr1 = HSIC_mean_of_median_pairwise_dist_terms[1][0] 396 | pairwise_dist_cond2_repr2 = HSIC_mean_of_median_pairwise_dist_terms[1][1] 397 | 398 | return HSIC_cond1, HSIC_cond2, pairwise_dist_cond1_repr1, pairwise_dist_cond1_repr2, \ 399 | pairwise_dist_cond2_repr1, pairwise_dist_cond2_repr2 400 | -------------------------------------------------------------------------------- /taskmodularnets/LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /COSMO_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from the following repository: 5 | # https://github.com/yuvalatzmon/COSMO/ 6 | # 7 | # The license for the original version of this file can be 8 | # found in this directory (LICENSE_COSMO). The modifications 9 | # to this file are subject to the License 10 | # located at the root directory. 11 | # --------------------------------------------------------------- 12 | 13 | 14 | 15 | import subprocess 16 | import sys 17 | from contextlib import contextmanager 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | from sklearn.metrics import auc 23 | 24 | 25 | def _AUSUC(Acc_tr, Acc_ts): 26 | 27 | """ Calc area under seen-unseen curve 28 | Source: https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py 29 | """ 30 | 31 | # Sort by X axis 32 | X_sorted_arg = np.argsort(Acc_tr) 33 | sorted_X = np.array(Acc_tr)[X_sorted_arg] 34 | sorted_Y = np.array(Acc_ts)[X_sorted_arg] 35 | 36 | # zero pad 37 | leftmost_X, leftmost_Y = 0, sorted_Y[0] 38 | rightmost_X, rightmost_Y = sorted_X[-1], 0 39 | sorted_X = np.block([np.array([leftmost_X]), sorted_X, np.array([rightmost_X])]) 40 | sorted_Y = np.block([np.array([leftmost_Y]), sorted_Y, np.array([rightmost_Y])]) 41 | 42 | # eval AUC 43 | AUSUC = auc(sorted_X, sorted_Y) 44 | 45 | return AUSUC 46 | 47 | 48 | def calc_cs_ausuc(pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy=True, gamma_range=None, verbose=True): 49 | """ Calc area under seen-unseen curve, according to calibrated stacking (Chao et al. 2016) 50 | Adapted from: https://github.com/yuvalatzmon/COSMO/blob/master/src/metrics.py 51 | 52 | NOTE: pred cannot accept -np.inf values 53 | """ 54 | 55 | assert(pred.min() != -np.inf) 56 | 57 | # make numbers positive with min = 0 58 | # pred = pred.copy() - pred.min() 59 | 60 | if gamma_range is None: 61 | # make a log spaced search range 62 | gamma = abs((pred[:, unseen_classses].max(dim=1)[0] - pred[:, seen_classses].max(dim=1)[0]).min()) + abs( 63 | (pred[:, unseen_classses].max(dim=1)[0] - pred[:, seen_classses].max(dim=1)[0]).max()) 64 | gamma = gamma.item() 65 | pos_range = (-(np.logspace(0, np.log10(gamma + 1), 50) - 1)).tolist() 66 | neg_range = (np.logspace(0, np.log10(gamma +1), 50) - 1).tolist() 67 | gamma_range = sorted(pos_range + neg_range) 68 | 69 | # torch.cuda.empty_cache() 70 | Acc_tr_values, Acc_ts_values = calc_acc_tr_ts_over_gamma_range(gamma_range, pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy) 71 | # torch.cuda.empty_cache() 72 | 73 | cs_ausuc = _AUSUC(Acc_tr=Acc_tr_values, Acc_ts=Acc_ts_values) 74 | if min(Acc_tr_values) > 0.01: 75 | print(f'CS AUSUC ERROR: Increase gamma range (add low values), because min(Acc_tr_values) equals {min(Acc_tr_values)}') 76 | return np.nan 77 | if min(Acc_ts_values) > 0.01: 78 | print(f'CS AUSUC ERROR: Increase gamma range (add high values), because min(Acc_ts_values) equals {min(Acc_ts_values)}') 79 | return np.nan 80 | if verbose: 81 | print('AUSUC: max(acc_seen) = ', max(Acc_tr_values)) 82 | print('AUSUC: max(acc_unseen)', max(Acc_ts_values)) 83 | print(f'AUSUC (by Calibrated Stacking) = {cs_ausuc:.3f}') 84 | return cs_ausuc 85 | 86 | 87 | def calc_acc_tr_ts_over_gamma_range(gamma_range, pred, y_gt, seen_classses, unseen_classses, use_balanced_accuracy): 88 | Acc_tr_values = [] 89 | Acc_ts_values = [] 90 | 91 | for gamma in gamma_range: 92 | params = gamma, pred, seen_classses, unseen_classses, use_balanced_accuracy, y_gt 93 | Acc_tr, Acc_ts = cs_at_single_operating_point(params) 94 | Acc_tr_values.append(Acc_tr) 95 | Acc_ts_values.append(Acc_ts) 96 | 97 | return Acc_tr_values, Acc_ts_values 98 | 99 | 100 | def cs_at_single_operating_point(params): 101 | gamma, pred, seen_classses, unseen_classses, use_balanced_accuracy, y_gt = params 102 | if isinstance(pred, torch.Tensor): 103 | cs_pred = pred.clone() 104 | else: 105 | cs_pred = pred.copy() 106 | 107 | cs_pred[:, seen_classses] -= gamma 108 | zs_metrics = ZSL_Metrics(seen_classses, unseen_classses) 109 | Acc_ts, Acc_tr, H = zs_metrics.generlized_scores(y_gt.cpu().numpy(), cs_pred, num_class_according_y_true=True, 110 | use_balanced_accuracy=use_balanced_accuracy) 111 | cs_pred = None # release memory 112 | # torch.cuda.empty_cache() # release memory 113 | 114 | return Acc_tr, Acc_ts 115 | 116 | 117 | class ZSL_Metrics(): 118 | def __init__(self, seen_classes, unseen_classes, report_entropy=False): 119 | self._seen_classes = np.sort(seen_classes) 120 | self._unseen_classes = np.sort(unseen_classes) 121 | self._n_seen = len(seen_classes) 122 | self._n_unseen = len(unseen_classes) 123 | self._report_entropy = report_entropy 124 | 125 | assert(self._n_seen == len(np.unique(seen_classes))) # sanity check 126 | assert(self._n_unseen == len(np.unique(unseen_classes))) # sanity check 127 | 128 | 129 | def unseen_balanced_accuracy(self, y_true, pred_softmax): 130 | Acc_zs, Ent_zs = self._subset_classes_balanced_accuracy(y_true, pred_softmax, 131 | self._unseen_classes) 132 | if self._report_entropy: 133 | return Acc_zs, Ent_zs 134 | else: 135 | return Acc_zs 136 | 137 | def seen_balanced_accuracy(self, y_true, pred_softmax): 138 | Acc_seen, Ent_seen = self._subset_classes_balanced_accuracy(y_true, 139 | pred_softmax, 140 | self._seen_classes) 141 | if self._report_entropy: 142 | return Acc_seen, Ent_seen 143 | else: 144 | return Acc_seen 145 | 146 | def generlized_scores(self, y_true_cpu, pred_softmax, num_class_according_y_true=True, use_balanced_accuracy=True): 147 | 148 | Acc_ts, Ent_ts = self._generalized_unseen_accuracy(y_true_cpu, pred_softmax, num_class_according_y_true, 149 | use_balanced_accuracy) 150 | Acc_tr, Ent_tr = self._generalized_seen_accuracy(y_true_cpu, pred_softmax, num_class_according_y_true, 151 | use_balanced_accuracy) 152 | H = 2*Acc_tr*Acc_ts/(Acc_tr + Acc_ts + 1e-8) 153 | Ent_H = 2*Ent_tr*Ent_ts/(Ent_tr + Ent_ts + 1e-8) 154 | 155 | if self._report_entropy: 156 | return Acc_ts, Acc_tr, H, Ent_ts, Ent_tr, Ent_H 157 | else: 158 | return Acc_ts, Acc_tr, H 159 | 160 | def _generalized_unseen_accuracy(self, y_true_cpu, pred_softmax, num_class_according_y_true, use_balanced_accuracy): 161 | return self._generalized_subset_accuracy(y_true_cpu, pred_softmax, 162 | self._unseen_classes, num_class_according_y_true, use_balanced_accuracy) 163 | 164 | def _generalized_seen_accuracy(self, y_true_cpu, pred_softmax, num_class_according_y_true, use_balanced_accuracy): 165 | return self._generalized_subset_accuracy(y_true_cpu, pred_softmax, 166 | self._seen_classes, num_class_according_y_true, use_balanced_accuracy) 167 | 168 | def _generalized_subset_accuracy(self, y_true_cpu, pred_softmax, subset_classes, num_class_according_y_true, 169 | use_balanced_accuracy): 170 | is_member = np.in1d # np.in1d is like MATLAB's ismember 171 | ix_subset_samples = is_member(y_true_cpu, subset_classes) 172 | 173 | y_true_subset = y_true_cpu[ix_subset_samples] 174 | all_classes = np.sort(np.block([self._seen_classes, self._unseen_classes])) 175 | 176 | if isinstance(pred_softmax, torch.Tensor): 177 | amax = (pred_softmax[:, all_classes]).argmax(1).cpu() 178 | else: 179 | amax = (pred_softmax[:, all_classes]).argmax(1) 180 | 181 | y_pred = all_classes[amax] 182 | y_pred_subset = y_pred[ix_subset_samples] 183 | 184 | if use_balanced_accuracy: 185 | 186 | num_class = len(subset_classes) 187 | # infer number of classes according to unique(y_true_subset) 188 | if num_class_according_y_true: 189 | num_class = len(np.unique(y_true_subset)) 190 | Acc = float(per_class_balanced_accuracy(y_true_subset, y_pred_subset, num_class)) 191 | else: 192 | Acc = (y_true_subset == y_pred_subset).mean() 193 | 194 | # Ent = float(entropy2(pred_softmax[ix_subset_samples, :][:, all_classes]).mean()) 195 | Ent = 0*Acc + 1e-3 # disabled because its too slow 196 | return Acc, Ent 197 | 198 | def _subset_classes_balanced_accuracy(self, y_true, pred_softmax, subset_classes): 199 | is_member = np.in1d # np.in1d is like MATLAB's ismember 200 | ix_subset_samples = is_member(y_true, subset_classes) 201 | 202 | y_true_zs = y_true[ix_subset_samples] 203 | y_pred = subset_classes[(pred_softmax[:, subset_classes]).argmax(dim=1)] 204 | y_pred_zs = y_pred[ix_subset_samples] 205 | 206 | Acc = float(per_class_balanced_accuracy(y_true_zs, y_pred_zs, len(subset_classes))) 207 | # Ent = float(entropy2(pred_softmax[:, subset_classes]).mean()) 208 | Ent = 0*Acc + 1e-3 # disabled because its too slow 209 | return Acc, Ent 210 | 211 | 212 | def per_class_balanced_accuracy(y_true, y_pred, num_class=None): 213 | """ A balanced accuracy metric as in Xian (CVPR 2017). Accuracy is 214 | evaluated individually per class, and then uniformly averaged between 215 | classes. 216 | """ 217 | if len(y_true) == 0 or num_class == 0: 218 | return np.nan 219 | 220 | if isinstance(y_true, torch.Tensor): 221 | y_true = y_true.flatten().cpu().numpy().astype('int32') 222 | else: 223 | y_true = y_true.flatten().astype('int32') 224 | 225 | if isinstance(y_pred, torch.Tensor): 226 | y_pred = y_pred.cpu().numpy() 227 | 228 | 229 | if num_class is None: 230 | num_class = len(np.unique(np.block([y_true, y_pred]))) 231 | 232 | max_class_id = 1+max([num_class, y_true.max(), y_pred.max()]) 233 | 234 | counts_per_class_s = pd.Series(y_true).value_counts() 235 | counts_per_class = np.zeros((max_class_id,)) 236 | counts_per_class[counts_per_class_s.index] = counts_per_class_s.values 237 | 238 | accuracy = (1.*(y_pred == y_true) / counts_per_class[y_true]).sum() / num_class 239 | return accuracy.astype('float32') 240 | 241 | 242 | 243 | def run_bash(cmd, raise_on_err=True, raise_on_warning=False, versbose=True, return_exist_code=False, err_ind_by_exitcode=False): 244 | """ This function takes Bash commands and return their stdout 245 | Returns: string (stdout) 246 | :type cmd: string 247 | """ 248 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, 249 | stderr=subprocess.PIPE, executable='/bin/bash') 250 | # p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) 251 | out, err = p.communicate() 252 | out = out.strip().decode('utf-8') 253 | err = err.strip().decode('utf-8') 254 | exit_code = p.returncode 255 | is_err = err != '' 256 | if err_ind_by_exitcode: 257 | is_err = (exit_code != 0) 258 | 259 | if is_err and raise_on_err: 260 | do_raise = True 261 | if 'warning' in err.lower(): 262 | do_raise = raise_on_warning 263 | if versbose and not raise_on_warning: 264 | print('command was: {}'.format(cmd)) 265 | print(err, file=sys.stderr) 266 | 267 | if do_raise or 'error' in err.lower(): 268 | if versbose: 269 | print('command was: {}'.format(cmd)) 270 | raise RuntimeError(err) 271 | 272 | if return_exist_code: 273 | return out, exit_code 274 | else: 275 | return out # This is the stdout from the shell command 276 | 277 | 278 | @contextmanager 279 | def temporary_random_numpy_seed(seed) -> object: 280 | """ From https://github.com/yuvalatzmon/COSMO/blob/master/src/utils/ml_utils.py#L701 281 | A context manager for a temporary random seed (only within context) 282 | When leaving the context the numpy random state is restored 283 | 284 | This function is inspired by http://stackoverflow.com/q/32679403, which is shared according to https://creativecommons.org/licenses/by-sa/3.0/ 285 | 286 | License 287 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. 288 | 289 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. 290 | 291 | 1. Definitions 292 | 293 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License. 294 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License. 295 | "Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License. 296 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. 297 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike. 298 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. 299 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. 300 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work. 301 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 302 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. 303 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium. 304 | 2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. 305 | 306 | 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 307 | 308 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections; 309 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified."; 310 | to Distribute and Publicly Perform the Work including as incorporated in Collections; and, 311 | to Distribute and Publicly Perform Adaptations. 312 | For the avoidance of doubt: 313 | 314 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 315 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and, 316 | Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License. 317 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved. 318 | 319 | 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 320 | 321 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested. 322 | You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. 323 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. 324 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. 325 | 5. Representations, Warranties and Disclaimer 326 | 327 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. 328 | 329 | 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 330 | 331 | 7. Termination 332 | 333 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 334 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 335 | 8. Miscellaneous 336 | 337 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 338 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 339 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 340 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 341 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. 342 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. 343 | 344 | 345 | 346 | 347 | """ 348 | state = np.random.get_state() 349 | np.random.seed(seed) 350 | yield None 351 | np.random.set_state(state) 352 | 353 | 354 | --------------------------------------------------------------------------------