├── CUB ├── __init__.py ├── config.py ├── utils.py ├── hyperparam_checking.py ├── README.md ├── models.py ├── hyperopt.py ├── data_processing.py ├── postprocessing.py ├── dataset.py ├── inference.py └── train.py ├── SYNTHETIC ├── config.py ├── backbone.py ├── README.md ├── models.py ├── template_model.py ├── dataset.py ├── data.py ├── train.py └── inference.py ├── figures ├── cub_main_result.png ├── skincon_main_result.png └── synthetic_main_result.png ├── SKINCON ├── .gitignore ├── utils.py ├── README.md ├── backbone.py ├── models.py ├── config.py ├── data_processing.py ├── generate_new_data.py ├── dataset.py └── inference.py ├── .gitignore ├── requirements.txt ├── README.md ├── experiments.py └── analysis.py /CUB/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SYNTHETIC/config.py: -------------------------------------------------------------------------------- 1 | BASE_DIR='' 2 | N_CLASSES=10 3 | N_ATTRIBUTES=10 -------------------------------------------------------------------------------- /figures/cub_main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbin4/Closer-Intervention-CBM/HEAD/figures/cub_main_result.png -------------------------------------------------------------------------------- /figures/skincon_main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbin4/Closer-Intervention-CBM/HEAD/figures/skincon_main_result.png -------------------------------------------------------------------------------- /figures/synthetic_main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbin4/Closer-Intervention-CBM/HEAD/figures/synthetic_main_result.png -------------------------------------------------------------------------------- /SKINCON/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .vscode/ 4 | output/ 5 | model/ 6 | scripts/ 7 | *.sh 8 | *.out 9 | wandb/ 10 | *.txt 11 | pretrained/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pkl 2 | *out 3 | .ipynb_checkpoints/ 4 | *pyc 5 | .idea/ 6 | wandb/ 7 | CUB_200_2011/ 8 | CUB_processed/ 9 | PredConcepts/ 10 | __pycache__/ 11 | .vscode/ 12 | output/ 13 | model/ 14 | scripts/ 15 | *.sh 16 | *.out 17 | *.txt 18 | !requirements.txt 19 | data/ 20 | *.pth 21 | .vscode 22 | graph 23 | pretrained/ -------------------------------------------------------------------------------- /SKINCON/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def seed_everything(seed=42): 8 | random.seed(seed) 9 | os.environ['PYTHONHASHSEED'] = str(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False -------------------------------------------------------------------------------- /CUB/config.py: -------------------------------------------------------------------------------- 1 | # General 2 | BASE_DIR = '' 3 | N_ATTRIBUTES = 312 4 | N_CLASSES = 200 5 | 6 | GROUP_DICT = {0: [0, 1, 2, 3], 1: [4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, 15], 3: [16, 17, 18, 19, 20, 21], 4: [22, 23, 24], 5: [25, 26, 27, 28, 29, 30], 6: [31], 7: [32, 33, 34, 35, 36], 8: [37, 38], 9: [39, 40, 41, 42, 43, 44], 10: [45, 46, 47, 48, 49], 11: [50], 12: [51, 52], 13: [53, 54, 55, 56, 57, 58], 14: [59, 60, 61, 62, 63], 15: [64, 65, 66, 67, 68, 69], 16: [70, 71, 72, 73, 74, 75], 17: [76, 77], 18: [78, 79, 80], 19: [81, 82], 20: [83, 84, 85], 21: [86, 87, 88], 22: [89], 23: [90, 91, 92, 93, 94, 95], 24: [96, 97, 98], 25: [99, 100, 101], 26: [102, 103, 104, 105, 106, 107], 27: [108, 109, 110, 111]} 7 | 8 | # Training 9 | UPWEIGHT_RATIO = 9.0 10 | MIN_LR = 0.0001 11 | LR_DECAY_SIZE = 0.1 12 | -------------------------------------------------------------------------------- /SYNTHETIC/backbone.py: -------------------------------------------------------------------------------- 1 | from SYNTHETIC.template_model import MLP, End2EndModel 2 | 3 | 4 | # Independent & Sequential Model 5 | def ModelXtoC(input_dim, n_attributes, expand_dim): 6 | return MLP(input_dim=input_dim, num_classes=n_attributes, expand_dim=expand_dim) 7 | 8 | # Independent Model 9 | def ModelOracleCtoY(n_attributes, num_classes, expand_dim): 10 | # X -> C part is separate, this is only the C -> Y part 11 | model = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 12 | return model 13 | 14 | # Sequential Model 15 | def ModelXtoChat_ChatToY(n_attributes, num_classes, expand_dim): 16 | # X -> C part is separate, this is only the C -> Y part (same as Independent model) 17 | return ModelOracleCtoY(n_attributes, num_classes, expand_dim) 18 | 19 | # Joint Model 20 | def ModelXtoCtoY(input_dim, num_classes, n_attributes, expand_dim, 21 | use_relu, use_sigmoid): 22 | model1 = MLP(input_dim=input_dim, num_classes=n_attributes, expand_dim=expand_dim) 23 | model2 = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 24 | return End2EndModel(model1, model2, n_attributes, use_relu, use_sigmoid) -------------------------------------------------------------------------------- /SYNTHETIC/README.md: -------------------------------------------------------------------------------- 1 | # Generating the data 2 | 3 | To generate the synthetic data with input noise, use the following command. 4 | 5 | ``` 6 | python3 data.py -exp GenData -out_dir ${data_dir} -alpha_mean ${alpha_mean} -alpha_var ${alpha_var} -z_var ${z_var} -input_dim ${input_dim} -n_attributes ${n_attributes} -n_classes ${n_classes} -n_groups ${n_groups} 7 | ``` 8 | 9 | You can generate the data with hidden concepts using the following command where 'data_dir' is directory for the previously generated dataset. 10 | ``` 11 | python3 data.py -exp 'Hidden' -data_dir ${data_dir} -out_dir ${hidden_data_dir} -hidden_ratio ${hidden_ratio} 12 | ``` 13 | 14 | # Training the models 15 | 16 | The following command is an example to train XtoC model. 17 | 18 | ``` 19 | python3 ../experiments.py synthetic Concept_XtoC --seed ${seed} -ckpt '' -log_dir ${log_dir} -e 1000 -optimizer sgd -use_attr -weighted_loss multiple -data_dir ${data_dir} -n_attributes ${n_attributes} -input_dim ${input_dim} -n_classes ${n_classes} -normalize_loss -b 64 -weight_decay ${reg} -lr ${init_lr} -scheduler_step ${step} -bottleneck -expand_dim 100 -input_dim ${input_dim} -n_attributes ${n_attributes} -n_classes ${n_classes} 20 | ``` 21 | 22 | # Test-time intervention 23 | 24 | You can use the similar command as in the CUB. -------------------------------------------------------------------------------- /SKINCON/README.md: -------------------------------------------------------------------------------- 1 | # Dataset preprocessing 2 | 3 | Download the images from [Fitzpatrick17k](https://github.com/mattgroh/fitzpatrick17k) and save them in data/fitz/ directory. 4 | 5 | Download the annotations_fitzpatrick17k.csv file from [SKINCON](https://skincon-dataset.github.io/) and save it in data/ directory. 6 | 7 | Obtain the pkl files with the binary class labels using the following command. 8 | ``` 9 | python3 data_processing.py -save_dir ${save_dir} -data_dir data/ -class_label 'binary' 10 | ``` 11 | 12 | For the experiments, we only use the 22 concepts which are present in at least 50 images. 13 | By the following command, we generate the new data existing in 'modify_data_dir' into 'data_dir' directory. 14 | ``` 15 | python3 generate_new_data.py --exp Concept --data_dir ${data_dir} --modify_data_dir ${modify_data_dir} 16 | ``` 17 | 18 | 19 | # Training the models 20 | 21 | Download the pretrained models from [Disparities in Dermatology AI Performance on a Diverse, Curated Clinical Image Set](https://drive.google.com/drive/folders/1WscikgfyQWg1OTPem_JZ-8EjbCQ_FHxm). 22 | 23 | The training code is based on [Concept Bottleneck Models](https://github.com/yewsiang/ConceptBottleneck). 24 | 25 | The following command is an example to train the XtoC model with the data saved in 'data/22concepts/binary' directory. 26 | 27 | ``` 28 | python3 ../experiments.py skincon Concept_XtoC -seed ${seed} -ckpt 1 -log_dir ${log_dir} -e 1000 -optimizer sgd -pretrained -use_aux -use_attr -weighted_loss multiple -data_dir data/22concepts/ -n_attributes 22 -normalize_loss -b 64 -weight_decay ${weight_decay} -lr ${init_lr} -scheduler_step ${step} -bottleneck -class_label binary 29 | ``` 30 | 31 | # Test-time intervention 32 | 33 | You can use the similar command as in the CUB. -------------------------------------------------------------------------------- /CUB/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common functions for visualization in different ipython notebooks 3 | """ 4 | import os 5 | import random 6 | from matplotlib.pyplot import figure, imshow, axis, show 7 | from matplotlib.image import imread 8 | 9 | N_CLASSES = 200 10 | N_ATTRIBUTES = 312 11 | 12 | def get_class_attribute_names(img_dir = 'CUB_200_2011/images/', feature_file='CUB_200_2011/attributes/attributes.txt'): 13 | """ 14 | Returns: 15 | class_to_folder: map class id (0 to 199) to the path to the corresponding image folder (containing actual class names) 16 | attr_id_to_name: map attribute id (0 to 311) to actual attribute name read from feature_file argument 17 | """ 18 | class_to_folder = dict() 19 | for folder in os.listdir(img_dir): 20 | class_id = int(folder.split('.')[0]) 21 | class_to_folder[class_id - 1] = os.path.join(img_dir, folder) 22 | 23 | attr_id_to_name = dict() 24 | with open(feature_file, 'r') as f: 25 | for line in f: 26 | idx, name = line.strip().split(' ') 27 | attr_id_to_name[int(idx) - 1] = name 28 | return class_to_folder, attr_id_to_name 29 | 30 | def sample_files(class_label, class_to_folder, number_of_files=10): 31 | """ 32 | Given a class id, extract the path to the corresponding image folder and sample number_of_files randomly from that folder 33 | """ 34 | folder = class_to_folder[class_label] 35 | class_files = random.sample(os.listdir(folder), number_of_files) 36 | class_files = [os.path.join(folder, f) for f in class_files] 37 | return class_files 38 | 39 | def show_img_horizontally(list_of_files): 40 | """ 41 | Given a list of files, display them horizontally in the notebook output 42 | """ 43 | fig = figure(figsize=(40,40)) 44 | number_of_files = len(list_of_files) 45 | for i in range(number_of_files): 46 | a=fig.add_subplot(1,number_of_files,i+1) 47 | image = imread(list_of_files[i]) 48 | imshow(image) 49 | axis('off') 50 | show(block=True) 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.11.1 2 | brokenaxes==0.5.0 3 | brotlipy==0.7.0 4 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi 5 | cffi @ file:///home/builder/ci_310/cffi_1642753365720/work 6 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 7 | click==8.1.3 8 | cryptography @ file:///tmp/build/80754af9/cryptography_1652101588599/work 9 | cycler==0.11.0 10 | data==0.4 11 | decorator==5.1.1 12 | docker-pycreds==0.4.0 13 | filelock==3.8.2 14 | fonttools==4.33.3 15 | funcsigs==1.0.2 16 | future==0.18.3 17 | gdown==4.6.0 18 | gitdb==4.0.9 19 | GitPython==3.1.27 20 | graphviz==0.20 21 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 22 | joblib==1.1.0 23 | kiwisolver==1.4.3 24 | latex==0.7.0 25 | matplotlib==3.5.2 26 | mkl-fft==1.3.1 27 | mkl-random @ file:///home/builder/ci_310/mkl_random_1641843545607/work 28 | mkl-service==2.4.0 29 | numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1652801679809/work 30 | packaging==21.3 31 | pandas==1.5.2 32 | pathtools==0.1.2 33 | Pillow==9.0.1 34 | promise==2.3 35 | protobuf==3.20.1 36 | psutil==5.9.1 37 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 38 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 39 | pyparsing==3.0.9 40 | PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work 41 | python-dateutil==2.8.2 42 | pytz==2022.6 43 | PyYAML==6.0 44 | requests @ file:///opt/conda/conda-bld/requests_1641824580448/work 45 | scikit-learn==1.1.1 46 | scipy==1.8.1 47 | seaborn==0.12.2 48 | sentry-sdk==1.6.0 49 | setproctitle==1.2.3 50 | shortuuid==1.0.9 51 | shutilwhich==1.1.0 52 | six @ file:///tmp/build/80754af9/six_1644875935023/work 53 | sklearn==0.0 54 | smmap==5.0.0 55 | soupsieve==2.3.2.post1 56 | tempdir==0.7.1 57 | threadpoolctl==3.1.0 58 | torch==1.12.0 59 | torchaudio==0.12.0 60 | torchsummary==1.5.1 61 | torchvision==0.13.0 62 | torchviz==0.0.2 63 | tqdm==4.64.1 64 | typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work 65 | urllib3 @ file:///opt/conda/conda-bld/urllib3_1650637206367/work 66 | wandb==0.12.21 67 | -------------------------------------------------------------------------------- /CUB/hyperparam_checking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import datetime 5 | import subprocess 6 | from CUB.config import BASE_DIR 7 | 8 | 9 | def find_early_stop_epoch(log_file, patience): 10 | all_best_epochs = find_all_best_epochs(log_file) 11 | #print("last best epoch:", all_best_epochs[-1]) 12 | for i, epoch in enumerate(all_best_epochs): 13 | if i == len(all_best_epochs) - 1: 14 | return epoch 15 | if epoch > 500 and all_best_epochs[i+1] - epoch > patience: 16 | return epoch 17 | return None if not all_best_epochs else all_best_epochs[-1] 18 | 19 | def find_best_config_retrain(path, patience=1000): 20 | best_records = dict() 21 | for config in os.listdir(path): 22 | config_path = os.path.join(path, config) 23 | if not os.path.isdir(config_path): 24 | continue 25 | #lambda_val = re.findall(r"attr_loss_weight_\d*\.\d+", config)[0].split('_')[-1] 26 | #if float(lambda_val) > 1: 27 | # continue 28 | if 'end2end' in config: 29 | model_type = 'end2end' 30 | elif 'bottleneck' in config: 31 | model_type = 'bottleneck' 32 | elif 'onlyAttr' in config: 33 | model_type = 'onlyAttr' 34 | elif 'simple_finetune' in config: 35 | model_type = 'simple_finetune' 36 | else: 37 | model_type = 'multitask' 38 | log_file = os.path.join(config_path, 'log.txt') 39 | all_val_acc = find_best_perf(log_file) 40 | epoch = find_early_stop_epoch(log_file, patience) 41 | if epoch is None: 42 | continue 43 | print(config_path) 44 | print (model_type, epoch, all_val_acc[epoch], '\n') 45 | 46 | 47 | if __name__ == "__main__": 48 | cmd = sys.argv[1] 49 | if 'invalid' in cmd: 50 | find_invalid_runs(BASE_DIR) 51 | else: 52 | results_dir = os.path.join(BASE_DIR, sys.argv[2]) 53 | if 'retrain' in cmd: 54 | find_best_config_retrain(results_dir) 55 | else: 56 | find_best_config_hyperparam_tune(results_dir) 57 | -------------------------------------------------------------------------------- /CUB/README.md: -------------------------------------------------------------------------------- 1 | # CUB Dataset 2 | ## Dataset preprocessing and training the models 3 | 4 | Our code is based on [Concept Bottleneck Models](https://github.com/yewsiang/ConceptBottleneck). 5 | Please check the above repository for dataset preprocessing and training the models. 6 | 7 | ## Test-time intervention 8 | 9 | ### Basic usage 10 | Conduct test-time intervention for independent (IND) models using the following command. 11 | ``` 12 | python3 tti.py -model_dirs ${model_dirs} -model_dirs2 ${model_dirs2} -use_attr -bottleneck -criterion ${criterion} -level {level} -inference_mode ${inference_mode} -n_trials 1 -n_attributes 112 -data_dir2 CUB_processed/CUB_raw -data_dir CUB_processed/class_attr_data_10 -use_sigmoid -class_level -use_invisible -no_intervention_when_invisible 13 | ``` 14 | 15 | - Criterion is one of ['rand', 'ucp', 'lcp', 'cctp', 'ectp', 'edutp']. 16 | - Intervention level should be represented as one of the ['i+s', 'i+b', 'g+s', 'g+b']. 17 | - Inference mode represents the conceptualization strategy, e.g., 'soft', 'hard', or 'samp'. 18 | 19 | ### Other training strategies 20 | - SEQ 21 | ``` 22 | python3 tti.py -model_dirs ${model_dirs} -model_dirs2 ${model_dirs2} -use_attr -bottleneck -criterion ${criterion} -level {level} -inference_mode ${inference_mode} -n_trials 1 -n_attributes 112 -data_dir2 CUB_processed/CUB_raw -data_dir CUB_processed/class_attr_data_10 -use_invisible -class_level -no_intervention_when_invisible 23 | ``` 24 | - JNT 25 | ``` 26 | python3 tti.py -model_dirs ${model_dirs} -use_attr -bottleneck -criterion ${criterion} -level {level} -inference_mode ${inference_mode} -n_trials 1 -n_attributes 112 -data_dir2 CUB_processed/CUB_raw -data_dir CUB_processed/class_attr_data_10 -use_invisible -class_level -no_intervention_when_invisible 27 | ``` 28 | - JNT+P 29 | ``` 30 | python3 tti.py -model_dirs ${model_dirs} -use_attr -bottleneck -criterion ${criterion} -level {level} -inference_mode ${inference_mode} -n_trials 1 -n_attributes 112 -data_dir2 CUB_processed/CUB_raw -data_dir CUB_processed/class_attr_data_10 -use_invisible -use_sigmoid -class_level -no_intervention_when_invisible 31 | ``` 32 | 33 | ### Comparison between different intervention levels 34 | 35 | For the group-wise interventions, comparison values to individual intervention for Figure 5-(b) and Figure 18 are printed at the last lines. 36 | 37 | ## Cost of intervention 38 | 39 | Calculate mean labeling time used for the experiments in Appendix C. 40 | ``` 41 | python3 generate_new_data.py --exp Cost 42 | ``` 43 | -------------------------------------------------------------------------------- /CUB/models.py: -------------------------------------------------------------------------------- 1 | from CUB.template_model import MLP, inception_v3, End2EndModel 2 | 3 | 4 | # Independent & Sequential Model 5 | def ModelXtoC(pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, three_class): 6 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 7 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 8 | three_class=three_class) 9 | 10 | # Independent Model 11 | def ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim): 12 | # X -> C part is separate, this is only the C -> Y part 13 | if n_class_attr == 3: 14 | model = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 15 | else: 16 | model = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 17 | return model 18 | 19 | # Sequential Model 20 | def ModelXtoChat_ChatToY(n_class_attr, n_attributes, num_classes, expand_dim): 21 | # X -> C part is separate, this is only the C -> Y part (same as Independent model) 22 | return ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim) 23 | 24 | # Joint Model 25 | def ModelXtoCtoY(n_class_attr, pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, 26 | use_relu, use_sigmoid): 27 | model1 = inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 28 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 29 | three_class=(n_class_attr == 3)) 30 | if n_class_attr == 3: 31 | model2 = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 32 | else: 33 | model2 = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 34 | return End2EndModel(model1, model2, use_relu, use_sigmoid, n_class_attr) 35 | 36 | # Standard Model 37 | def ModelXtoY(pretrained, freeze, num_classes, use_aux): 38 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux) 39 | 40 | # Multitask Model 41 | def ModelXtoCY(pretrained, freeze, num_classes, use_aux, n_attributes, three_class, connect_CY): 42 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 43 | n_attributes=n_attributes, bottleneck=False, three_class=three_class, 44 | connect_CY=connect_CY) -------------------------------------------------------------------------------- /SYNTHETIC/models.py: -------------------------------------------------------------------------------- 1 | from CUB.template_model import MLP, inception_v3, End2EndModel 2 | 3 | 4 | # Independent & Sequential Model 5 | def ModelXtoC(pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, three_class): 6 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 7 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 8 | three_class=three_class) 9 | 10 | # Independent Model 11 | def ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim): 12 | # X -> C part is separate, this is only the C -> Y part 13 | if n_class_attr == 3: 14 | model = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 15 | else: 16 | model = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 17 | return model 18 | 19 | # Sequential Model 20 | def ModelXtoChat_ChatToY(n_class_attr, n_attributes, num_classes, expand_dim): 21 | # X -> C part is separate, this is only the C -> Y part (same as Independent model) 22 | return ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim) 23 | 24 | # Joint Model 25 | def ModelXtoCtoY(n_class_attr, pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, 26 | use_relu, use_sigmoid): 27 | model1 = inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 28 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 29 | three_class=(n_class_attr == 3)) 30 | if n_class_attr == 3: 31 | model2 = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 32 | else: 33 | model2 = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 34 | return End2EndModel(model1, model2, use_relu, use_sigmoid, n_class_attr) 35 | 36 | # Standard Model 37 | def ModelXtoY(pretrained, freeze, num_classes, use_aux): 38 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux) 39 | 40 | # Multitask Model 41 | def ModelXtoCY(pretrained, freeze, num_classes, use_aux, n_attributes, three_class, connect_CY): 42 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 43 | n_attributes=n_attributes, bottleneck=False, three_class=three_class, 44 | connect_CY=connect_CY) -------------------------------------------------------------------------------- /SKINCON/backbone.py: -------------------------------------------------------------------------------- 1 | from SKINCON.template_model import MLP, inception_v3, End2EndModel 2 | 3 | 4 | # Independent & Sequential Model 5 | def ModelXtoC(pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, three_class): 6 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 7 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 8 | three_class=three_class) 9 | 10 | # Independent Model 11 | def ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim): 12 | # X -> C part is separate, this is only the C -> Y part 13 | if n_class_attr == 3: 14 | model = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 15 | else: 16 | model = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 17 | return model 18 | 19 | # Sequential Model 20 | def ModelXtoChat_ChatToY(n_class_attr, n_attributes, num_classes, expand_dim): 21 | # X -> C part is separate, this is only the C -> Y part (same as Independent model) 22 | return ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim) 23 | 24 | # Joint Model 25 | def ModelXtoCtoY(n_class_attr, pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, 26 | use_relu, use_sigmoid): 27 | model1 = inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 28 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 29 | three_class=(n_class_attr == 3)) 30 | if n_class_attr == 3: 31 | model2 = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 32 | else: 33 | model2 = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 34 | return End2EndModel(model1, model2, use_relu, use_sigmoid, n_class_attr) 35 | 36 | # Standard Model 37 | def ModelXtoY(pretrained, freeze, num_classes, use_aux): 38 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux) 39 | 40 | # Multitask Model 41 | def ModelXtoCY(pretrained, freeze, num_classes, use_aux, n_attributes, three_class, connect_CY): 42 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 43 | n_attributes=n_attributes, bottleneck=False, three_class=three_class, 44 | connect_CY=connect_CY) -------------------------------------------------------------------------------- /SKINCON/models.py: -------------------------------------------------------------------------------- 1 | from CUB.template_model import CBM_AUC, MLP, inception_v3, End2EndModel 2 | 3 | 4 | # Independent & Sequential Model 5 | def ModelXtoC(pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, three_class): 6 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 7 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 8 | three_class=three_class) 9 | 10 | # Independent Model 11 | def ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim): 12 | # X -> C part is separate, this is only the C -> Y part 13 | if n_class_attr == 3: 14 | model = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 15 | else: 16 | model = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 17 | return model 18 | 19 | # Sequential Model 20 | def ModelXtoChat_ChatToY(n_class_attr, n_attributes, num_classes, expand_dim): 21 | # X -> C part is separate, this is only the C -> Y part (same as Independent model) 22 | return ModelOracleCtoY(n_class_attr, n_attributes, num_classes, expand_dim) 23 | 24 | # Joint Model 25 | def ModelXtoCtoY(n_class_attr, pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, 26 | use_relu, use_sigmoid): 27 | model1 = inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 28 | n_attributes=n_attributes, bottleneck=True, expand_dim=expand_dim, 29 | three_class=(n_class_attr == 3)) 30 | if n_class_attr == 3: 31 | model2 = MLP(input_dim=n_attributes * n_class_attr, num_classes=num_classes, expand_dim=expand_dim) 32 | else: 33 | model2 = MLP(input_dim=n_attributes, num_classes=num_classes, expand_dim=expand_dim) 34 | return End2EndModel(model1, model2, use_relu, use_sigmoid, n_class_attr) 35 | 36 | # CBM-AUC 37 | def ModelXtoCtoYPlusXtoY(n_class_attr, pretrained, freeze, num_classes, use_aux, n_attributes, expand_dim, 38 | use_relu, use_sigmoid): 39 | return CBM_AUC() 40 | 41 | # Standard Model 42 | def ModelXtoY(pretrained, freeze, num_classes, use_aux): 43 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux) 44 | 45 | # Multitask Model 46 | def ModelXtoCY(pretrained, freeze, num_classes, use_aux, n_attributes, three_class, connect_CY): 47 | return inception_v3(pretrained=pretrained, freeze=freeze, num_classes=num_classes, aux_logits=use_aux, 48 | n_attributes=n_attributes, bottleneck=False, three_class=three_class, 49 | connect_CY=connect_CY) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Closer Look at the Intervention Procedure of Concept Bottleneck Models 2 | 3 | This repository contains source code for ICML 2023 paper [A Closer Look at the Intervention Procedure of Concept Bottleneck Models](https://arxiv.org/abs/2302.14260) by Sungbin Shin, Yohan Jo, Sungsoo Ahn, and Namhoon Lee. 4 | 5 | Our experiments are based on the following three datasets. 6 | 7 | - [Caltech-UCSD Birds-200-2011 (CUB-200-2011)](https://www.vision.caltech.edu/datasets/cub_200_2011/) 8 | - [SKIN Concepts Dataset (SkinCon)](https://skincon-dataset.github.io/) 9 | - Synthetic 10 | 11 | ## TL;DR 12 | We develop various ways of selecting intervening concepts to improve the intervention effectiveness and conduct an array of in-depth analyses as to how they evolve under different circumstances. 13 | 14 | ## Abstract 15 | Concept bottleneck models (CBMs) are a class of interpretable neural network models that predict the target response of a given input based on its high-level concepts. 16 | Unlike the standard end-to-end models, CBMs enable domain experts to intervene on the predicted concepts and rectify any mistakes at test time, so that more accurate task predictions can be made at the end. 17 | While such intervenability provides a powerful avenue of control, many aspects of the intervention procedure remain rather unexplored. 18 | In this work, we develop various ways of selecting intervening concepts to improve the intervention effectiveness and conduct an array of in-depth analyses as to how they evolve under different circumstances. 19 | Specifically, we find that an informed intervention strategy can reduce the task error more than ten times compared to the current baseline under the same amount of intervention counts in realistic settings, and yet, this can vary quite significantly when taking into account different intervention granularity. 20 | We verify our findings through comprehensive evaluations, not only on the standard real datasets, but also on synthetic datasets that we generate based on a set of different causal graphs. 21 | We further discover some major pitfalls of the current practices which, without a proper addressing, raise concerns on reliability and fairness of the intervention procedure. 22 | 23 | 24 | | ![fig](./figures/cub_main_result.png) | ![fig](./figures/skincon_main_result.png) | ![fig](./figures/synthetic_main_result.png) | 25 | |:--------------:|:----------:|:----------------------:| 26 | | CUB | SkinCon | Synthetic | 27 | 28 | ## Requirements 29 | 30 | Install the required libraries using the following command. 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Usage 36 | See the readme file of each folder to check how to preprocess the dataset, train the models, and conduct test-time interventions. 37 | 38 | ## Citation 39 | 40 | ``` 41 | @ainproceedings{shin2023closer, 42 | title={A Closer Look at the Intervention Procedure of Concept Bottleneck Models}, 43 | author={Shin, Sungbin and Jo, Yohan and Ahn, Sungsoo and Lee, Namhoon}, 44 | year={2023}, 45 | journal={ICML} 46 | } 47 | ``` -------------------------------------------------------------------------------- /CUB/hyperopt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tune hyperparameters for end2end and multitask models with different lambda values 3 | """ 4 | import os 5 | import sys 6 | import argparse 7 | import subprocess 8 | 9 | BASE_DIR = '' 10 | DATA_DIR = 'class_attr_data_10' 11 | N_ATTR = 112 12 | USE_RELU = False 13 | USE_SIGMOID = False 14 | 15 | all_lr = [0.01, 0.001] 16 | all_optimizer = ['SGD'] #, 'RMSprop'] 17 | all_batch_size = [64] 18 | all_lambda_val = [0.001, 0.01, 0.1, 0, 1] 19 | all_scheduler_step = [1000, 20, 10, 15] # large scheduler step = constant lr 20 | all_weight_decay = [0.0004, 0.00004] 21 | all_model_type = ['simple_finetune', 'onlyAttr', 'bottleneck', 'multitask', 'end2end'] 22 | 23 | 24 | all_configs = [{'model_type': m, 'lr': lr, 'batch_size': b, 'optimizer': o, 'lambda': l, 'scheduler_step': s, 'weight_decay': w} 25 | for m in all_model_type for lr in all_lr for b in all_batch_size for o in all_optimizer for l in all_lambda_val for s in all_scheduler_step for w in all_weight_decay] 26 | BASE_COMMAND = 'python train.py -e 2000 -pretrained %s' 27 | 28 | def launch_job(config, save_dir): 29 | save_path = os.path.join(BASE_DIR, save_dir) 30 | if not os.path.exists(save_path): 31 | os.mkdir(save_path) 32 | if config['model_type'] not in ['multitask', 'end2end']: 33 | if config['lambda'] != 1: 34 | return 35 | 36 | if config['model_type'] == 'simple_finetune': 37 | model_suffix = '' 38 | else: 39 | model_suffix = '-use_attr -weighted_loss multiple -data_dir %s -n_attributes %d -attr_loss_weight %.3f -normalize_loss' % (DATA_DIR, N_ATTR, config['lambda']) 40 | if USE_RELU: 41 | model_suffix += ' -use_relu' 42 | if USE_SIGMOID: 43 | model_suffix += ' -use_sigmoid' 44 | 45 | if config['model_type'] == 'end2end': 46 | model_suffix += ' -end2end' 47 | elif config['model_type'] == 'bottleneck': 48 | model_suffix += ' -bottleneck' 49 | elif config['model_type'] == 'onlyAttr': 50 | model_suffix += ' -no_img' 51 | command = model_suffix + ' -batch_size %d -lr %f -optimizer %s -weight_decay %f -scheduler_step %s' % (config['batch_size'], config['lr'], config['optimizer'], config['weight_decay'], config['scheduler_step']) 52 | log_dir = os.path.join(save_path, config['model_type']) 53 | if not os.path.exists(log_dir): 54 | os.mkdir(log_dir) 55 | log_dir = os.path.join(log_dir, '_'.join(command.split(' '))) 56 | log_dir = log_dir.replace('-', '') 57 | command = command + ' -log_dir %s' % log_dir 58 | command = BASE_COMMAND % command 59 | print("Launch command:", command, '\n') 60 | subprocess.run([command]) 61 | 62 | def parse_arguments(parser=None): 63 | if parser is None: parser = argparse.ArgumentParser(description='PyTorch Training') 64 | parser.add_argument('--save_dir', type=str, required=True, help='directory to the data used for evaluation') 65 | return parser.parse_args() 66 | 67 | def run(args): 68 | for config in all_configs: 69 | launch_job(config, args.save_dir) 70 | 71 | if __name__ == "__main__": 72 | args = parse_arguments() 73 | run(args) 74 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | 2 | import pdb 3 | import sys 4 | 5 | 6 | import torch.backends.cudnn as cudnn 7 | import random 8 | 9 | 10 | def run_experiments(dataset, args): 11 | 12 | if dataset == 'CUB': 13 | from CUB.train import ( 14 | train_X_to_C, 15 | train_oracle_C_to_y_and_test_on_Chat, 16 | train_Chat_to_y_and_test_on_Chat, 17 | train_X_to_C_to_y, 18 | train_X_to_y, 19 | train_X_to_Cy, 20 | test_time_intervention 21 | ) 22 | 23 | elif dataset == 'SKINCON': 24 | from SKINCON.train import ( 25 | train_X_to_C, 26 | train_oracle_C_to_y_and_test_on_Chat, 27 | train_Chat_to_y_and_test_on_Chat, 28 | train_X_to_C_to_y, 29 | train_X_to_y, 30 | train_X_to_Cy, 31 | test_time_intervention 32 | ) 33 | 34 | elif dataset == 'SYNTHETIC': 35 | from SYNTHETIC.train import ( 36 | train_X_to_C, 37 | train_oracle_C_to_y_and_test_on_Chat, 38 | train_Chat_to_y_and_test_on_Chat, 39 | train_X_to_C_to_y 40 | ) 41 | 42 | experiment = args[0].exp 43 | if experiment == 'Concept_XtoC': 44 | train_X_to_C(*args) 45 | 46 | elif experiment == 'Independent_CtoY': 47 | train_oracle_C_to_y_and_test_on_Chat(*args) 48 | 49 | elif experiment == 'Sequential_CtoY': 50 | train_Chat_to_y_and_test_on_Chat(*args) 51 | 52 | elif experiment == 'Joint': 53 | train_X_to_C_to_y(*args) 54 | 55 | elif experiment == 'Standard': 56 | train_X_to_y(*args) 57 | 58 | elif experiment == 'Multitask': 59 | train_X_to_Cy(*args) 60 | 61 | elif experiment == 'TTI': 62 | test_time_intervention(*args) 63 | 64 | def parse_arguments(): 65 | # First arg must be dataset, and based on which dataset it is, we will parse arguments accordingly 66 | assert len(sys.argv) > 2, 'You need to specify dataset and experiment' 67 | assert sys.argv[1].upper() in ['CUB', 'SKINCON', 'SYNTHETIC'], 'Please specify the dataset' 68 | assert sys.argv[2] in ['Concept_XtoC', 'Independent_CtoY', 'Sequential_CtoY', 69 | 'Standard', 'StandardWithAuxC', 'Multitask', 'Joint','TTI'], \ 70 | 'Please specify valid experiment. Current: %s' % sys.argv[2] 71 | dataset = sys.argv[1].upper() 72 | experiment = sys.argv[2].upper() 73 | 74 | # Handle accordingly to dataset 75 | if dataset == 'CUB': 76 | from CUB.train import parse_arguments 77 | elif dataset == 'SKINCON': 78 | from SKINCON.train import parse_arguments 79 | elif dataset == 'SYNTHETIC': 80 | from SYNTHETIC.train import parse_arguments 81 | 82 | args = parse_arguments(experiment=experiment) 83 | return dataset, args 84 | 85 | if __name__ == '__main__': 86 | 87 | import torch 88 | import numpy as np 89 | 90 | dataset, args = parse_arguments() 91 | 92 | if args[0].fix_seed: 93 | torch.manual_seed(0) 94 | torch.cuda.manual_seed(0) 95 | torch.cuda.manual_seed_all(0) 96 | np.random.seed(0) 97 | cudnn.benchmark = False 98 | cudnn.deterministic = True 99 | random.seed(0) 100 | else: 101 | # Seeds 102 | np.random.seed(args[0].seed) 103 | torch.manual_seed(args[0].seed) 104 | 105 | run_experiments(dataset, args) 106 | -------------------------------------------------------------------------------- /SYNTHETIC/template_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | InceptionV3 Network modified from https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py 3 | New changes: add softmax layer + option for freezing lower layers except fc 4 | """ 5 | from cmath import exp 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import Parameter 10 | import torch.nn.functional as F 11 | import torch.utils.model_zoo as model_zoo 12 | 13 | import torchvision.models as models 14 | from torchvision import transforms 15 | 16 | class End2EndModel(torch.nn.Module): 17 | def __init__(self, model1, model2, n_attributes, use_relu=False, use_sigmoid=False, n_class_attr=2): 18 | super(End2EndModel, self).__init__() 19 | self.first_model = model1 20 | self.sec_model = model2 21 | self.use_relu = use_relu 22 | self.use_sigmoid = use_sigmoid 23 | self.n_attributes = n_attributes 24 | 25 | def forward_stage2(self, stage1_out): 26 | if self.use_relu: 27 | attr_outputs = [nn.ReLU()(o) for o in stage1_out] 28 | elif self.use_sigmoid: 29 | attr_outputs = [torch.nn.Sigmoid()(o) for o in stage1_out] 30 | else: 31 | attr_outputs = stage1_out 32 | 33 | stage2_inputs = attr_outputs 34 | stage2_inputs = torch.cat(stage2_inputs, dim=1) 35 | all_out = [self.sec_model(stage2_inputs)] 36 | all_out.extend(stage1_out) 37 | return all_out 38 | 39 | def forward(self, x): 40 | outputs = self.first_model(x) 41 | new_outputs = [] 42 | 43 | for c in range(self.n_attributes): 44 | new_outputs.append(outputs[:, c].unsqueeze(1)) 45 | 46 | outputs = new_outputs 47 | return self.forward_stage2(outputs) 48 | 49 | class MLP(nn.Module): 50 | def __init__(self, input_dim, num_classes, expand_dim): 51 | super(MLP, self).__init__() 52 | self.expand_dim = expand_dim 53 | self.num_classes = num_classes 54 | if self.expand_dim: 55 | self.linear = nn.Linear(input_dim, expand_dim) 56 | self.activation1 = torch.nn.ReLU() 57 | self.linear2 = nn.Linear(expand_dim, expand_dim) #softmax is automatically handled by loss function 58 | self.activation2 = torch.nn.ReLU() 59 | self.linear3 = nn.Linear(expand_dim, num_classes) #softmax is automatically handled by loss function 60 | else: 61 | self.linear = nn.Linear(input_dim, num_classes) 62 | 63 | def forward(self, x): 64 | x = self.linear(x) 65 | if hasattr(self, 'expand_dim') and self.expand_dim: 66 | x = self.activation1(x) 67 | x = self.linear2(x) 68 | x = self.activation2(x) 69 | x = self.linear3(x) 70 | return x 71 | 72 | 73 | class FC(nn.Module): 74 | 75 | def __init__(self, input_dim, output_dim, expand_dim, stddev=None): 76 | """ 77 | Extend standard Torch Linear layer to include the option of expanding into 2 Linear layers 78 | """ 79 | super(FC, self).__init__() 80 | self.expand_dim = expand_dim 81 | if self.expand_dim > 0: 82 | self.relu = nn.ReLU() 83 | self.fc_new = nn.Linear(input_dim, expand_dim) 84 | self.fc = nn.Linear(expand_dim, output_dim) 85 | else: 86 | self.fc = nn.Linear(input_dim, output_dim) 87 | if stddev: 88 | self.fc.stddev = stddev 89 | if expand_dim > 0: 90 | self.fc_new.stddev = stddev 91 | 92 | def forward(self, x): 93 | if self.expand_dim > 0: 94 | x = self.fc_new(x) 95 | x = self.relu(x) 96 | x = self.fc(x) 97 | return x -------------------------------------------------------------------------------- /SKINCON/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # General 4 | BASE_DIR = '' 5 | N_ATTRIBUTES = 22 6 | # For binary classification, output is scalar value 7 | N_CLASSES = {'binary': 2, 'three': 3, 'nine': 9, 'whole': 114} 8 | N_Y_OUTPUTS = {'binary': 2, 'three': 3, 'nine': 9, 'whole': 114} 9 | 10 | #GROUP_DICT = {0: [0, 1, 2, 3], 1: [4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, 15], 3: [16, 17, 18, 19, 20, 21], 4: [22, 23, 24], 5: [25, 26, 27, 28, 29, 30], 6: [31], 7: [32, 33, 34, 35, 36], 8: [37, 38], 9: [39, 40, 41, 42, 43, 44], 10: [45, 46, 47, 48, 49], 11: [50], 12: [51, 52], 13: [53, 54, 55, 56, 57, 58], 14: [59, 60, 61, 62, 63], 15: [64, 65, 66, 67, 68, 69], 16: [70, 71, 72, 73, 74, 75], 17: [76, 77], 18: [78, 79, 80], 19: [81, 82], 20: [83, 84, 85], 21: [86, 87, 88], 22: [89], 23: [90, 91, 92, 93, 94, 95], 24: [96, 97, 98], 25: [99, 100, 101], 26: [102, 103, 104, 105, 106, 107], 27: [108, 109, 110, 111]} 11 | 12 | # Training 13 | UPWEIGHT_RATIO = 9.0 14 | MIN_LR = 0.0001 15 | LR_DECAY_SIZE = 0.1 16 | 17 | ALL_LABELS_RATIO = [0.004464285714285714, 0.013392857142857142, 0.02072704081632653, 0.00510204081632653, 0.00701530612244898, 0.01020408163265306, 0.0194515306122449, 0.004464285714285714, 0.03220663265306122, 0.003826530612244898, 0.0028698979591836736, 0.003826530612244898, 0.004464285714285714, 0.005739795918367347, 0.004464285714285714, 0.009566326530612245, 0.0031887755102040817, 0.009885204081632654, 0.0031887755102040817, 0.013073979591836735, 0.0028698979591836736, 0.00510204081632653, 0.011798469387755101, 0.006377551020408163, 0.005420918367346939, 0.00350765306122449, 0.006377551020408163, 0.0031887755102040817, 0.015306122448979591, 0.00510204081632653, 0.00350765306122449, 0.006696428571428571, 0.01881377551020408, 0.006058673469387755, 0.011798469387755101, 0.0028698979591836736, 0.01211734693877551, 0.005739795918367347, 0.0047831632653061226, 0.00350765306122449, 0.005739795918367347, 0.007653061224489796, 0.007653061224489796, 0.007334183673469388, 0.0047831632653061226, 0.003826530612244898, 0.006377551020408163, 0.006377551020408163, 0.028698979591836735, 0.005420918367346939, 0.0028698979591836736, 0.02806122448979592, 0.007653061224489796, 0.007334183673469388, 0.006377551020408163, 0.005739795918367347, 0.016900510204081634, 0.007334183673469388, 0.004464285714285714, 0.005739795918367347, 0.011798469387755101, 0.0031887755102040817, 0.00350765306122449, 0.006696428571428571, 0.014987244897959183, 0.0031887755102040817, 0.008928571428571428, 0.00350765306122449, 0.02072704081632653, 0.004145408163265306, 0.00510204081632653, 0.009247448979591837, 0.004145408163265306, 0.00860969387755102, 0.0047831632653061226, 0.02072704081632653, 0.007653061224489796, 0.0031887755102040817, 0.004464285714285714, 0.013392857142857142, 0.015625, 0.010841836734693877, 0.006058673469387755, 0.008928571428571428, 0.004464285714285714, 0.011479591836734694, 0.036989795918367346, 0.0031887755102040817, 0.006377551020408163, 0.005739795918367347, 0.006058673469387755, 0.022321428571428572, 0.02295918367346939, 0.02295918367346939, 0.006377551020408163, 0.009247448979591837, 0.003826530612244898, 0.004145408163265306, 0.03762755102040816, 0.00350765306122449, 0.007971938775510204, 0.0047831632653061226, 0.002551020408163265, 0.0047831632653061226, 0.006377551020408163, 0.007334183673469388, 0.005420918367346939, 0.008290816326530613, 0.006377551020408163, 0.009566326530612245, 0.006696428571428571, 0.010841836734693877, 0.004464285714285714, 0.005420918367346939] 18 | NINE_PARTITIONS_RATIO = [0.0625, 0.056760204081632654, 0.014349489795918368, 0.07079081632653061, 0.654655612244898, 0.011798469387755101, 0.007653061224489796, 0.08801020408163265, 0.033482142857142856] 19 | THREE_PARTITIONS_RATIO = [0.13329081632653061, 0.1412627551020408, 0.7254464285714286] 20 | BINARY_RATIO = [0.8587372448979592, 0.1412627551020408] 21 | 22 | CLASS_RATIO = {'whole': ALL_LABELS_RATIO, 'nine': NINE_PARTITIONS_RATIO, 'three': THREE_PARTITIONS_RATIO, 'binary': BINARY_RATIO} 23 | 24 | FITZ_FOLDER = "./data/" 25 | FITZPATRICK_CSV = os.path.join(FITZ_FOLDER, "fitzpatrick17k.csv") 26 | FITZPATRICK_DATA = os.path.join(FITZ_FOLDER, "fitz") 27 | FITZ_CONCEPTS_CSV = "./data/annotations_fitzpatrick17k.csv" 28 | 29 | DDI_DIR = "/path/to/ddi" 30 | DDI_CSV = os.path.join(DDI_DIR, "ddi_metadata.csv") 31 | # From the DDI paper. 32 | 33 | CONCEPT_CANDIDATES = ['Vesicle', 'Papule', 34 | 'Macule', 'Plaque', 'Abscess', 'Pustule', 35 | 'Bulla', 'Patch', 'Nodule', 'Ulcer', 'Crust', 'Erosion', 36 | 'Excoriation', 'Atrophy', 'Exudate', 'Purpura/Petechiae', 'Fissure', 37 | 'Induration', 'Xerosis', 'Telangiectasia', 'Scale', 'Scar', 'Friable', 38 | 'Sclerosis', 'Pedunculated', 'Exophytic/Fungating', 39 | 'Warty/Papillomatous', 'Dome-shaped', 'Flat topped', 40 | 'Brown(Hyperpigmentation)', 'Translucent', 'White(Hypopigmentation)', 41 | 'Purple', 'Yellow', 'Black', 'Erythema', 'Comedo', 'Lichenification', 42 | 'Blue', 'Umbilicated', 'Poikiloderma', 'Salmon', 'Wheal', 'Acuminate', 43 | 'Burrow', 'Gray', 'Pigmented', 'Cyst'] 44 | -------------------------------------------------------------------------------- /CUB/data_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make train, val, test datasets based on train_test_split.txt, and by sampling val_ratio of the official train data to make a validation set 3 | Each dataset is a list of metadata, each includes official image id, full image path, class label, attribute labels, attribute certainty scores, and attribute labels calibrated for uncertainty 4 | """ 5 | import os 6 | import random 7 | import pickle 8 | import argparse 9 | from os import listdir 10 | from os.path import isfile, isdir, join 11 | from collections import defaultdict as ddict 12 | 13 | 14 | def extract_data(data_dir): 15 | cwd = os.getcwd() 16 | data_path = join(cwd,data_dir + '/images') 17 | val_ratio = 0.2 18 | 19 | path_to_id_map = dict() #map from full image path to image id 20 | id_to_path_map = dict() 21 | with open(data_path.replace('images', 'images.txt'), 'r') as f: 22 | for line in f: 23 | items = line.strip().split() 24 | path_to_id_map[join(data_path, items[1])] = int(items[0]) 25 | id_to_path_map[items[0]] = join(data_path, items[1]) 26 | 27 | attribute_labels_all = ddict(list) #map from image id to a list of attribute labels 28 | attribute_certainties_all = ddict(list) #map from image id to a list of attribute certainties 29 | attribute_uncertain_labels_all = ddict(list) #map from image id to a list of attribute labels calibrated for uncertainty 30 | # 1 = not visible, 2 = guessing, 3 = probably, 4 = definitely 31 | uncertainty_map = {1: {1: 0, 2: 0.5, 3: 0.75, 4:1}, #calibrate main label based on uncertainty label 32 | 0: {1: 0, 2: 0.5, 3: 0.25, 4: 0}} 33 | with open(join(cwd, data_dir + '/attributes/image_attribute_labels.txt'), 'r') as f: 34 | for line in f: 35 | file_idx, attribute_idx, attribute_label, attribute_certainty = line.strip().split()[:4] 36 | attribute_label = int(attribute_label) 37 | attribute_certainty = int(attribute_certainty) 38 | uncertain_label = uncertainty_map[attribute_label][attribute_certainty] 39 | attribute_labels_all[int(file_idx)].append(attribute_label) 40 | attribute_uncertain_labels_all[int(file_idx)].append(uncertain_label) 41 | attribute_certainties_all[int(file_idx)].append(attribute_certainty) 42 | 43 | is_train_test = dict() #map from image id to 0 / 1 (1 = train) 44 | train_val_id=[] 45 | with open(join(cwd, data_dir + '/train_test_split.txt'), 'r') as f: 46 | for line in f: 47 | idx, is_train = line.strip().split() 48 | is_train_test[int(idx)] = int(is_train) 49 | if int(is_train): 50 | train_val_id.append(idx) 51 | print("Number of train images from official train test split:", sum(list(is_train_test.values()))) 52 | 53 | 54 | val_id=random.sample(train_val_id, int(len(train_val_id)/5)) 55 | 56 | val_files=[] 57 | 58 | for i in val_id: 59 | val_files.append(id_to_path_map[i]) 60 | 61 | 62 | train_val_data, test_data = [], [] 63 | train_data, val_data = [], [] 64 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 65 | folder_list.sort() #sort by class index 66 | for i, folder in enumerate(folder_list): 67 | folder_path = join(data_path, folder) 68 | classfile_list = [cf for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')] 69 | #classfile_list.sort() 70 | for cf in classfile_list: 71 | img_id = path_to_id_map[join(folder_path, cf)] 72 | img_path = join(folder_path, cf) 73 | metadata = {'id': img_id, 'img_path': img_path, 'class_label': i, 74 | 'attribute_label': attribute_labels_all[img_id], 'attribute_certainty': attribute_certainties_all[img_id], 75 | 'uncertain_attribute_label': attribute_uncertain_labels_all[img_id], 'original_attribute_label': attribute_labels_all[img_id]} 76 | if is_train_test[img_id]: 77 | train_val_data.append(metadata) 78 | if val_files is not None: 79 | if img_path in val_files: 80 | val_data.append(metadata) 81 | else: 82 | train_data.append(metadata) 83 | else: 84 | test_data.append(metadata) 85 | 86 | random.shuffle(train_val_data) 87 | split = int(val_ratio * len(train_val_data)) 88 | train_data = train_val_data[split :] 89 | val_data = train_val_data[: split] 90 | print('Size of train set:', len(train_data)) 91 | return train_data, val_data, test_data 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser(description='Dataset preparation') 96 | parser.add_argument('-save_dir', '-d', help='Where to save the new datasets') 97 | parser.add_argument('-data_dir', help='Where to load the datasets') 98 | args = parser.parse_args() 99 | train_data, val_data, test_data = extract_data(args.data_dir) 100 | 101 | for dataset in ['train','val','test']: 102 | print("Processing %s set" % dataset) 103 | f = open(join(os.getcwd(), args.save_dir, dataset + '.pkl'), 'wb+') 104 | if 'train' in dataset: 105 | pickle.dump(train_data, f) 106 | elif 'val' in dataset: 107 | pickle.dump(val_data, f) 108 | else: 109 | pickle.dump(test_data, f) 110 | f.close() 111 | 112 | -------------------------------------------------------------------------------- /SYNTHETIC/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utils for training, evaluation and data loading 3 | """ 4 | import os 5 | import torch 6 | import pickle 7 | import numpy as np 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | from torch.utils.data import BatchSampler 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | BASE_DIR = '' 15 | 16 | import pandas as pd 17 | 18 | 19 | class SYNTHETICDataset(Dataset): 20 | """ 21 | Returns a compatible Torch Dataset object customized for the CUB dataset 22 | """ 23 | 24 | def __init__(self, pkl_file_paths, use_attr, no_img): 25 | """ 26 | Arguments: 27 | pkl_file_paths: list of full path to all the pkl data 28 | use_attr: whether to load the attributes (e.g. False for simple finetune) 29 | no_img: whether to load the images (e.g. False for A -> Y model) 30 | uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75) 31 | image_dir: default = 'images'. Will be append to the parent dir 32 | n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible 33 | transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing 34 | 35 | return_certainty: whether to return uncertainty label or not 36 | return_path: whether to return image path or not 37 | class_label: which class label to use 38 | """ 39 | self.data = [] 40 | self.is_train = any(["train" in path for path in pkl_file_paths]) 41 | if not self.is_train: 42 | assert any([("test" in path) or ("val" in path) for path in pkl_file_paths]) 43 | for file_path in pkl_file_paths: 44 | self.data.extend(pickle.load(open(file_path, 'rb'))) 45 | self.use_attr = use_attr 46 | self.no_img = no_img 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __getitem__(self, idx): 52 | data = self.data[idx] 53 | input = data['input'] 54 | class_label = data['label'] 55 | if self.use_attr: 56 | attr_label = data['attribute_label'] 57 | if self.no_img: 58 | return attr_label, class_label 59 | else: 60 | return input, class_label, attr_label 61 | else: 62 | return input, class_label 63 | 64 | 65 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 66 | """Samples elements randomly from a given list of indices for imbalanced dataset 67 | Arguments: 68 | indices (list, optional): a list of indices 69 | num_samples (int, optional): number of samples to draw 70 | """ 71 | 72 | def __init__(self, dataset, indices=None): 73 | # if indices is not provided, 74 | # all elements in the dataset will be considered 75 | self.indices = list(range(len(dataset))) \ 76 | if indices is None else indices 77 | 78 | # if num_samples is not provided, 79 | # draw `len(indices)` samples in each iteration 80 | self.num_samples = len(self.indices) 81 | 82 | # distribution of classes in the dataset 83 | label_to_count = {} 84 | for idx in self.indices: 85 | label = self._get_label(dataset, idx) 86 | if label in label_to_count: 87 | label_to_count[label] += 1 88 | else: 89 | label_to_count[label] = 1 90 | 91 | # weight for each sample 92 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 93 | for idx in self.indices] 94 | self.weights = torch.DoubleTensor(weights) 95 | 96 | def _get_label(self, dataset, idx): # Note: for single attribute dataset 97 | return dataset.data[idx]['attribute_label'][0] 98 | 99 | def __iter__(self): 100 | idx = (self.indices[i] for i in torch.multinomial( 101 | self.weights, self.num_samples, replacement=True)) 102 | return idx 103 | 104 | def __len__(self): 105 | return self.num_samples 106 | 107 | def load_data(pkl_paths, use_attr, no_img, batch_size, resampling=False): 108 | """ 109 | Note: Inception needs (299,299,3) images with inputs scaled between -1 and 1 110 | Loads data with transformations applied, and upsample the minority class if there is class imbalance and weighted loss is not used 111 | NOTE: resampling is customized for first attribute only, so change sampler.py if necessary 112 | """ 113 | 114 | dataset = SYNTHETICDataset(pkl_paths, use_attr, no_img) 115 | is_training = any(['train.pkl' in f for f in pkl_paths]) 116 | if is_training: 117 | drop_last = True 118 | shuffle = True 119 | else: 120 | drop_last = False 121 | shuffle = False 122 | if resampling: 123 | sampler = BatchSampler(ImbalancedDatasetSampler(dataset), batch_size=batch_size, drop_last=drop_last) 124 | loader = DataLoader(dataset, batch_sampler=sampler) 125 | else: 126 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) 127 | return loader 128 | 129 | def find_class_imbalance(pkl_file, multiple_attr=False, attr_idx=-1): 130 | """ 131 | Calculate class imbalance ratio for binary attribute labels stored in pkl_file 132 | If attr_idx >= 0, then only return ratio for the corresponding attribute id 133 | If multiple_attr is True, then return imbalance ratio separately for each attribute. Else, calculate the overall imbalance across all attributes 134 | """ 135 | imbalance_ratio = [] 136 | data = pickle.load(open(os.path.join(BASE_DIR, pkl_file), 'rb')) 137 | n = len(data) 138 | n_attr = len(data[0]['attribute_label']) 139 | if attr_idx >= 0: 140 | n_attr = 1 141 | if multiple_attr: 142 | n_ones = [0] * n_attr 143 | total = [n] * n_attr 144 | else: 145 | n_ones = [0] 146 | total = [n * n_attr] 147 | for d in data: 148 | labels = d['attribute_label'] 149 | if multiple_attr: 150 | for i in range(n_attr): 151 | n_ones[i] += labels[i] 152 | else: 153 | if attr_idx >= 0: 154 | n_ones[0] += labels[attr_idx] 155 | else: 156 | n_ones[0] += sum(labels) 157 | for j in range(len(n_ones)): 158 | imbalance_ratio.append(total[j]/n_ones[j] - 1) 159 | if not multiple_attr: #e.g. [9.0] --> [9.0] * 312 160 | imbalance_ratio *= n_attr 161 | return imbalance_ratio -------------------------------------------------------------------------------- /SKINCON/data_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make train, val, test datasets based on train_test_split.txt, and by sampling val_ratio of the official train data to make a validation set 3 | Each dataset is a list of metadata, each includes official image id, full image path, class label, attribute labels, attribute certainty scores, and attribute labels calibrated for uncertainty 4 | """ 5 | import os 6 | import random 7 | import pickle 8 | import argparse 9 | from os import listdir 10 | from os.path import isfile, isdir, join 11 | from collections import defaultdict as ddict 12 | 13 | import pandas as pd 14 | from config import FITZPATRICK_CSV, FITZ_CONCEPTS_CSV 15 | 16 | import numpy as np 17 | from sklearn.model_selection import train_test_split 18 | 19 | def extract_data(data_dir): 20 | cwd = os.getcwd() 21 | data_path = join(cwd,data_dir + 'fitz') 22 | val_ratio = 0.15 23 | test_ratio = 0.15 24 | 25 | fitz_concepts = pd.read_csv(FITZ_CONCEPTS_CSV) 26 | fitz_concepts = fitz_concepts[fitz_concepts["Do not consider this image"] != 1] 27 | 28 | path_to_id_map = dict() #map from full image path to image id 29 | id_to_path_map = dict() 30 | id_to_hash_map = dict() 31 | attribute_labels_all = ddict(list) #map from image id to a list of attribute labels 32 | id_list = [] 33 | with open(FITZ_CONCEPTS_CSV, 'r') as f: 34 | for i, line in enumerate(f): 35 | if i == 0: 36 | continue 37 | items = line.strip().split(',') 38 | file_idx = int(items[0]) 39 | img_path = join(data_path, items[1]) 40 | path_to_id_map[img_path] = file_idx 41 | id_to_path_map[file_idx] = img_path 42 | id_to_hash_map[file_idx] = items[1][:-4] 43 | id_list.append(file_idx) 44 | for j in range(2, len(items)-1): 45 | attribute_label = int(items[j]) 46 | attribute_labels_all[int(file_idx)].append(attribute_label) 47 | 48 | is_train_test = dict() #map from image id to 0 / 1 (1 = train) 49 | 50 | n_samples = len(attribute_labels_all) 51 | 52 | all_data = [] 53 | 54 | fitz_data = pd.read_csv(FITZPATRICK_CSV) 55 | 56 | fitz_data['label_int'] = pd.Categorical(fitz_data['label']).codes 57 | fitz_data['nine_partition_int'] = pd.Categorical(fitz_data['nine_partition_label']).codes 58 | fitz_data['three_partition_int'] = pd.Categorical(fitz_data['three_partition_label']).codes 59 | 60 | num_classes_whole = fitz_data['label_int'].max() + 1 61 | num_classes_nine = fitz_data['nine_partition_int'].max() + 1 62 | num_classes_three = fitz_data['three_partition_int'].max() + 1 63 | print('Number of classes: ', num_classes_whole, num_classes_nine, num_classes_three) 64 | 65 | labels_all = [] 66 | nine_partitions_all = [] 67 | three_partitions_all = [] 68 | binary_all = [] 69 | 70 | for idx in id_list: 71 | img_id = idx 72 | img_path = id_to_path_map[img_id] 73 | hash_id = id_to_hash_map[img_id] 74 | item = fitz_data.loc[fitz_data['md5hash'] == hash_id].values[0] 75 | label = item[-3] 76 | nine_partition_label = item[-2] 77 | three_partition_label = item[-1] 78 | 79 | if item[5] == 'malignant': 80 | benign_malignant = 1 81 | else: 82 | benign_malignant = 0 83 | metadata = {'id': img_id, 'img_path': img_path, 'attribute_label': attribute_labels_all[img_id], 84 | 'label': label, 'nine_partition_label': nine_partition_label, 85 | 'three_partition_label': three_partition_label, 'benign_malignant': benign_malignant} 86 | all_data.append(metadata) 87 | 88 | labels_all.append(label) 89 | nine_partitions_all.append(nine_partition_label) 90 | three_partitions_all.append(three_partition_label) 91 | binary_all.append(benign_malignant) 92 | 93 | if args.class_label == 'whole': 94 | labels_list = np.array(labels_all) 95 | class_size = num_classes_whole 96 | elif args.class_label == 'nine': 97 | labels_list = np.array(nine_partitions_all) 98 | class_size = num_classes_nine 99 | elif args.class_label == 'three': 100 | labels_list = np.array(three_partitions_all) 101 | class_size = num_classes_three 102 | elif args.class_label == 'binary': 103 | labels_list = np.array(binary_all) 104 | class_size = 2 105 | 106 | print(labels_list) 107 | 108 | 109 | id_list = np.arange(n_samples) 110 | trainval_id, test_id, trainval_label, test_label = train_test_split(id_list, labels_list, stratify=labels_list, test_size=test_ratio, random_state=42) 111 | 112 | val_ratio = val_ratio/(1 - test_ratio) 113 | train_id, val_id, train_label, val_label = train_test_split(trainval_id, trainval_label, stratify=trainval_label, test_size=val_ratio, random_state=42) 114 | 115 | all_data = np.array(all_data) 116 | trainval_data = all_data[trainval_id] 117 | test_data = all_data[test_id] 118 | val_data = all_data[val_id] 119 | train_data = all_data[train_id] 120 | train_data = train_data.tolist() 121 | val_data = val_data.tolist() 122 | test_data = test_data.tolist() 123 | 124 | class_cnt = np.zeros(class_size) 125 | for i in range(len(trainval_data)): 126 | label = trainval_label[i] 127 | class_cnt[label] += 1 128 | 129 | print('Ratio of each class: ', args.class_label, (class_cnt/len(trainval_data)).tolist()) 130 | 131 | print('Size of train/val/test set: ', len(train_data), len(val_data), len(test_data)) 132 | 133 | return train_data, val_data, test_data 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description='Dataset preparation') 137 | parser.add_argument('-save_dir', '-d', help='Where to save the new datasets') 138 | parser.add_argument('-data_dir', help='Where to load the datasets') 139 | parser.add_argument('-class_label', help='class label for stratified split') 140 | args = parser.parse_args() 141 | train_data, val_data, test_data = extract_data(args.data_dir) 142 | 143 | directory = join(os.getcwd(), args.save_dir, args.class_label) 144 | 145 | for dataset in ['train','val','test']: 146 | print("Processing %s set" % dataset) 147 | if not os.path.exists(directory): 148 | os.makedirs(directory) 149 | f = open(join(os.getcwd(), args.save_dir, args.class_label, dataset + '.pkl'), 'wb+') 150 | if 'train' in dataset: 151 | pickle.dump(train_data, f) 152 | elif 'val' in dataset: 153 | pickle.dump(val_data, f) 154 | else: 155 | pickle.dump(test_data, f) 156 | f.close() 157 | 158 | -------------------------------------------------------------------------------- /CUB/postprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Automatically extract best config and epoch and retrain the model on both train + val sets 3 | """ 4 | import os 5 | import subprocess 6 | import re 7 | import argparse 8 | from hyperparam_checking import find_best_config_hyperparam_tune, find_best_perf 9 | 10 | 11 | def retrain(hyperparam_tune_path, save_path, all_model_types=[], all_lambdas=[], shots=[], adversarial=False): 12 | """ 13 | Retrain only the best hyperparam config for each model type on both train + val sets 14 | """ 15 | if not os.path.exists(save_path): 16 | os.mkdir(save_path) 17 | best_records = find_best_config_hyperparam_tune(hyperparam_tune_path) 18 | all_data_dir = [] 19 | if shots: 20 | for n_shots in shots: 21 | all_data_dir.append('class_attr_data_10_%d_shot' % n_shots) 22 | else: 23 | all_data_dir.append('class_attr_data_10') 24 | 25 | for data_dir in all_data_dir: 26 | for model_type, v in best_records.items(): 27 | _, epoch, config_dir = v 28 | if all_model_types and not any([t in model_type for t in all_model_types]): 29 | continue 30 | model_path = os.path.join(config_dir, '%d_model.pth' % epoch) 31 | log_dir = os.path.join(save_path, config_dir.split('/')[-1] + '_' + data_dir) 32 | command = 'python train_sigmoid.py -log_dir %s -e 1000 -optimizer sgd -pretrained -use_aux %s' 33 | if 'simple_finetune' in model_path: 34 | model_suffix = '' 35 | else: 36 | lambda_val = float(re.findall(r"attr_loss_weight_\d*\.\d+", config_dir)[0].split('_')[-1]) 37 | if any([t in model_type for t in ['multitask', 'end2end']]) and (all_lambdas and lambda_val not in all_lambdas): 38 | continue 39 | model_suffix = '-use_attr -weighted_loss multiple -data_dir %s -n_attributes 112 -attr_loss_weight %.3f -normalize_loss' % (data_dir, lambda_val) 40 | if 'relu' in hyperparam_tune_path: 41 | model_suffix += ' -use_relu' 42 | print("USE RELU") 43 | 44 | if 'end2end' in model_path: 45 | model_suffix += ' -end2end' 46 | elif 'bottleneck' in model_path: 47 | model_suffix += ' -bottleneck' 48 | elif 'onlyAttr' in model_path: 49 | model_suffix += ' -no_img' 50 | scheduler_step = int(re.findall(r"scheduler_step_\d*", config_dir)[0].split('_')[-1]) 51 | weight_decay = float(re.findall(r"weight_decay_\d*\.\d+", config_dir)[0].split('_')[-1]) 52 | lr = float(re.findall(r"lr_\d*\.\d+", config_dir)[0].split('_')[-1]) 53 | 54 | model_suffix = model_suffix + " -batch_size %d -weight_decay %f -lr %f -scheduler_step %d" % (64, weight_decay, lr, scheduler_step) 55 | command = command % (log_dir, model_suffix) 56 | if not shots: #also train on val set 57 | command += (' -ckpt %s' % model_path) 58 | if adversarial: 59 | command += ' -image_dir CUB_adversarial/CUB_fixed/train/' 60 | print(command) 61 | subprocess.run([command]) 62 | 63 | def run_inference(retrain_path, model_types=[], all_lambdas=[], feature_group=False, sequential=False): 64 | """ 65 | Assuming there is only one model of each (model type, lambda value) in retrain_path 66 | Run inference on test set using the best epoch obtained from retraining 67 | if model_type is specified, then only run inference for that model type 68 | """ 69 | for config in os.listdir(retrain_path): 70 | config_dir = os.path.join(retrain_path, config) 71 | if not os.path.isdir(config_dir): 72 | continue 73 | if 'bottleneck' in config: 74 | model_type = 'bottleneck' 75 | elif 'end2end' in config: 76 | model_type = 'end2end' 77 | elif 'use_attr' in config and 'onlyAttr' not in config: 78 | model_type = 'multitask' 79 | elif 'onlyAttr' not in config: 80 | model_type = 'simple_finetune' 81 | else: 82 | model_type = 'onlyAttr' 83 | if model_types and model_type not in model_types: 84 | continue 85 | all_val_acc = find_best_perf(os.path.join(config_dir, 'log.txt')) 86 | epoch = all_val_acc.index(max(all_val_acc)) 87 | #epoch = round(epoch, -1) - 20 88 | if epoch < 0: 89 | print(config_dir, ' has not started training') 90 | print(epoch, '\t', config) 91 | model_path = os.path.join(config_dir, '%d_model.pth' % epoch) 92 | if 'attr_loss_weight' in model_path: 93 | lambda_val = float(re.findall(r"attr_loss_weight_\d*\.\d+", config_dir)[0].split('_')[-1]) 94 | else: 95 | lambda_val = 1 96 | if any([t in model_types for t in ['multitask', 'end2end']]) and (all_lambdas and lambda_val not in all_lambdas): 97 | continue 98 | if 'NEW_SIGMOID_MODEL' in retrain_path or 'NEW_MODEL' in retrain_path: 99 | command = 'python inference_sigmoid.py -model_dir %s -eval_data test' % model_path 100 | else: 101 | command = 'python inference.py -model_dir %s -eval_data test' % model_path 102 | if feature_group: 103 | command += ' -feature_group_results' 104 | if 'use_attr' in model_path: 105 | command += ' -use_attr -n_attributes 112 -data_dir class_attr_data_10' 106 | if 'onlyAttr' in model_path: 107 | continue 108 | if 'bottleneck' in model_path: 109 | def find_onlyAttr_dir(retrain_path, model_path): 110 | if 'few_shots' in retrain_path: 111 | n_shots = re.findall(r"\d+_shot", model_path)[0] 112 | if sequential: 113 | dir_name = [c for c in os.listdir(retrain_path) if 'onlyAttr_Ahat' in c and n_shots in c][0] 114 | else: 115 | dir_name = [c for c in os.listdir(retrain_path) if 'onlyAttr' in c and 'onlyAttr_Ahat' not in c and n_shots in c][0] 116 | else: 117 | if sequential: 118 | dir_name = [c for c in os.listdir(retrain_path) if 'onlyAttr_Ahat' in c][0] 119 | else: 120 | dir_name = [c for c in os.listdir(retrain_path) if 'onlyAttr' in c and 'onlyAttr_Ahat' not in c][0] 121 | return os.path.join(retrain_path, dir_name) 122 | 123 | onlyAttr_dir = find_onlyAttr_dir(retrain_path, model_path) 124 | val_acc = find_best_perf(os.path.join(onlyAttr_dir, 'log.txt')) 125 | model2_path = os.path.join(onlyAttr_dir, '%d_model.pth' % (val_acc.index(max(val_acc)))) 126 | config_dir = os.path.join(retrain_path, config) 127 | command += (' -model_dir2 %s -bottleneck' % model2_path) 128 | if 'onlyAttr_Ahat' not in model2_path: 129 | command += ' -use_sigmoid' 130 | if 'adversarial' in retrain_path: 131 | command += ' -image_dir CUB_adversarial/CUB_fixed/test/' 132 | subprocess.run([command]) 133 | #TODO: write test inference results to a separate folder 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser(description='PyTorch Training') 138 | parser.add_argument('-save_path', default=None, help='where the trained models are saved') 139 | parser.add_argument('-results_path', default=None, help='where to parse for the best performance') 140 | args = parser.parse_args() 141 | #retrain(args.results_path, args.save_path, all_model_types=['bottleneck', 'end2end'], all_lambdas=['0.01'], shots=[]) 142 | run_inference(args.results_path, model_types=['end2end'], all_lambdas=[0.001], sequential=True) 143 | -------------------------------------------------------------------------------- /SKINCON/generate_new_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create variants of the initial CUB dataset 3 | """ 4 | import os 5 | import sys 6 | import copy 7 | import torch 8 | import random 9 | import pickle 10 | import argparse 11 | import numpy as np 12 | from PIL import Image 13 | from shutil import copyfile 14 | import torchvision.transforms as transforms 15 | from collections import defaultdict as ddict 16 | 17 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 18 | from SKINCON.config import N_ATTRIBUTES, N_CLASSES 19 | 20 | from tqdm import tqdm 21 | from sklearn.manifold import TSNE 22 | 23 | def remove_sparse_concepts(min_count, out_dir, modify_data_dir='', keep_instance_data=False): 24 | """ 25 | Use train.pkl to only keep those present at more than min_count images 26 | Transform data in modify_data_dir file and save the new dataset to out_dir 27 | """ 28 | data = [] 29 | for dataset in ['train', 'val', 'test']: 30 | full_path = os.path.join(os.getcwd(), modify_data_dir, dataset + '.pkl' ) 31 | data.extend(pickle.load(open(full_path, 'rb'))) 32 | 33 | attr_count = np.zeros(len(data[0]['attribute_label'])) 34 | for d in data: 35 | for attr_idx, a in enumerate(d['attribute_label']): 36 | attr_count[attr_idx] += a 37 | 38 | mask = np.where(attr_count >= min_count) #select attributes that are present in at least [min_count] images 39 | 40 | collapse_fn = lambda d: list(np.array(d['attribute_label'])[mask]) 41 | create_new_dataset(out_dir, 'attribute_label', collapse_fn, data_dir=modify_data_dir) 42 | 43 | def create_logits_data(model_path, out_dir, data_dir='', use_relu=False, use_sigmoid=False): 44 | """ 45 | Replace attribute labels in data_dir with the logits output by the model from model_path and save the new data to out_dir 46 | """ 47 | model = torch.load(model_path) 48 | get_logits_train = lambda d: inference(d['img_path'], model, use_relu, use_sigmoid, is_train=True) 49 | get_logits_test = lambda d: inference(d['img_path'], model, use_relu, use_sigmoid, is_train=False) 50 | create_new_dataset(out_dir, 'attribute_label', get_logits_train, datasets=['train'], data_dir=data_dir) 51 | create_new_dataset(out_dir, 'attribute_label', get_logits_train, datasets=['val', 'test'], data_dir=data_dir) 52 | 53 | def inference(img_path, model, use_relu, use_sigmoid, is_train, resol=299, layer_idx=None): 54 | """ 55 | For a single image stored in img_path, run inference using model and return A\hat (if layer_idx is None) or values extracted from layer layer_idx 56 | """ 57 | model.eval() 58 | # see utils.py 59 | if is_train: 60 | transform = transforms.Compose([ 61 | transforms.ColorJitter(brightness=32 / 255, saturation=(0.5, 1.5)), 62 | transforms.RandomResizedCrop(resol), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), # implicitly divides by 255 65 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2]) 66 | ]) 67 | else: 68 | transform = transforms.Compose([ 69 | transforms.CenterCrop(resol), 70 | transforms.ToTensor(), # implicitly divides by 255 71 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2]) 72 | ]) 73 | 74 | try: 75 | idx = img_path.split('/').index('data') 76 | img_path = '/'.join(img_path.split('/')[idx:]) 77 | img = Image.open(img_path).convert('RGB') 78 | except: 79 | img_path_split = img_path.split('/') 80 | split = 'train' if self.is_train else 'test' 81 | img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:]) 82 | img = Image.open(img_path).convert('RGB') 83 | img = Image.open(img_path).convert('RGB') 84 | img_tensor = transform(img).unsqueeze(0) 85 | input_var = torch.autograd.Variable(img_tensor).cuda() 86 | if layer_idx is not None: 87 | all_mods = list(model.modules()) 88 | cropped_model = torch.nn.Sequential(*list(model.children())[:layer_idx]) # nn.ModuleList(all_mods[:layer_idx]) 89 | print(type(input_var), input_var.shape, input_var) 90 | return cropped_model(input_var) 91 | 92 | outputs = model(input_var) 93 | if use_relu: 94 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs] 95 | elif use_sigmoid: 96 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs] 97 | else: 98 | attr_outputs = outputs 99 | 100 | attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1).squeeze() 101 | return list(attr_outputs.data.cpu().numpy()) 102 | 103 | 104 | def inference_no_grad(img_path, model, use_relu, use_sigmoid, is_train, resol=299, layer_idx=None): 105 | """ 106 | Extract activation from layer_idx of model for input from img_path (for linear probe) 107 | """ 108 | with torch.no_grad(): 109 | attr_outputs = inference(img_path, model, use_relu, use_sigmoid, is_train, resol, layer_idx) 110 | #return [list(o.cpu().numpy().squeeze())[0] for o in attr_outputs] 111 | return [o.cpu().numpy().squeeze()[()] for o in attr_outputs] 112 | 113 | 114 | def create_new_dataset(out_dir, field_change, compute_fn, datasets=['train', 'val', 'test'], data_dir=''): 115 | """ 116 | Generic function that given datasets stored in data_dir, modify/ add one field of the metadata in each dataset based on compute_fn 117 | and save the new datasets to out_dir 118 | compute_fn should take in a metadata object (that includes 'img_path', 'class_label', 'attribute_label', etc.) 119 | and return the updated value for field_change 120 | """ 121 | if not os.path.exists(out_dir): 122 | os.makedirs(out_dir) 123 | for dataset in datasets: 124 | path = os.path.join(data_dir, dataset + '.pkl') 125 | if not os.path.exists(path): 126 | continue 127 | data = pickle.load(open(path, 'rb')) 128 | new_data = [] 129 | for d in data: 130 | new_d = copy.deepcopy(d) 131 | new_value = compute_fn(d) 132 | if field_change in d: 133 | old_value = d[field_change] 134 | assert (type(old_value) == type(new_value)) 135 | new_d[field_change] = new_value 136 | new_data.append(new_d) 137 | f = open(os.path.join(out_dir, dataset + '.pkl'), 'wb') 138 | pickle.dump(new_data, f) 139 | f.close() 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--exp', type=str, 144 | choices=['ExtractConcepts', 'ExtractProbeRepresentations', 'DataEfficiencySplits', 'ChangeAdversarialDataDir', 'Concept', 'MV'], 145 | help='Name of experiment to run.') 146 | parser.add_argument('--model_path', type=str, help='Path of model') 147 | parser.add_argument('--out_dir', type=str, help='Output directory') 148 | parser.add_argument('--data_dir', type=str, help='Data directory') 149 | parser.add_argument('--adv_data_dir', type=str, help='Adversarial data directory') 150 | parser.add_argument('--train_splits', type=str, nargs='+', help='Train splits to use') 151 | parser.add_argument('--use_relu', action='store_true', help='Use Relu') 152 | parser.add_argument('--use_sigmoid', action='store_true', help='Use Sigmoid') 153 | parser.add_argument('--layer_idx', type=int, default=None, help='Layer id to extract probe representations') 154 | parser.add_argument('--n_samples', type=int, help='Number of samples for data efficiency split') 155 | parser.add_argument('--splits_dir', type=str, help='Data dir of splits') 156 | parser.add_argument('--modify_data_dir', type=str, help="Data dir to be modified") 157 | parser.add_argument('-class_label', type=str, default='binary', help='which class label to use') 158 | args = parser.parse_args() 159 | 160 | if args.exp == 'Concept': 161 | out_dir = os.path.join(os.getcwd(), args.data_dir) 162 | modify_data_dir = os.path.join(os.getcwd(), args.modify_data_dir) 163 | remove_sparse_concepts(50, out_dir, modify_data_dir=modify_data_dir) 164 | elif args.exp == 'ExtractConcepts': 165 | create_logits_data(args.model_path, args.out_dir, args.data_dir, args.use_relu, args.use_sigmoid) -------------------------------------------------------------------------------- /CUB/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utils for training, evaluation and data loading 3 | """ 4 | import os 5 | import torch 6 | import pickle 7 | import numpy as np 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | from CUB.config import BASE_DIR, N_ATTRIBUTES 12 | from torch.utils.data import BatchSampler 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | 16 | class CUBDataset(Dataset): 17 | """ 18 | Returns a compatible Torch Dataset object customized for the CUB dataset 19 | """ 20 | 21 | def __init__(self, pkl_file_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform=None): 22 | """ 23 | Arguments: 24 | pkl_file_paths: list of full path to all the pkl data 25 | use_attr: whether to load the attributes (e.g. False for simple finetune) 26 | no_img: whether to load the images (e.g. False for A -> Y model) 27 | uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75) 28 | image_dir: default = 'images'. Will be append to the parent dir 29 | n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible 30 | transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing 31 | 32 | return_certainty: whether to return uncertainty label or not 33 | return_path: whether to return image path or not 34 | """ 35 | self.data = [] 36 | self.is_train = any(["train" in path for path in pkl_file_paths]) 37 | if not self.is_train: 38 | assert any([("test" in path) or ("val" in path) for path in pkl_file_paths]) 39 | for file_path in pkl_file_paths: 40 | self.data.extend(pickle.load(open(file_path, 'rb'))) 41 | self.transform = transform 42 | self.use_attr = use_attr 43 | self.no_img = no_img 44 | self.uncertain_label = uncertain_label 45 | self.image_dir = image_dir 46 | self.n_class_attr = n_class_attr 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __getitem__(self, idx): 52 | img_data = self.data[idx] 53 | img_path = img_data['img_path'] 54 | # Trim unnecessary paths 55 | try: 56 | idx = img_path.split('/').index('CUB_200_2011') 57 | if self.image_dir != 'images': 58 | img_path = '/'.join([self.image_dir] + img_path.split('/')[idx+1:]) 59 | img_path = img_path.replace('images/', '') 60 | else: 61 | img_path = '/'.join(img_path.split('/')[idx:]) 62 | img = Image.open(img_path).convert('RGB') 63 | except: 64 | img_path_split = img_path.split('/') 65 | split = 'train' if self.is_train else 'test' 66 | img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:]) 67 | img = Image.open(img_path).convert('RGB') 68 | 69 | class_label = img_data['class_label'] 70 | if self.transform: 71 | img = self.transform(img) 72 | 73 | if self.use_attr: 74 | if self.uncertain_label: 75 | attr_label = img_data['uncertain_attribute_label'] 76 | else: 77 | attr_label = img_data['attribute_label'] 78 | if self.no_img: 79 | if self.n_class_attr == 3: 80 | one_hot_attr_label = np.zeros((N_ATTRIBUTES, self.n_class_attr)) 81 | one_hot_attr_label[np.arange(N_ATTRIBUTES), attr_label] = 1 82 | return one_hot_attr_label, class_label 83 | else: 84 | return attr_label, class_label 85 | else: 86 | return img, class_label, attr_label 87 | else: 88 | return img, class_label 89 | 90 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 91 | """Samples elements randomly from a given list of indices for imbalanced dataset 92 | Arguments: 93 | indices (list, optional): a list of indices 94 | num_samples (int, optional): number of samples to draw 95 | """ 96 | 97 | def __init__(self, dataset, indices=None): 98 | # if indices is not provided, 99 | # all elements in the dataset will be considered 100 | self.indices = list(range(len(dataset))) \ 101 | if indices is None else indices 102 | 103 | # if num_samples is not provided, 104 | # draw `len(indices)` samples in each iteration 105 | self.num_samples = len(self.indices) 106 | 107 | # distribution of classes in the dataset 108 | label_to_count = {} 109 | for idx in self.indices: 110 | label = self._get_label(dataset, idx) 111 | if label in label_to_count: 112 | label_to_count[label] += 1 113 | else: 114 | label_to_count[label] = 1 115 | 116 | # weight for each sample 117 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 118 | for idx in self.indices] 119 | self.weights = torch.DoubleTensor(weights) 120 | 121 | def _get_label(self, dataset, idx): # Note: for single attribute dataset 122 | return dataset.data[idx]['attribute_label'][0] 123 | 124 | def __iter__(self): 125 | idx = (self.indices[i] for i in torch.multinomial( 126 | self.weights, self.num_samples, replacement=True)) 127 | return idx 128 | 129 | def __len__(self): 130 | return self.num_samples 131 | 132 | def load_data(pkl_paths, use_attr, no_img, batch_size, uncertain_label=False, n_class_attr=2, image_dir='images', resampling=False, resol=299): 133 | """ 134 | Note: Inception needs (299,299,3) images with inputs scaled between -1 and 1 135 | Loads data with transformations applied, and upsample the minority class if there is class imbalance and weighted loss is not used 136 | NOTE: resampling is customized for first attribute only, so change sampler.py if necessary 137 | """ 138 | resized_resol = int(resol * 256/224) 139 | is_training = any(['train.pkl' in f for f in pkl_paths]) 140 | if is_training: 141 | transform = transforms.Compose([ 142 | #transforms.Resize((resized_resol, resized_resol)), 143 | #transforms.RandomSizedCrop(resol), 144 | transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), 145 | transforms.RandomResizedCrop(resol), 146 | transforms.RandomHorizontalFlip(), 147 | transforms.ToTensor(), #implicitly divides by 255 148 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]) 149 | #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), 150 | ]) 151 | else: 152 | transform = transforms.Compose([ 153 | #transforms.Resize((resized_resol, resized_resol)), 154 | transforms.CenterCrop(resol), 155 | transforms.ToTensor(), #implicitly divides by 255 156 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]) 157 | #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), 158 | ]) 159 | 160 | dataset = CUBDataset(pkl_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform) 161 | if is_training: 162 | drop_last = True 163 | shuffle = True 164 | else: 165 | drop_last = False 166 | shuffle = False 167 | if resampling: 168 | sampler = BatchSampler(ImbalancedDatasetSampler(dataset), batch_size=batch_size, drop_last=drop_last) 169 | loader = DataLoader(dataset, batch_sampler=sampler) 170 | else: 171 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) 172 | return loader 173 | 174 | def find_class_imbalance(pkl_file, multiple_attr=False, attr_idx=-1): 175 | """ 176 | Calculate class imbalance ratio for binary attribute labels stored in pkl_file 177 | If attr_idx >= 0, then only return ratio for the corresponding attribute id 178 | If multiple_attr is True, then return imbalance ratio separately for each attribute. Else, calculate the overall imbalance across all attributes 179 | """ 180 | imbalance_ratio = [] 181 | data = pickle.load(open(os.path.join(BASE_DIR, pkl_file), 'rb')) 182 | n = len(data) 183 | n_attr = len(data[0]['attribute_label']) 184 | if attr_idx >= 0: 185 | n_attr = 1 186 | if multiple_attr: 187 | n_ones = [0] * n_attr 188 | total = [n] * n_attr 189 | else: 190 | n_ones = [0] 191 | total = [n * n_attr] 192 | for d in data: 193 | labels = d['attribute_label'] 194 | if multiple_attr: 195 | for i in range(n_attr): 196 | n_ones[i] += labels[i] 197 | else: 198 | if attr_idx >= 0: 199 | n_ones[0] += labels[attr_idx] 200 | else: 201 | n_ones[0] += sum(labels) 202 | for j in range(len(n_ones)): 203 | imbalance_ratio.append(total[j]/n_ones[j] - 1) 204 | if not multiple_attr: #e.g. [9.0] --> [9.0] * 312 205 | imbalance_ratio *= n_attr 206 | return imbalance_ratio 207 | -------------------------------------------------------------------------------- /SKINCON/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utils for training, evaluation and data loading 3 | """ 4 | import os 5 | import torch 6 | import pickle 7 | import numpy as np 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | from CUB.config import BASE_DIR, N_ATTRIBUTES 12 | from torch.utils.data import BatchSampler 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | import pandas as pd 16 | from config import FITZPATRICK_CSV, FITZ_CONCEPTS_CSV, FITZPATRICK_DATA, CONCEPT_CANDIDATES 17 | 18 | 19 | class SKINCONDataset(Dataset): 20 | """ 21 | Returns a compatible Torch Dataset object customized for the CUB dataset 22 | """ 23 | 24 | def __init__(self, pkl_file_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform=None, chat=False, return_path=False, class_label='binary'): 25 | """ 26 | Arguments: 27 | pkl_file_paths: list of full path to all the pkl data 28 | use_attr: whether to load the attributes (e.g. False for simple finetune) 29 | no_img: whether to load the images (e.g. False for A -> Y model) 30 | uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75) 31 | image_dir: default = 'images'. Will be append to the parent dir 32 | n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible 33 | transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing 34 | 35 | return_certainty: whether to return uncertainty label or not 36 | return_path: whether to return image path or not 37 | class_label: which class label to use 38 | """ 39 | self.data = [] 40 | self.is_train = any(["train" in path for path in pkl_file_paths]) 41 | if not self.is_train: 42 | assert any([("test" in path) or ("val" in path) for path in pkl_file_paths]) 43 | for file_path in pkl_file_paths: 44 | self.data.extend(pickle.load(open(file_path, 'rb'))) 45 | self.transform = transform 46 | self.use_attr = use_attr 47 | self.no_img = no_img 48 | self.uncertain_label = uncertain_label 49 | self.image_dir = image_dir 50 | self.n_class_attr = n_class_attr 51 | self.chat = chat 52 | self.return_path = return_path 53 | self.class_label = class_label 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, idx): 59 | img_data = self.data[idx] 60 | img_path = img_data['img_path'] 61 | try: 62 | idx = img_path.split('/').index('data') 63 | img_path = '/'.join(img_path.split('/')[idx:]) 64 | img = Image.open(img_path).convert('RGB') 65 | except: 66 | img_path_split = img_path.split('/') 67 | split = 'train' if self.is_train else 'test' 68 | img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:]) 69 | img = Image.open(img_path).convert('RGB') 70 | 71 | if self.class_label == 'binary': 72 | class_label = img_data['benign_malignant'] 73 | elif self.class_label == 'three': 74 | class_label = img_data['three_partition_label'] 75 | elif self.class_label == 'nine': 76 | class_label = img_data['nine_partition_label'] 77 | elif self.class_label == 'whole': 78 | class_label = img_data['label'] 79 | 80 | if self.transform: 81 | img = self.transform(img) 82 | 83 | if self.use_attr: 84 | if self.uncertain_label: 85 | attr_label = img_data['uncertain_attribute_label'] 86 | else: 87 | attr_label = img_data['attribute_label'] 88 | if self.no_img: 89 | if self.n_class_attr == 3: 90 | one_hot_attr_label = np.zeros((N_ATTRIBUTES, self.n_class_attr)) 91 | one_hot_attr_label[np.arange(N_ATTRIBUTES), attr_label] = 1 92 | return one_hot_attr_label, class_label 93 | else: 94 | return attr_label, class_label 95 | else: 96 | if self.chat: 97 | return img, class_label, attr_label, img_data['chat'] 98 | else: 99 | if self.return_path: 100 | return img, class_label, attr_label, img_data['img_path'] 101 | else: 102 | return img, class_label, attr_label 103 | else: 104 | return img, class_label 105 | 106 | 107 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 108 | """Samples elements randomly from a given list of indices for imbalanced dataset 109 | Arguments: 110 | indices (list, optional): a list of indices 111 | num_samples (int, optional): number of samples to draw 112 | """ 113 | 114 | def __init__(self, dataset, indices=None): 115 | # if indices is not provided, 116 | # all elements in the dataset will be considered 117 | self.indices = list(range(len(dataset))) \ 118 | if indices is None else indices 119 | 120 | # if num_samples is not provided, 121 | # draw `len(indices)` samples in each iteration 122 | self.num_samples = len(self.indices) 123 | 124 | # distribution of classes in the dataset 125 | label_to_count = {} 126 | for idx in self.indices: 127 | label = self._get_label(dataset, idx) 128 | if label in label_to_count: 129 | label_to_count[label] += 1 130 | else: 131 | label_to_count[label] = 1 132 | 133 | # weight for each sample 134 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 135 | for idx in self.indices] 136 | self.weights = torch.DoubleTensor(weights) 137 | 138 | def _get_label(self, dataset, idx): # Note: for single attribute dataset 139 | return dataset.data[idx]['attribute_label'][0] 140 | 141 | def __iter__(self): 142 | idx = (self.indices[i] for i in torch.multinomial( 143 | self.weights, self.num_samples, replacement=True)) 144 | return idx 145 | 146 | def __len__(self): 147 | return self.num_samples 148 | 149 | def load_data(pkl_paths, use_attr, no_img, batch_size, uncertain_label=False, n_class_attr=2, image_dir='images', resampling=False, resol=299, chat=False, return_path=False, class_label='binary'): 150 | """ 151 | Note: Inception needs (299,299,3) images with inputs scaled between -1 and 1 152 | Loads data with transformations applied, and upsample the minority class if there is class imbalance and weighted loss is not used 153 | NOTE: resampling is customized for first attribute only, so change sampler.py if necessary 154 | """ 155 | resized_resol = int(resol * 256/224) 156 | is_training = any(['train.pkl' in f for f in pkl_paths]) 157 | if is_training: 158 | transform = transforms.Compose([ 159 | #transforms.Resize((resized_resol, resized_resol)), 160 | #transforms.RandomSizedCrop(resol), 161 | transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), 162 | transforms.RandomResizedCrop(resol), 163 | transforms.RandomHorizontalFlip(), 164 | transforms.ToTensor(), #implicitly divides by 255 165 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]) 166 | #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), 167 | ]) 168 | else: 169 | transform = transforms.Compose([ 170 | #transforms.Resize((resized_resol, resized_resol)), 171 | transforms.CenterCrop(resol), 172 | transforms.ToTensor(), #implicitly divides by 255 173 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]) 174 | #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]), 175 | ]) 176 | 177 | dataset = SKINCONDataset(pkl_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform, chat=chat, return_path=return_path, class_label=class_label) 178 | if is_training: 179 | drop_last = True 180 | shuffle = True 181 | else: 182 | drop_last = False 183 | shuffle = False 184 | if resampling: 185 | sampler = BatchSampler(ImbalancedDatasetSampler(dataset), batch_size=batch_size, drop_last=drop_last) 186 | loader = DataLoader(dataset, batch_sampler=sampler) 187 | else: 188 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) 189 | return loader 190 | 191 | def find_class_imbalance(pkl_file, multiple_attr=False, attr_idx=-1): 192 | """ 193 | Calculate class imbalance ratio for binary attribute labels stored in pkl_file 194 | If attr_idx >= 0, then only return ratio for the corresponding attribute id 195 | If multiple_attr is True, then return imbalance ratio separately for each attribute. Else, calculate the overall imbalance across all attributes 196 | """ 197 | imbalance_ratio = [] 198 | data = pickle.load(open(os.path.join(BASE_DIR, pkl_file), 'rb')) 199 | n = len(data) 200 | n_attr = len(data[0]['attribute_label']) 201 | if attr_idx >= 0: 202 | n_attr = 1 203 | if multiple_attr: 204 | n_ones = [0] * n_attr 205 | total = [n] * n_attr 206 | else: 207 | n_ones = [0] 208 | total = [n * n_attr] 209 | for d in data: 210 | labels = d['attribute_label'] 211 | if multiple_attr: 212 | for i in range(n_attr): 213 | n_ones[i] += labels[i] 214 | else: 215 | if attr_idx >= 0: 216 | n_ones[0] += labels[attr_idx] 217 | else: 218 | n_ones[0] += sum(labels) 219 | for j in range(len(n_ones)): 220 | imbalance_ratio.append(total[j]/n_ones[j] - 1) 221 | if not multiple_attr: #e.g. [9.0] --> [9.0] * 312 222 | imbalance_ratio *= n_attr 223 | return imbalance_ratio 224 | -------------------------------------------------------------------------------- /SYNTHETIC/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import random 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | import copy 10 | import torch 11 | 12 | from sklearn.model_selection import train_test_split 13 | 14 | import sys 15 | sys.path.insert(0, '../') 16 | 17 | def gen_data(args): 18 | if not os.path.exists(args.out_dir): 19 | os.makedirs(args.out_dir) 20 | 21 | # sample \alpha 22 | alpha = np.random.normal(loc=args.alpha_mean, scale=args.alpha_var, size=args.n_attributes) 23 | 24 | alpha = np.clip(alpha, 0, 1) 25 | 26 | beta = np.random.uniform(size=(args.n_groups, args.n_attributes)) 27 | 28 | group_concepts = (beta >= alpha).astype(int) 29 | 30 | n_classes_per_group = args.n_classes // args.n_groups 31 | 32 | class_concepts = np.empty((args.n_classes, args.n_attributes)) 33 | 34 | for group in range(args.n_groups): 35 | group_concept_value = group_concepts[group] 36 | change_indexes = np.random.choice(args.n_attributes, n_classes_per_group, replace=False) 37 | for i, index in enumerate(change_indexes): 38 | class_concept = group_concept_value.copy() 39 | class_concept[index] = 1 - class_concept[index] 40 | class_index = group * n_classes_per_group + i 41 | class_concepts[class_index, :] = class_concept 42 | 43 | class_concepts = class_concepts.astype(int) 44 | print(class_concepts) 45 | 46 | w_x = np.random.normal(scale=args.w_var, size=(args.input_dim, args.n_attributes)) 47 | 48 | np.save(os.path.join(args.out_dir, 'wx_save'), w_x) 49 | 50 | new_data = [] 51 | 52 | labels_list = [] 53 | 54 | for y in range(args.n_classes): 55 | for i in range(args.n_samples_per_class): 56 | z = np.random.normal(scale=args.z_var, size=args.input_dim) 57 | c = class_concepts[y, :] 58 | x = w_x @ c + z 59 | x = np.float32(x) 60 | id = y*args.n_samples_per_class + i 61 | data = {'id': id, 'input': x, 'z': z, 'attribute_label': c, 'label': y} 62 | new_data.append(data) 63 | labels_list.append(y) 64 | new_data = np.array(new_data) 65 | n_total_samples = args.n_samples_per_class * args.n_classes 66 | trainval_id, test_id, trainval_label, test_label = train_test_split(np.arange(n_total_samples), labels_list, stratify=labels_list, test_size=args.test_ratio, random_state=42) 67 | 68 | val_ratio = args.val_ratio/(1 - args.test_ratio) 69 | train_id, val_id, train_label, val_label = train_test_split(trainval_id, trainval_label, stratify=trainval_label, test_size=val_ratio, random_state=42) 70 | 71 | test_data = new_data[test_id] 72 | val_data = new_data[val_id] 73 | train_data = new_data[train_id] 74 | 75 | f_test = open(os.path.join(args.out_dir, 'test.pkl'), 'wb') 76 | f_val = open(os.path.join(args.out_dir, 'val.pkl'), 'wb') 77 | f_train = open(os.path.join(args.out_dir, 'train.pkl'), 'wb') 78 | 79 | pickle.dump(test_data, f_test) 80 | pickle.dump(val_data, f_val) 81 | pickle.dump(train_data, f_train) 82 | 83 | f_test.close() 84 | f_val.close() 85 | f_train.close() 86 | 87 | def inference(d, model, use_relu, use_sigmoid, is_train, layer_idx=None): 88 | """ 89 | For a single image stored in img_path, run inference using model and return A\hat (if layer_idx is None) or values extracted from layer layer_idx 90 | """ 91 | model.eval() 92 | 93 | inputs = torch.from_numpy(d['input']).unsqueeze(0) 94 | input_var = torch.autograd.Variable(inputs).cuda() 95 | 96 | outputs = model(input_var) 97 | new_outputs = [] 98 | 99 | for c in range(args.n_attributes): 100 | new_outputs.append(outputs[:, c].unsqueeze(1)) 101 | 102 | outputs = new_outputs 103 | if use_relu: 104 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs] 105 | elif use_sigmoid: 106 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs] 107 | else: 108 | attr_outputs = outputs 109 | 110 | attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1).squeeze() 111 | return list(attr_outputs.data.cpu().numpy()) 112 | 113 | def create_new_dataset(out_dir, field_change, compute_fn, datasets=['train', 'val', 'test'], data_dir=''): 114 | """ 115 | Generic function that given datasets stored in data_dir, modify/ add one field of the metadata in each dataset based on compute_fn 116 | and save the new datasets to out_dir 117 | compute_fn should take in a metadata object (that includes 'img_path', 'class_label', 'attribute_label', etc.) 118 | and return the updated value for field_change 119 | """ 120 | if not os.path.exists(out_dir): 121 | os.makedirs(out_dir) 122 | for dataset in datasets: 123 | path = os.path.join(data_dir, dataset + '.pkl') 124 | if not os.path.exists(path): 125 | continue 126 | data = pickle.load(open(path, 'rb')) 127 | new_data = [] 128 | for d in data: 129 | new_d = copy.deepcopy(d) 130 | new_value = compute_fn(d) 131 | if field_change in d: 132 | old_value = d[field_change] 133 | new_value = np.array(new_value) 134 | assert (type(old_value) == type(new_value)) 135 | new_d[field_change] = new_value 136 | new_data.append(new_d) 137 | f = open(os.path.join(out_dir, dataset + '.pkl'), 'wb') 138 | pickle.dump(new_data, f) 139 | f.close() 140 | 141 | def create_logits_data(model_path, out_dir, data_dir='', use_relu=False, use_sigmoid=False): 142 | """ 143 | Replace attribute labels in data_dir with the logits output by the model from model_path and save the new data to out_dir 144 | """ 145 | model = torch.load(model_path) 146 | get_logits_train = lambda d: inference(d, model, use_relu, use_sigmoid, is_train=True) 147 | get_logits_test = lambda d: inference(d, model, use_relu, use_sigmoid, is_train=False) 148 | create_new_dataset(out_dir, 'attribute_label', get_logits_train, datasets=['train'], data_dir=data_dir) 149 | create_new_dataset(out_dir, 'attribute_label', get_logits_train, datasets=['val', 'test'], data_dir=data_dir) 150 | 151 | def create_hidden_data(out_dir, data_dir, hidden_ratio, n_attributes, datasets=['train', 'val', 'test']): 152 | if not os.path.exists(out_dir): 153 | os.makedirs(out_dir) 154 | revealed_concept_num = int((1 - hidden_ratio) * n_attributes) 155 | revealed_concepts_indexes = np.random.choice(args.n_attributes, revealed_concept_num, replace=False) 156 | for dataset in datasets: 157 | path = os.path.join(data_dir, dataset + '.pkl') 158 | if not os.path.exists(path): 159 | continue 160 | data = pickle.load(open(path, 'rb')) 161 | new_data = [] 162 | for d in data: 163 | new_d = copy.deepcopy(d) 164 | attribute_label = d['attribute_label'] 165 | new_d['attribute_label'] = attribute_label[revealed_concepts_indexes] 166 | assert (type(d['attribute_label']) == type(new_d['attribute_label'])) 167 | new_data.append(new_d) 168 | f = open(os.path.join(out_dir, dataset + '.pkl'), 'wb') 169 | pickle.dump(new_data, f) 170 | f.close() 171 | 172 | def create_diversity_data(out_dir, data_dir, diversity_ratio, n_attributes, datasets=['train', 'val', 'test']): 173 | if not os.path.exists(out_dir): 174 | os.makedirs(out_dir) 175 | w_x = np.load(os.path.join(data_dir, 'wx_save.npy')) 176 | for dataset in datasets: 177 | path = os.path.join(data_dir, dataset + '.pkl') 178 | if not os.path.exists(path): 179 | continue 180 | data = pickle.load(open(path, 'rb')) 181 | new_data = [] 182 | for d in data: 183 | new_d = copy.deepcopy(d) 184 | attribute_label = d['attribute_label'] 185 | reverted_label = 1 - attribute_label 186 | reverted_mask = np.random.uniform(size=n_attributes) < np.ones(n_attributes) * diversity_ratio 187 | new_d['attribute_label'] = np.where(reverted_mask == 1, reverted_label, attribute_label) 188 | assert (type(d['attribute_label']) == type(new_d['attribute_label'])) 189 | new_x = w_x @ new_d['attribute_label'] + d['z'] 190 | new_d['input'] = np.float32(new_x) 191 | assert (type(d['input']) == type(new_d['input'])) 192 | new_data.append(new_d) 193 | f = open(os.path.join(out_dir, dataset + '.pkl'), 'wb') 194 | pickle.dump(new_data, f) 195 | f.close() 196 | 197 | def create_sparsity_data(args, datasets=['train', 'val', 'test']): 198 | if not os.path.exists(args.out_dir): 199 | os.makedirs(args.out_dir) 200 | alpha = np.random.normal(loc=args.alpha_mean, scale=args.alpha_var, size=args.n_attributes) 201 | 202 | alpha = np.clip(alpha, 0, 1) 203 | 204 | beta = np.random.uniform(size=(args.n_groups, args.n_attributes)) 205 | 206 | group_concepts = (beta >= alpha).astype(int) 207 | 208 | n_classes_per_group = args.n_classes // args.n_groups 209 | 210 | class_concepts = np.empty((args.n_classes, args.n_attributes)) 211 | 212 | for group in range(args.n_groups): 213 | group_concept_value = group_concepts[group] 214 | class_concepts[group * n_classes_per_group, :] = group_concept_value 215 | change_indexes = np.random.choice(args.n_attributes, n_classes_per_group, replace=False) 216 | prev_concept = group_concept_value.copy() 217 | for i, index in enumerate(change_indexes): 218 | class_concept = group_concept_value.copy() 219 | class_concept[index] = 1 - class_concept[index] 220 | class_index = group * n_classes_per_group + i 221 | class_concepts[class_index, :] = class_concept 222 | 223 | w_x = np.load(os.path.join(args.data_dir, 'wx_save.npy')) 224 | for dataset in datasets: 225 | path = os.path.join(args.data_dir, dataset + '.pkl') 226 | if not os.path.exists(path): 227 | continue 228 | data = pickle.load(open(path, 'rb')) 229 | new_data = [] 230 | for d in data: 231 | new_d = copy.deepcopy(d) 232 | y = d['label'] 233 | c = class_concepts[y, :] 234 | z = d['z'] 235 | x = w_x@c + z 236 | x = np.float32(x) 237 | new_d['attribute_label'] = c 238 | new_d['input'] = x 239 | assert (type(d['attribute_label']) == type(new_d['attribute_label'])) 240 | assert (type(d['input']) == type(new_d['input'])) 241 | new_data.append(new_d) 242 | f = open(os.path.join(args.out_dir, dataset + '.pkl'), 'wb') 243 | pickle.dump(new_data, f) 244 | f.close() 245 | 246 | if __name__ == '__main__': 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument('-exp', type=str, 249 | choices=['GenData', 'Hidden', 'Diversity', 'ExtractConcepts', 'Sparsity', 'Similarity'], 250 | help='Name of experiment to run.', default='GenData') 251 | parser.add_argument('-out_dir', type=str, help='Output directory') 252 | parser.add_argument('-n_samples_per_class', type=int, help='Number of samples to generate', default=100) 253 | parser.add_argument('-data_dir', type=str, help="Data dir to be modified") 254 | parser.add_argument('-input_dim', type=int, help='dimension of x', default=100) 255 | parser.add_argument('-n_attributes', type=int, help='dimension of c', default=100) 256 | parser.add_argument('-n_classes', type=int, help='dimension of y', default=100) 257 | parser.add_argument('-z_var', type=float, help='variance of z', default=2.0) 258 | parser.add_argument('-alpha_var', type=float, help='variance of alpha', default=0.1) 259 | parser.add_argument('-alpha_mean', type=float, help='mean of alpha', default=0.8) 260 | parser.add_argument('-w_var', type=float, help='variance of matrix W', default=0.1) 261 | parser.add_argument('-diversity_ratio', type=float, help='concept diversity factor') 262 | parser.add_argument('-hidden_ratio', type=float, help='number of hidden concepts') 263 | parser.add_argument('-test_ratio', type=float, help='ratio of test samples', default=0.15) 264 | parser.add_argument('-val_ratio', type=float, help='ratio of validation samples', default=0.15) 265 | parser.add_argument('-datasets', default=['train', 'test'], help='datasets to generate') 266 | parser.add_argument('-model_path', type=str, help='Path of model') 267 | parser.add_argument('--use_relu', action='store_true', help='Use Relu') 268 | parser.add_argument('--use_sigmoid', action='store_true', help='Use Sigmoid') 269 | parser.add_argument('-n_groups', type=int, help='number of similar groups', default=50) 270 | args = parser.parse_args() 271 | 272 | if args.exp == 'GenData': 273 | gen_data(args) 274 | elif args.exp == 'ExtractConcepts': 275 | create_logits_data(args.model_path, args.out_dir, args.data_dir, args.use_relu, args.use_sigmoid) 276 | elif args.exp == 'Hidden': 277 | create_hidden_data(args.out_dir, args.data_dir, args.hidden_ratio, args.n_attributes) 278 | elif args.exp == 'Diversity': 279 | create_diversity_data(args.out_dir, args.data_dir, args.diversity_ratio, args.n_attributes) 280 | elif args.exp == 'Sparsity': 281 | create_sparsity_data(args) 282 | elif args.exp == 'Similarity': 283 | create_sparsity_data(args) 284 | -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | 2 | import pdb 3 | import os 4 | import sys 5 | import sklearn 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from scipy.stats import pearsonr, spearmanr 9 | from sklearn.metrics import mean_squared_error, precision_recall_fscore_support, accuracy_score, precision_score, recall_score, balanced_accuracy_score, classification_report 10 | 11 | 12 | # ---------------------- OAI ---------------------- 13 | def plot(x, y, **kw): 14 | if kw.get('multiple_plots'): 15 | # MANY lines on MANY plots 16 | assert not kw.get('multiple_plot_cols') is None 17 | ncols = kw['multiple_plot_cols'] 18 | titles = kw['multiple_plot_titles'] 19 | suptitle = kw['suptitle'] 20 | sharex = kw['sharex'] if kw.get('sharex') else False 21 | sharey = kw['sharey'] if kw.get('sharey') else False 22 | nplots = len(x) 23 | nrows = np.ceil(nplots / ncols).astype(np.int32) 24 | fig_dims_w = 16 25 | fig_dims_h = nrows * (0.25 * fig_dims_w) 26 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, 27 | figsize=(fig_dims_w, fig_dims_h), 28 | sharex=sharex, sharey=sharey) 29 | if len(axes.shape) == 1: axes = axes[None,:] 30 | 31 | for n in range(nplots): 32 | i, j = n // ncols, n % ncols 33 | subplt = axes[i, j] 34 | for k, (x_, y_) in enumerate(zip(x[n], y[n])): 35 | plot_types = kw.get('plot_types') 36 | if plot_types: 37 | plot_type, plot_args = plot_types[n][k] 38 | if plot_type == 'line': 39 | subplt.plot(x_, y_, **plot_args) 40 | elif plot_type == 'scatter': 41 | subplt.scatter(x_, y_, **plot_args) 42 | else: 43 | subplt.plot(x_, y_) 44 | handle_plot_kwargs(subplt, **kw) 45 | subplt.set_title(titles[n]) 46 | fig.suptitle(**suptitle) 47 | plt.tight_layout() 48 | else: 49 | # ONE line on ONE plot 50 | plt.plot(x, y) 51 | plot_template_ending(**kw) 52 | plt.show() 53 | 54 | def handle_plot_kwargs(subplot=None, **kw): 55 | curr_plot = subplot if subplot else plt 56 | if kw.get('title'): curr_plot.title(kw['title']) 57 | if kw.get('xlabel'): curr_plot.xlabel(kw['xlabel']) 58 | if kw.get('ylabel'): curr_plot.ylabel(kw['ylabel']) 59 | if kw.get('margins'): curr_plot.margins(kw['margins']) 60 | if kw.get('xticks'): curr_plot.xticks(**kw['xticks']) 61 | if kw.get('yticks'): curr_plot.yticks(**kw['yticks']) 62 | if kw.get('xlim'): curr_plot.xlim(**kw['xlim']) 63 | if kw.get('ylim'): curr_plot.ylim(**kw['ylim']) 64 | if kw.get('set_xlim'): curr_plot.set_xlim(**kw['set_xlim']) 65 | if kw.get('set_ylim'): curr_plot.set_ylim(**kw['set_ylim']) 66 | if kw.get('subplots_adjust'): curr_plot.subplots_adjust(**kw['subplots_adjust']) 67 | 68 | def plot_template_ending(**kw): 69 | # Standard template ending for the plots 70 | handle_plot_kwargs(**kw) 71 | plt.show() 72 | 73 | def plot_violin(x_category, y, **kw): 74 | unique = np.unique(x_category) 75 | plot_x = range(len(unique)) 76 | plot_y = [y[x_category == val] for val in unique] 77 | plt.violinplot(plot_y, plot_x, points=60, widths=0.7, showmeans=False, showextrema=True, 78 | showmedians=True, bw_method=0.5) 79 | plot_template_ending(**kw) 80 | 81 | def plot_rmse(y_true, y_pred, **kw): 82 | unique = np.unique(y_true) 83 | plot_x = range(len(unique)) 84 | ids = [y_true == val for val in unique] 85 | rmses = [np.sqrt(mean_squared_error(y_true[idx], y_pred[idx])) for idx in ids] 86 | rmse = np.sqrt(mean_squared_error(y_true, y_pred)) 87 | kw['title'] = 'RMSE = %.3f' % rmse 88 | plot(plot_x, rmses, **kw) 89 | 90 | def plot_distributions(data, names, discrete=True): 91 | assert data.shape[1] == len(names) 92 | x = data.astype(np.int32) 93 | 94 | nplots = data.shape[1] 95 | ncols = 4 96 | nrows = np.ceil(nplots / ncols).astype(np.int32) 97 | fig, axes = plt.subplots(nrows=nrows, ncols=4, figsize=(12,12)) 98 | for n in range(nplots): 99 | i, j = n // ncols, n % ncols 100 | data = x[:,n] 101 | if discrete: 102 | nbins = len(np.unique(data)) 103 | axes[i, j].hist(data, bins=nbins) 104 | axes[i, j].set_title(names[n]) 105 | plt.tight_layout() 106 | plt.show() 107 | 108 | def assign_value_to_bins(value, bins, use_integer_bins=True): 109 | shape = value.shape 110 | value_vec = value.reshape(-1) 111 | dist = np.abs(value_vec[:,None] - bins[None,:]) 112 | bin_id = np.argmin(dist, axis=1) 113 | if use_integer_bins: 114 | new_values = bin_id 115 | else: 116 | new_values = bins[bin_id] 117 | new_values = new_values.reshape(shape) 118 | return new_values 119 | 120 | def convert_continuous_back_to_ordinal(y_true, y_pred, use_integer_bins=False): 121 | # Convert y_true into categories 122 | unique_y_true = np.unique(y_true) # (C,) 123 | N_classes = len(unique_y_true) 124 | one_hot_y_true = (y_true[:, None] == unique_y_true[None, :]) # (N,C) 125 | cat_y_true = np.dot(one_hot_y_true, np.arange(N_classes)) # (N,) 126 | y_pred_binned_i = assign_value_to_bins(y_pred, unique_y_true, use_integer_bins=use_integer_bins) 127 | return y_pred_binned_i, cat_y_true 128 | 129 | def assess_performance(y, yhat, names, prediction_type, prefix, verbose=False): 130 | """ 131 | Return standard metrics of performance of y and yhat. 132 | """ 133 | assert y.shape == yhat.shape, print('(%s) y: %s, yhat: %s' % (prefix, str(y.shape), str(yhat.shape)) ) 134 | assert y.shape[1] == len(names), print('%s) y: %s, len(names): %d' % (prefix, str(y.shape), len(names)) ) 135 | 136 | metrics = {} 137 | for i, name in enumerate(names): 138 | # This is to give each variable a unique key in the metrics dict 139 | prefix_name = '%s_%s_' % (prefix, name) 140 | # y and yhat can be (N,D), we analyse col by col 141 | y_i = y[:,i] 142 | yhat_i = yhat[:,i] 143 | 144 | if prediction_type == 'binary': 145 | assert set(np.unique(y_i)) == {0, 1} 146 | assert set(np.unique(yhat_i)) != {0, 1} 147 | fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true=y_i, y_score=yhat_i) 148 | auc = sklearn.metrics.roc_auc_score(y_true=y_i, y_score=yhat_i) 149 | auprc = sklearn.metrics.average_precision_score(y_true=y_i, y_score=yhat_i) 150 | metrics.update({ 151 | prefix_name+'auc': auc, 152 | prefix_name+'auprc': auprc, 153 | prefix_name+'tpr': tpr, 154 | prefix_name+'fpr': fpr 155 | }) 156 | 157 | elif prediction_type == 'multiclass': 158 | precision, recall, fbeta, support = precision_recall_fscore_support(y_i, yhat_i) 159 | metrics.update({ 160 | prefix_name+'precision': precision, 161 | prefix_name+'recall': recall, 162 | prefix_name+'fbeta': fbeta, 163 | prefix_name+'support': support, 164 | prefix_name+'macro_precision': np.mean(precision), 165 | prefix_name+'macro_recall': np.mean(recall), 166 | prefix_name+'macro_F1': np.mean(fbeta), 167 | }) 168 | 169 | elif prediction_type in ['continuous', 'continuous_ordinal']: 170 | r = pearsonr(y_i, yhat_i)[0] 171 | spearman_r = spearmanr(y_i, yhat_i)[0] 172 | rmse = np.sqrt(np.mean((y_i - yhat_i) ** 2)) 173 | metrics.update({ 174 | prefix_name+'r': r, 175 | prefix_name+'rmse': rmse, 176 | prefix_name+'negative_rmse': -rmse, 177 | prefix_name+'r^2': r ** 2, 178 | prefix_name+'spearman_r': spearman_r, 179 | prefix_name+'spearman_r^2': spearman_r ** 2, 180 | }) 181 | 182 | # Continuous ordinal means that the class is categorical & ordinal in nature but represented as continuous 183 | if prediction_type == 'continuous_ordinal': 184 | yhat_round_i, cat_y_i = convert_continuous_back_to_ordinal(y_i, yhat_i, use_integer_bins=True) 185 | precision, recall, fbeta, support = precision_recall_fscore_support(cat_y_i, yhat_round_i) 186 | 187 | metrics.update({ 188 | prefix_name+'precision': precision, 189 | prefix_name+'recall': recall, 190 | prefix_name+'F1': fbeta, 191 | prefix_name+'acc': accuracy_score(cat_y_i, yhat_round_i), 192 | prefix_name+'support': support, 193 | prefix_name+'macro_precision': np.mean(precision), 194 | prefix_name+'macro_recall': np.mean(recall), 195 | prefix_name+'macro_F1': np.mean(fbeta), 196 | }) 197 | 198 | metrics[prefix_name+'pred'] = yhat_i 199 | metrics[prefix_name+'true'] = y_i 200 | 201 | if verbose: 202 | if prediction_type in ['multiclass', 'continuous_ordinal']: 203 | N_classes = len(np.unique(y_i)) 204 | out = ('%11s |' % prefix_name[:-1]) + ('%8s|' * N_classes) % tuple([str(i) for i in range(N_classes)]) 205 | for metric in ['precision', 'recall', 'F1', 'support']: 206 | out += ('\n%11s |' % metric) 207 | for cls_id in range(N_classes): 208 | if metric == 'support': 209 | out += ' %6d |' % (metrics[prefix_name+metric][cls_id]) 210 | else: 211 | out += ' %04.1f |' % (metrics[prefix_name+metric][cls_id] * 100.) 212 | out += '\nMacro precision: %2.1f' % ((metrics[prefix_name+'macro_precision']) * 100.) 213 | out += '\nMacro recall : %2.1f' % ((metrics[prefix_name+'macro_recall']) * 100.) 214 | out += '\nMacro F1 : %2.1f' % ((metrics[prefix_name+'macro_F1']) * 100.) 215 | print(out) 216 | 217 | for metric in metrics: 218 | metric_type = '_'.join(metric.split('_')[2:]) 219 | if metric_type in ['tpr', 'fpr', 'precision', 'recall', 'F1', 'support', 'pred', 'true']: 220 | continue 221 | if np.isnan(metrics[metric]): 222 | pass 223 | # print(metric, metrics[metric]) 224 | # raise Exception("%s is a nan, something is weird about your predictor" % metric) 225 | return metrics 226 | 227 | # ---------------------- CUB ---------------------- 228 | class Logger(object): 229 | """ 230 | Log results to a file and flush() to view instant updates 231 | """ 232 | 233 | def __init__(self, fpath=None): 234 | self.console = sys.stdout 235 | self.file = None 236 | if fpath is not None: 237 | self.file = open(fpath, 'w') 238 | 239 | def __del__(self): 240 | self.close() 241 | 242 | def __enter__(self): 243 | pass 244 | 245 | def __exit__(self, *args): 246 | self.close() 247 | 248 | def write(self, msg): 249 | self.console.write(msg) 250 | if self.file is not None: 251 | self.file.write(msg) 252 | 253 | def flush(self): 254 | self.console.flush() 255 | if self.file is not None: 256 | self.file.flush() 257 | os.fsync(self.file.fileno()) 258 | 259 | def close(self): 260 | self.console.close() 261 | if self.file is not None: 262 | self.file.close() 263 | 264 | class AverageMeter(object): 265 | """ 266 | Computes and stores the average and current value 267 | """ 268 | 269 | def __init__(self): 270 | self.reset() 271 | 272 | def reset(self): 273 | self.val = 0 274 | self.avg = 0 275 | self.sum = 0 276 | self.count = 0 277 | 278 | def update(self, val, n=1): 279 | self.val = val 280 | self.sum += val * n 281 | self.count += n 282 | self.avg = self.sum / self.count 283 | 284 | def accuracy(output, target, topk=(1,)): 285 | """ 286 | Computes the precision@k for the specified values of k 287 | output and target are Torch tensors 288 | """ 289 | maxk = max(topk) 290 | batch_size = target.size(0) 291 | _, pred = output.topk(maxk, 1, True, True) 292 | pred = pred.t() 293 | temp = target.view(1, -1).expand_as(pred) 294 | temp = temp.cuda() 295 | 296 | correct = pred.eq(temp) 297 | 298 | res = [] 299 | for k in topk: 300 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 301 | res.append(correct_k.mul_(100.0 / batch_size)) 302 | return res 303 | 304 | def binary_accuracy(output, target): 305 | """ 306 | Computes the accuracy for multiple binary predictions 307 | output and target are Torch tensors 308 | """ 309 | pred = output.cpu() >= 0.5 310 | #print(list(output.data.cpu().numpy())) 311 | #print(list(pred.data[0].numpy())) 312 | #print(list(target.data[0].numpy())) 313 | #print(pred.size(), target.size()) 314 | acc = (pred.int()).eq(target.int()).sum() 315 | 316 | acc = acc*100 / np.prod(np.array(target.size())) 317 | return acc 318 | 319 | def multiclass_metric(output, target): 320 | """ 321 | Return balanced accuracy score (average of recall for each class) in case of class imbalance, 322 | and classification report containing precision, recall, F1 score for each class 323 | """ 324 | balanced_acc = balanced_accuracy_score(target, output) 325 | report = classification_report(target, output) 326 | return balanced_acc, report 327 | -------------------------------------------------------------------------------- /SYNTHETIC/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train InceptionV3 Network using the SYNTHETIC-200-2011 dataset 3 | """ 4 | import pdb 5 | import os 6 | import sys 7 | import argparse 8 | 9 | from SYNTHETIC.template_model import MLP 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | import math 13 | import torch 14 | import numpy as np 15 | from analysis import Logger, AverageMeter, accuracy, binary_accuracy 16 | 17 | from SYNTHETIC.dataset import load_data, find_class_imbalance 18 | from SYNTHETIC.backbone import ModelXtoChat_ChatToY, ModelXtoC, ModelOracleCtoY, ModelXtoCtoY 19 | 20 | import pickle 21 | import pandas as pd 22 | 23 | # Training 24 | UPWEIGHT_RATIO = 9.0 25 | MIN_LR = 0.0001 26 | LR_DECAY_SIZE = 0.1 27 | BASE_DIR = '' 28 | 29 | 30 | def run_epoch_simple(model, optimizer, loader, loss_meter, acc_meter, criterion, args, is_training): 31 | """ 32 | A -> Y: Predicting class labels using only attributes with MLP 33 | """ 34 | if is_training: 35 | model.train() 36 | else: 37 | model.eval() 38 | for _, data in enumerate(loader): 39 | inputs, labels = data 40 | if isinstance(inputs, list): 41 | #inputs = [i.long() for i in inputs] 42 | inputs = torch.stack(inputs).t().float() 43 | inputs = torch.flatten(inputs, start_dim=1).float() 44 | inputs_var = torch.autograd.Variable(inputs).cuda() 45 | inputs_var = inputs_var.cuda() if torch.cuda.is_available() else inputs_var 46 | labels_var = torch.autograd.Variable(labels).cuda() 47 | labels_var = labels_var.cuda() if torch.cuda.is_available() else labels_var 48 | 49 | outputs = model(inputs_var) 50 | loss = criterion(outputs, labels_var) 51 | acc = accuracy(outputs, labels, topk=(1,)) 52 | loss_meter.update(loss.item(), inputs.size(0)) 53 | acc_meter.update(acc[0], inputs.size(0)) 54 | 55 | if is_training: 56 | optimizer.zero_grad() #zero the parameter gradients 57 | loss.backward() 58 | optimizer.step() #optimizer step to update parameters 59 | return loss_meter, acc_meter 60 | 61 | def run_epoch(model, optimizer, loader, loss_meter, acc_meter, criterion, attr_criterion, args, is_training): 62 | """ 63 | For the rest of the networks (X -> A, cotraining, simple finetune) 64 | """ 65 | if is_training: 66 | model.train() 67 | else: 68 | model.eval() 69 | 70 | all_labels = [] 71 | all_preds = [] 72 | 73 | for _, data in enumerate(loader): 74 | if attr_criterion is None: 75 | inputs, labels = data 76 | attr_labels, attr_labels_var = None, None 77 | else: 78 | inputs, labels, attr_labels = data 79 | if args.n_attributes > 1: 80 | attr_labels = [i.long() for i in attr_labels] 81 | attr_labels = torch.stack(attr_labels)#.float() #N x 312 82 | else: 83 | if isinstance(attr_labels, list): 84 | attr_labels = attr_labels[0] 85 | attr_labels = attr_labels.unsqueeze(1) 86 | attr_labels_var = torch.autograd.Variable(attr_labels).float() 87 | attr_labels_var = attr_labels_var.cuda() if torch.cuda.is_available() else attr_labels_var 88 | 89 | inputs_var = torch.autograd.Variable(inputs) 90 | inputs_var = inputs_var.cuda() if torch.cuda.is_available() else inputs_var 91 | labels_var = torch.autograd.Variable(labels) 92 | labels_var = labels_var.cuda() if torch.cuda.is_available() else labels_var 93 | 94 | 95 | outputs = model(inputs_var) 96 | 97 | if args.bottleneck: 98 | new_outputs = [] 99 | 100 | for c in range(args.n_attributes): 101 | new_outputs.append(outputs[:, c].unsqueeze(1)) 102 | 103 | outputs = new_outputs 104 | losses = [] 105 | out_start = 0 106 | if not args.bottleneck: 107 | loss_main = criterion(outputs[0], labels_var) 108 | losses.append(loss_main) 109 | out_start = 1 110 | if attr_criterion is not None and args.attr_loss_weight > 0: #X -> A, cotraining, end2end 111 | for i in range(len(attr_criterion)): 112 | losses.append(args.attr_loss_weight * attr_criterion[i](outputs[i+out_start].squeeze().type(torch.cuda.FloatTensor), attr_labels_var[:, i])) 113 | 114 | if args.bottleneck: #attribute accuracy 115 | sigmoid_outputs = torch.nn.Sigmoid()(torch.cat(outputs, dim=1)) 116 | acc = binary_accuracy(sigmoid_outputs, attr_labels) 117 | acc_meter.update(acc.data.cpu().numpy(), inputs.size(0)) 118 | else: 119 | acc = accuracy(outputs[0], labels, topk=(1,)) #only care about class prediction accuracy 120 | acc_meter.update(acc[0], inputs.size(0)) 121 | 122 | if attr_criterion is not None: 123 | if args.bottleneck: 124 | total_loss = sum(losses)/ args.n_attributes 125 | else: #cotraining, loss by class prediction and loss by attribute prediction have the same weight 126 | total_loss = losses[0] + sum(losses[1:]) 127 | if args.normalize_loss: 128 | total_loss = total_loss / (1 + args.attr_loss_weight * args.n_attributes) 129 | else: #finetune 130 | total_loss = sum(losses) 131 | 132 | loss_meter.update(total_loss.item(), inputs.size(0)) 133 | if is_training: 134 | optimizer.zero_grad() 135 | total_loss.backward() 136 | optimizer.step() 137 | 138 | return loss_meter, acc_meter 139 | 140 | def train(model, args): 141 | # Determine imbalance 142 | imbalance = None 143 | if args.use_attr and not args.no_img and args.weighted_loss: 144 | train_data_path = os.path.join('', args.data_dir, 'train.pkl') 145 | if args.weighted_loss == 'multiple': 146 | imbalance = find_class_imbalance(train_data_path, True) 147 | else: 148 | imbalance = find_class_imbalance(train_data_path, False) 149 | 150 | if os.path.exists(args.log_dir): # job restarted by cluster 151 | for f in os.listdir(args.log_dir): 152 | os.remove(os.path.join(args.log_dir, f)) 153 | else: 154 | os.makedirs(args.log_dir) 155 | 156 | logger = Logger(os.path.join(args.log_dir, 'log.txt')) 157 | logger.write(str(args) + '\n') 158 | logger.write(str(imbalance) + '\n') 159 | logger.flush() 160 | 161 | model = model.cuda() 162 | criterion = torch.nn.CrossEntropyLoss() 163 | 164 | if args.use_attr and not args.no_img: 165 | attr_criterion = [] #separate criterion (loss function) for each attribute 166 | if args.weighted_loss: 167 | assert(imbalance is not None) 168 | for ratio in imbalance: 169 | if ratio == float('inf'): 170 | ratio = 1 171 | attr_criterion.append(torch.nn.BCEWithLogitsLoss(weight=torch.FloatTensor([ratio]).cuda())) 172 | else: 173 | for i in range(args.n_attributes): 174 | attr_criterion.append(torch.nn.CrossEntropyLoss()) 175 | else: 176 | attr_criterion = None 177 | 178 | if args.optimizer == 'Adam': 179 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 180 | elif args.optimizer == 'RMSprop': 181 | optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 182 | else: 183 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 184 | #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5, threshold=0.00001, min_lr=0.00001, eps=1e-08) 185 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.1) 186 | stop_epoch = int(math.log(MIN_LR / args.lr) / math.log(LR_DECAY_SIZE)) * args.scheduler_step 187 | print("Stop epoch: ", stop_epoch) 188 | 189 | train_data_path = os.path.join(BASE_DIR, args.data_dir, 'train.pkl') 190 | val_data_path = train_data_path.replace('train.pkl', 'val.pkl') 191 | logger.write('train data path: %s\n' % train_data_path) 192 | 193 | #print(train_data_path) 194 | 195 | if args.ckpt: #retraining 196 | train_loader = load_data([train_data_path], args.use_attr, args.no_img, args.batch_size, args.resampling) 197 | val_loader = None 198 | else: 199 | train_loader = load_data([train_data_path], args.use_attr, args.no_img, args.batch_size, args.resampling) 200 | val_loader = load_data([val_data_path], args.use_attr, args.no_img, args.batch_size, args.resampling) 201 | 202 | 203 | best_val_epoch = -1 204 | best_val_loss = float('inf') 205 | best_val_acc = 0 206 | 207 | for epoch in range(0, args.epochs): 208 | train_loss_meter = AverageMeter() 209 | train_acc_meter = AverageMeter() 210 | 211 | if args.no_img: 212 | train_loss_meter, train_acc_meter = run_epoch_simple(model, optimizer, train_loader, train_loss_meter, train_acc_meter, criterion, args, is_training=True) 213 | else: 214 | train_loss_meter, train_acc_meter = run_epoch(model, optimizer, train_loader, train_loss_meter, train_acc_meter, criterion, attr_criterion, args, is_training=True) 215 | 216 | if not args.ckpt: # evaluate on val set 217 | val_loss_meter = AverageMeter() 218 | val_acc_meter = AverageMeter() 219 | 220 | with torch.no_grad(): 221 | if args.no_img: 222 | val_loss_meter, val_acc_meter = run_epoch_simple(model, optimizer, val_loader, val_loss_meter, val_acc_meter, criterion, args, is_training=False) 223 | else: 224 | val_loss_meter, val_acc_meter = run_epoch(model, optimizer, val_loader, val_loss_meter, val_acc_meter, criterion, attr_criterion, args, is_training=False) 225 | 226 | else: 227 | val_loss_meter = train_loss_meter 228 | val_acc_meter = train_acc_meter 229 | 230 | 231 | if best_val_acc < val_acc_meter.avg: 232 | best_val_epoch = epoch 233 | best_val_acc = val_acc_meter.avg 234 | logger.write('New model best model at epoch %d\n' % epoch) 235 | torch.save(model, os.path.join(args.log_dir, 'best_model_%d.pth' % args.seed)) 236 | 237 | train_loss_avg = train_loss_meter.avg 238 | val_loss_avg = val_loss_meter.avg 239 | 240 | logger.write('Epoch [%d]:\tTrain loss: %.4f\tTrain accuracy: %.4f\t' 241 | 'Val loss: %.4f\tVal acc: %.4f\t' 242 | 'Best val epoch: %d\n' 243 | % (epoch, train_loss_avg, train_acc_meter.avg, val_loss_avg, val_acc_meter.avg, best_val_epoch)) 244 | logger.flush() 245 | 246 | 247 | if epoch <= stop_epoch: 248 | scheduler.step(epoch) #scheduler step to update lr at the end of epoch 249 | #inspect lr 250 | if epoch % 10 == 0: 251 | print('Current lr:', scheduler.get_lr()) 252 | 253 | # if epoch % args.save_step == 0: 254 | # torch.save(model, os.path.join(args.log_dir, '%d_model.pth' % epoch)) 255 | 256 | 257 | if epoch >= 100 and val_acc_meter.avg < 3: 258 | print("Early stopping because of low accuracy") 259 | break 260 | 261 | if epoch - best_val_epoch >= 100: 262 | print("Early stopping because acc hasn't improved for a long time") 263 | break 264 | 265 | def train_X_to_C(args): 266 | model = ModelXtoC(input_dim=args.input_dim, n_attributes=args.n_attributes, expand_dim=args.expand_dim) 267 | train(model, args) 268 | 269 | def train_oracle_C_to_y_and_test_on_Chat(args): 270 | model = ModelOracleCtoY(n_attributes=args.n_attributes, 271 | num_classes=args.n_classes, expand_dim=args.expand_dim) 272 | train(model, args) 273 | 274 | def train_Chat_to_y_and_test_on_Chat(args): 275 | model = ModelOracleCtoY(n_attributes=args.n_attributes, 276 | num_classes=args.n_classes, expand_dim=args.expand_dim) 277 | train(model, args) 278 | 279 | def train_X_to_C_to_y(args): 280 | model = ModelXtoCtoY(input_dim=args.input_dim, num_classes=args.n_classes, n_attributes=args.n_attributes, expand_dim= args.expand_dim, 281 | use_relu=args.use_relu, use_sigmoid=args.use_sigmoid) 282 | train(model, args) 283 | 284 | 285 | def parse_arguments(experiment): 286 | # Get argparse configs from user 287 | parser = argparse.ArgumentParser(description='SYNTHETIC Training') 288 | parser.add_argument('dataset', type=str, help='Name of the dataset.') 289 | parser.add_argument('exp', type=str, 290 | choices=['Concept_XtoC', 'Independent_CtoY', 'Sequential_CtoY', 291 | 'Standard', 'Multitask', 'Joint', 'AUC', 'Debiased', 292 | 'Probe', 293 | 'TTI', 'Robustness', 'HyperparameterSearch'], 294 | help='Name of experiment to run.') 295 | parser.add_argument('--seed', required=True, type=int, help='Numpy and torch seed.') 296 | 297 | 298 | parser.add_argument('-log_dir', default=None, help='where the trained model is saved') 299 | parser.add_argument('-batch_size', '-b', type=int, help='mini-batch size') 300 | parser.add_argument('-epochs', '-e', type=int, help='epochs for training process') 301 | parser.add_argument('-save_step', default=1000, type=int, help='number of epochs to save model') 302 | parser.add_argument('-lr', type=float, help="learning rate") 303 | parser.add_argument('-weight_decay', type=float, default=5e-5, help='weight decay for optimizer') 304 | parser.add_argument('-use_attr', action='store_true', 305 | help='whether to use attributes (FOR COTRAINING ARCHITECTURE ONLY)') 306 | parser.add_argument('-attr_loss_weight', default=1.0, type=float, help='weight for loss by predicting attributes') 307 | parser.add_argument('-no_img', action='store_true', 308 | help='if included, only use attributes (and not raw imgs) for class prediction') 309 | parser.add_argument('-bottleneck', help='whether to predict attributes before class labels', action='store_true') 310 | parser.add_argument('-weighted_loss', default='', # note: may need to reduce lr 311 | help='Whether to use weighted loss for single attribute or multiple ones') 312 | parser.add_argument('-n_attributes', type=int, default=100, 313 | help='whether to apply bottlenecks to only a few attributes') 314 | parser.add_argument('-input_dim', type=int, default=100, help='dimension of x') 315 | parser.add_argument('-n_classes', type=int, default=100, help='number of classes') 316 | parser.add_argument('-expand_dim', type=int, default=0, 317 | help='dimension of hidden layer (if we want to increase model capacity) - for bottleneck only') 318 | parser.add_argument('-n_class_attr', type=int, default=2, 319 | help='whether attr prediction is a binary or triary classification') 320 | parser.add_argument('-data_dir', default='official_datasets', help='directory to the training data') 321 | parser.add_argument('-image_dir', default='fitz', help='test image folder to run inference on') 322 | parser.add_argument('-resampling', help='Whether to use resampling', action='store_true') 323 | parser.add_argument('-end2end', action='store_true', 324 | help='Whether to train X -> A -> Y end to end. Train cmd is the same as cotraining + this arg') 325 | parser.add_argument('-optimizer', default='SGD', help='Type of optimizer to use, options incl SGD, RMSProp, Adam') 326 | parser.add_argument('-ckpt', default='', help='For retraining on both train + val set') 327 | parser.add_argument('-scheduler_step', type=int, default=1000, 328 | help='Number of steps before decaying current learning rate by half') 329 | parser.add_argument('-normalize_loss', action='store_true', 330 | help='Whether to normalize loss by taking attr_loss_weight into account') 331 | parser.add_argument('-use_relu', action='store_true', 332 | help='Whether to include relu activation before using attributes to predict Y. ' 333 | 'For end2end & bottleneck model') 334 | parser.add_argument('-use_sigmoid', action='store_true', 335 | help='Whether to include sigmoid activation before using attributes to predict Y. ' 336 | 'For end2end & bottleneck model') 337 | parser.add_argument('-fix_seed', action='store_true', help='whether to fix seed') 338 | args = parser.parse_args() 339 | return (args,) -------------------------------------------------------------------------------- /SYNTHETIC/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate trained models on the official SYNTHETIC test set 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import joblib 8 | import argparse 9 | import numpy as np 10 | from sklearn.metrics import f1_score 11 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 12 | 13 | from SYNTHETIC.dataset import load_data 14 | from SYNTHETIC.config import BASE_DIR, N_CLASSES, N_ATTRIBUTES 15 | from analysis import AverageMeter, multiclass_metric, accuracy, binary_accuracy 16 | 17 | K = [1, 3, 5] #top k class accuracies to compute 18 | 19 | def eval(args): 20 | """ 21 | Run inference using model (and model2 if bottleneck) 22 | Returns: (for notebook analysis) 23 | all_class_labels: flattened list of class labels for each image 24 | topk_class_outputs: array of top k class ids predicted for each image. Shape = size of test set * max(K) 25 | all_class_outputs: array of all logit outputs for class prediction, shape = N_TEST * N_CLASS 26 | all_attr_labels: flattened list of labels for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 27 | all_attr_outputs: flatted list of attribute logits (after ReLU/ Sigmoid respectively) predicted for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 28 | all_attr_outputs_sigmoid: flatted list of attribute logits predicted (after Sigmoid) for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 29 | wrong_idx: image ids where the model got the wrong class prediction (to compare with other models) 30 | """ 31 | if args.model_dir: 32 | model = torch.load(args.model_dir) 33 | else: 34 | model = None 35 | 36 | if not hasattr(model, 'use_relu'): 37 | if args.use_relu: 38 | model.use_relu = True 39 | else: 40 | model.use_relu = False 41 | if not hasattr(model, 'use_sigmoid'): 42 | if args.use_sigmoid: 43 | model.use_sigmoid = True 44 | else: 45 | model.use_sigmoid = False 46 | if not hasattr(model, 'cy_fc'): 47 | model.cy_fc = None 48 | model.eval() 49 | 50 | incorrect_idx_list = [[] for _ in range(16) ] 51 | 52 | if args.model_dir2: 53 | if 'rf' in args.model_dir2: 54 | model2 = joblib.load(args.model_dir2) 55 | else: 56 | model2 = torch.load(args.model_dir2) 57 | if not hasattr(model2, 'use_relu'): 58 | if args.use_relu: 59 | model2.use_relu = True 60 | else: 61 | model2.use_relu = False 62 | if not hasattr(model2, 'use_sigmoid'): 63 | if args.use_sigmoid: 64 | model2.use_sigmoid = True 65 | else: 66 | model2.use_sigmoid = False 67 | model2.eval() 68 | else: 69 | model2 = None 70 | 71 | if args.use_attr: 72 | attr_acc_meter = [AverageMeter()] 73 | if args.feature_group_results: # compute acc for each feature individually in addition to the overall accuracy 74 | for _ in range(args.n_attributes): 75 | attr_acc_meter.append(AverageMeter()) 76 | else: 77 | attr_acc_meter = None 78 | 79 | class_acc_meter = [] 80 | for j in range(len(K)): 81 | class_acc_meter.append(AverageMeter()) 82 | 83 | data_dir = os.path.join(BASE_DIR, args.data_dir, args.eval_data + '.pkl') 84 | loader = load_data([data_dir], args.use_attr, args.no_img, args.batch_size) 85 | all_outputs, all_targets = [], [] 86 | all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, all_attr_outputs2 = [], [], [], [] 87 | all_class_labels, all_class_outputs, all_class_logits = [], [], [] 88 | topk_class_labels, topk_class_outputs = [], [] 89 | 90 | all_image_paths = [] 91 | 92 | data_size = 0 93 | 94 | incorrect_cnt = np.zeros(args.n_attributes) # i-th element: number of images where i number of concepts are mispredicted 95 | y_correct_by_incorrect_c = np.zeros(args.n_attributes) # i-th element: number of correctly classified images when i number of concepts are mispredicted 96 | 97 | for data_idx, data in enumerate(loader): 98 | if args.use_attr: 99 | if args.no_img: # A -> Y 100 | inputs, labels = data 101 | if isinstance(inputs, list): 102 | inputs = torch.stack(inputs).t().float() 103 | inputs = inputs.float() 104 | # inputs = torch.flatten(inputs, start_dim=1).float() 105 | else: 106 | inputs, labels, attr_labels = data 107 | attr_labels = [i.long() for i in attr_labels] 108 | attr_labels = torch.stack(attr_labels) # N x 312 109 | else: # simple finetune 110 | inputs, labels = data 111 | 112 | inputs_var = torch.autograd.Variable(inputs).cuda() 113 | labels_var = torch.autograd.Variable(labels).cuda() 114 | 115 | 116 | outputs = model(inputs_var) 117 | if args.bottleneck: 118 | 119 | new_outputs = [] 120 | 121 | for c in range(args.n_attributes): 122 | new_outputs.append(outputs[:, c].unsqueeze(1)) 123 | 124 | outputs = new_outputs 125 | 126 | if args.use_attr: 127 | if args.no_img: # A -> Y 128 | class_outputs = outputs 129 | else: 130 | if args.bottleneck: 131 | if args.use_relu: 132 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs] 133 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 134 | elif args.use_sigmoid: 135 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs] 136 | attr_outputs_sigmoid = attr_outputs 137 | else: 138 | attr_outputs = outputs 139 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 140 | if model2: 141 | stage2_inputs = torch.cat(attr_outputs, dim=1) 142 | if args.inference_mode == 'soft' or args.inference_mode == 'hard': 143 | if args.inference_mode == 'hard': 144 | stage2_inputs = stage2_inputs >= (torch.ones_like(stage2_inputs) * 0.5) 145 | stage2_inputs = stage2_inputs.float() 146 | class_outputs = model2(stage2_inputs) 147 | elif args.inference_mode == 'samp': 148 | class_outputs_all = [] 149 | for _ in range(args.mc_samples): 150 | rand_num = torch.rand(*stage2_inputs.size()) 151 | sampled_stage2_inputs = rand_num.cuda() < stage2_inputs 152 | sampled_stage2_inputs = sampled_stage2_inputs.float() 153 | _class_outputs = model2(sampled_stage2_inputs) 154 | class_outputs_all.append(class_outputs_all) 155 | class_outputs = torch.mean(torch.stack(class_outputs_all, axis=0), axis=0) 156 | else: # for debugging bottleneck performance without running stage 2 157 | class_outputs = torch.zeros([inputs.size(0), N_CLASSES], 158 | dtype=torch.float64).cuda() # ignore this 159 | else: # cotraining, end2end 160 | if args.use_relu: 161 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs[1:]] 162 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 163 | elif args.use_sigmoid: 164 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 165 | attr_outputs_sigmoid = attr_outputs 166 | else: 167 | attr_outputs = outputs[1:] 168 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 169 | 170 | class_outputs = outputs[0] 171 | 172 | batch_incorrect_num = np.zeros(inputs.size(0)) 173 | 174 | for i in range(args.n_attributes): 175 | 176 | pred = attr_outputs_sigmoid[i].squeeze().cpu() >= 0.5 177 | eq = (pred.int()).eq(attr_labels[:,i].int()) 178 | 179 | not_eq = [int(not ele) for ele in eq] 180 | 181 | batch_incorrect_num += not_eq 182 | 183 | acc = binary_accuracy(attr_outputs_sigmoid[i].squeeze(), attr_labels[:, i]) 184 | acc = acc.data.cpu().numpy() 185 | # acc = accuracy(attr_outputs_sigmoid[i], attr_labels[:, i], topk=(1,)) 186 | attr_acc_meter[0].update(acc, inputs.size(0)) 187 | if args.feature_group_results: # keep track of accuracy of individual attributes 188 | attr_acc_meter[i + 1].update(acc, inputs.size(0)) 189 | 190 | attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1) 191 | attr_outputs_sigmoid = torch.cat([o for o in attr_outputs_sigmoid], dim=1) 192 | all_attr_outputs.extend(list(attr_outputs.flatten().data.cpu().numpy())) 193 | all_attr_outputs_sigmoid.extend(list(attr_outputs_sigmoid.flatten().data.cpu().numpy())) 194 | all_attr_labels.extend(list(attr_labels.flatten().data.cpu().numpy())) 195 | 196 | for i in range(inputs.size(0)): 197 | incorrect_cnt[int(batch_incorrect_num[i])] += 1 198 | 199 | 200 | for i in range(inputs.size(0)): 201 | incorrect = int(batch_incorrect_num[i]) 202 | if incorrect >= 15: 203 | incorrect_idx_list[15].append(data_size + i) 204 | else: 205 | incorrect_idx_list[incorrect].append(data_size + i) 206 | data_size += inputs.size(0) 207 | 208 | else: 209 | class_outputs = outputs[0] 210 | 211 | _, topk_preds = class_outputs.topk(max(K), 1, True, True) 212 | _, preds = class_outputs.topk(1, 1, True, True) 213 | 214 | all_class_outputs.extend(list(preds.detach().cpu().numpy().flatten())) 215 | all_class_labels.extend(list(labels.data.cpu().numpy())) 216 | all_class_logits.extend(class_outputs.detach().cpu().numpy()) 217 | topk_class_outputs.extend(topk_preds.detach().cpu().numpy()) 218 | topk_class_labels.extend(labels.view(-1, 1).expand_as(preds)) 219 | 220 | np.set_printoptions(threshold=sys.maxsize) 221 | class_acc = accuracy(class_outputs, labels, topk=K) # only class prediction accuracy 222 | 223 | _, pred = class_outputs.topk(1, 1, True, True) 224 | pred = pred.t() 225 | temp = labels.view(1, -1).expand_as(pred) 226 | temp = temp.cuda() 227 | class_eq = pred.eq(temp)[0] 228 | 229 | for i in range(inputs.size(0)): 230 | incorrect_num = int(batch_incorrect_num[i]) 231 | y_correct_by_incorrect_c [incorrect_num] += int(class_eq[i]) 232 | 233 | for m in range(len(class_acc_meter)): 234 | class_acc_meter[m].update(class_acc[m], inputs.size(0)) 235 | 236 | all_class_logits = np.vstack(all_class_logits) 237 | topk_class_outputs = np.vstack(topk_class_outputs) 238 | topk_class_labels = np.vstack(topk_class_labels) 239 | wrong_idx = np.where(np.sum(topk_class_outputs == topk_class_labels, axis=1) == 0)[0] 240 | 241 | yerr_by_cerr = 1 - np.array(y_correct_by_incorrect_c)/(np.array(incorrect_cnt)+1e-6) 242 | 243 | for j in range(len(K)): 244 | print('Average top %d class accuracy: %.5f' % (K[j], class_acc_meter[j].avg)) 245 | 246 | if args.use_attr and not args.no_img: # print some metrics for attribute prediction performance 247 | print('Average attribute accuracy: %.5f' % attr_acc_meter[0].avg) 248 | all_attr_outputs_int = np.array(all_attr_outputs_sigmoid) >= 0.5 249 | if args.feature_group_results: 250 | n = len(all_attr_labels) 251 | all_attr_acc, all_attr_f1 = [], [] 252 | for i in range(args.n_attributes): 253 | acc_meter = attr_acc_meter[1 + i] 254 | attr_acc = float(acc_meter.avg) 255 | attr_preds = [all_attr_outputs_int[j] for j in range(n) if j % args.n_attributes == i] 256 | attr_labels = [all_attr_labels[j] for j in range(n) if j % args.n_attributes == i] 257 | attr_f1 = f1_score(attr_labels, attr_preds) 258 | all_attr_acc.append(attr_acc) 259 | all_attr_f1.append(attr_f1) 260 | 261 | bins = np.arange(0, 1.01, 0.1) 262 | acc_bin_ids = np.digitize(np.array(all_attr_acc) / 100.0, bins) 263 | acc_counts_per_bin = [np.sum(acc_bin_ids == (i + 1)) for i in range(len(bins))] 264 | f1_bin_ids = np.digitize(np.array(all_attr_f1), bins) 265 | f1_counts_per_bin = [np.sum(f1_bin_ids == (i + 1)) for i in range(len(bins))] 266 | print("Accuracy bins:") 267 | print(acc_counts_per_bin) 268 | print("F1 bins:") 269 | print(f1_counts_per_bin) 270 | np.savetxt(os.path.join(args.log_dir, 'concepts.txt'), f1_counts_per_bin) 271 | 272 | balanced_acc, report = multiclass_metric(all_attr_outputs_int, all_attr_labels) 273 | f1 = f1_score(all_attr_labels, all_attr_outputs_int) 274 | print("Total 1's predicted:", sum(np.array(all_attr_outputs_sigmoid) >= 0.5) / len(all_attr_outputs_sigmoid)) 275 | print('Avg attribute balanced acc: %.5f' % (balanced_acc)) 276 | print("Avg attribute F1 score: %.5f" % f1) 277 | print(report + '\n') 278 | 279 | return class_acc_meter, attr_acc_meter, all_class_labels, topk_class_outputs, all_class_logits, all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, wrong_idx, all_attr_outputs2, incorrect_idx_list, all_image_paths, yerr_by_cerr 280 | 281 | if __name__ == '__main__': 282 | torch.backends.cudnn.benchmark=True 283 | parser = argparse.ArgumentParser(description='PyTorch Training') 284 | parser.add_argument('-log_dir', default='.', help='where results are stored') 285 | parser.add_argument('-model_dirs', default=None, nargs='+', help='where the trained models are saved') 286 | parser.add_argument('-model_dirs2', default=None, nargs='+', help='where another trained model are saved (for bottleneck only)') 287 | parser.add_argument('-eval_data', default='test', help='Type of data (train/ val/ test) to be used') 288 | parser.add_argument('-use_attr', help='whether to use attributes (FOR COTRAINING ARCHITECTURE ONLY)', action='store_true') 289 | parser.add_argument('-no_img', help='if included, only use attributes (and not raw imgs) for class prediction', action='store_true') 290 | parser.add_argument('-bottleneck', help='whether to predict attributes before class labels', action='store_true') 291 | parser.add_argument('-image_dir', default='images', help='test image folder to run inference on') 292 | parser.add_argument('-n_class_attr', type=int, default=2, help='whether attr prediction is a binary or triary classification') 293 | parser.add_argument('-data_dir', default='', help='directory to the data used for evaluation') 294 | parser.add_argument('-n_attributes', type=int, default=N_ATTRIBUTES, help='whether to apply bottlenecks to only a few attributes') 295 | parser.add_argument('-attribute_group', default=None, help='file listing the (trained) model directory for each attribute group') 296 | parser.add_argument('-feature_group_results', help='whether to print out performance of individual atttributes', action='store_true') 297 | parser.add_argument('-use_relu', help='Whether to include relu activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 298 | parser.add_argument('-use_sigmoid', help='Whether to include sigmoid activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 299 | parser.add_argument('-inference_mode', default='soft', help='mode of inference') 300 | args = parser.parse_args() 301 | args.batch_size = 16 302 | 303 | print(args) 304 | y_results, c_results = [], [] 305 | yerr_by_cerr_results = [] 306 | for i, model_dir in enumerate(args.model_dirs): 307 | args.model_dir = model_dir 308 | args.model_dir2 = args.model_dirs2[i] if args.model_dirs2 else None 309 | result = eval(args) 310 | class_acc_meter, attr_acc_meter = result[0], result[1] 311 | yerr_by_cerr = result[-2] 312 | y_results.append(1 - class_acc_meter[0].avg[0].item() / 100.) 313 | if attr_acc_meter is not None: 314 | c_results.append(1 - attr_acc_meter[0].avg.item() / 100.) 315 | else: 316 | c_results.append(-1) 317 | yerr_by_cerr_results.append(yerr_by_cerr) 318 | values = (np.mean(y_results), np.std(y_results), np.mean(c_results), np.std(c_results)) 319 | output_string = '%.4f %.4f %.4f %.4f' % values 320 | print_string = 'Error of y: %.4f +- %.4f, Error of C: %.4f +- %.4f' % values 321 | print(print_string) 322 | 323 | yerr_by_cerr_mean = np.array(yerr_by_cerr_results).mean(axis=0) 324 | yerr_by_cerr_std = np.array(yerr_by_cerr_results).std(axis=0) 325 | 326 | #print("yerr_by_cerr_mean", yerr_by_cerr_mean.tolist()) 327 | #print("yerr_by_cerr_std", yerr_by_cerr_std.tolist()) 328 | output = open(os.path.join(args.log_dir, 'results.txt'), 'w') 329 | output.write(output_string) -------------------------------------------------------------------------------- /CUB/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate trained models on the official CUB test set 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import joblib 8 | import argparse 9 | import numpy as np 10 | from sklearn.metrics import f1_score 11 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 12 | 13 | from CUB.dataset import load_data 14 | from CUB.config import BASE_DIR, N_CLASSES, N_ATTRIBUTES, GROUP_DICT 15 | from analysis import AverageMeter, multiclass_metric, accuracy, binary_accuracy 16 | 17 | K = [1, 3, 5] #top k class accuracies to compute 18 | 19 | def eval(args): 20 | """ 21 | Run inference using model (and model2 if bottleneck) 22 | Returns: (for notebook analysis) 23 | all_class_labels: flattened list of class labels for each image 24 | topk_class_outputs: array of top k class ids predicted for each image. Shape = size of test set * max(K) 25 | all_class_outputs: array of all logit outputs for class prediction, shape = N_TEST * N_CLASS 26 | all_attr_labels: flattened list of labels for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 27 | all_attr_outputs: flatted list of attribute logits (after ReLU/ Sigmoid respectively) predicted for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 28 | all_attr_outputs_sigmoid: flatted list of attribute logits predicted (after Sigmoid) for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 29 | wrong_idx: image ids where the model got the wrong class prediction (to compare with other models) 30 | incorrect_idx_list: list of lists where the i-th list contains the images which predicted i concepts incorrectly 31 | """ 32 | if args.model_dir: 33 | model = torch.load(args.model_dir) 34 | else: 35 | model = None 36 | 37 | if not hasattr(model, 'use_relu'): 38 | if args.use_relu: 39 | model.use_relu = True 40 | else: 41 | model.use_relu = False 42 | if not hasattr(model, 'use_sigmoid'): 43 | if args.use_sigmoid: 44 | model.use_sigmoid = True 45 | else: 46 | model.use_sigmoid = False 47 | if not hasattr(model, 'cy_fc'): 48 | model.cy_fc = None 49 | model.eval() 50 | 51 | incorrect_idx_list = [[] for _ in range(16) ] 52 | 53 | if args.model_dir2: 54 | if 'rf' in args.model_dir2: 55 | model2 = joblib.load(args.model_dir2) 56 | else: 57 | model2 = torch.load(args.model_dir2) 58 | if not hasattr(model2, 'use_relu'): 59 | if args.use_relu: 60 | model2.use_relu = True 61 | else: 62 | model2.use_relu = False 63 | if not hasattr(model2, 'use_sigmoid'): 64 | if args.use_sigmoid: 65 | model2.use_sigmoid = True 66 | else: 67 | model2.use_sigmoid = False 68 | model2.eval() 69 | else: 70 | model2 = None 71 | 72 | if args.use_attr: 73 | attr_acc_meter = [AverageMeter()] 74 | if args.feature_group_results: # compute acc for each feature individually in addition to the overall accuracy 75 | for _ in range(args.n_attributes): 76 | attr_acc_meter.append(AverageMeter()) 77 | else: 78 | attr_acc_meter = None 79 | 80 | class_acc_meter = [] 81 | for j in range(len(K)): 82 | class_acc_meter.append(AverageMeter()) 83 | 84 | data_dir = os.path.join(BASE_DIR, args.data_dir, args.eval_data + '.pkl') 85 | loader = load_data([data_dir], args.use_attr, args.no_img, args.batch_size, image_dir=args.image_dir, 86 | n_class_attr=args.n_class_attr) 87 | 88 | all_outputs, all_targets = [], [] 89 | all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, all_attr_outputs2 = [], [], [], [] 90 | all_class_labels, all_class_outputs, all_class_logits = [], [], [] 91 | topk_class_labels, topk_class_outputs = [], [] 92 | 93 | data_size = 0 94 | 95 | incorrect_cnt = np.zeros(args.n_attributes) # i-th element: number of images where 'i' concepts are mispredicted 96 | y_correct_by_incorrect_c = np.zeros(args.n_attributes) # i-th element: number of correctly classified images when i number of concepts are mispredicted 97 | 98 | for data_idx, data in enumerate(loader): 99 | if args.use_attr: 100 | if args.no_img: # A -> Y 101 | inputs, labels = data 102 | if isinstance(inputs, list): 103 | inputs = torch.stack(inputs).t().float() 104 | inputs = inputs.float() 105 | else: 106 | inputs, labels, attr_labels = data 107 | attr_labels = torch.stack(attr_labels).t() # N x 312 108 | else: # simple finetune 109 | inputs, labels = data 110 | 111 | inputs_var = torch.autograd.Variable(inputs).cuda() 112 | labels_var = torch.autograd.Variable(labels).cuda() 113 | 114 | if args.attribute_group: 115 | outputs = [] 116 | f = open(args.attribute_group, 'r') 117 | for line in f: 118 | attr_model = torch.load(line.strip()) 119 | outputs.extend(attr_model(inputs_var)) 120 | else: 121 | outputs = model(inputs_var) 122 | if args.use_attr: 123 | if args.no_img: # A -> Y 124 | class_outputs = outputs 125 | else: 126 | if args.bottleneck: 127 | if args.use_relu: 128 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs] 129 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 130 | elif args.use_sigmoid: 131 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs] 132 | attr_outputs_sigmoid = attr_outputs 133 | else: 134 | attr_outputs = outputs 135 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 136 | if model2: 137 | stage2_inputs = torch.cat(attr_outputs, dim=1) 138 | if args.inference_mode == 'soft' or args.inference_mode == 'hard': 139 | if args.inference_mode == 'hard': 140 | stage2_inputs = stage2_inputs >= (torch.ones_like(stage2_inputs) * 0.5) 141 | stage2_inputs = stage2_inputs.float() 142 | class_outputs = model2(stage2_inputs) 143 | elif args.inference_mode == 'samp': 144 | class_outputs_all = [] 145 | for _ in range(args.mc_samples): 146 | rand_num = torch.rand(*stage2_inputs.size()) 147 | sampled_stage2_inputs = rand_num.cuda() < stage2_inputs 148 | sampled_stage2_inputs = sampled_stage2_inputs.float() 149 | _class_outputs = model2(sampled_stage2_inputs) 150 | class_outputs_all.append(class_outputs_all) 151 | class_outputs = torch.mean(torch.stack(class_outputs_all, axis=0), axis=0) 152 | 153 | else: # for debugging bottleneck performance without running stage 2 154 | class_outputs = torch.zeros([inputs.size(0), N_CLASSES], 155 | dtype=torch.float64).cuda() # ignore this 156 | else: # cotraining, end2end 157 | if args.use_relu: 158 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs[1:]] 159 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 160 | elif args.use_sigmoid: 161 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 162 | attr_outputs_sigmoid = attr_outputs 163 | else: 164 | attr_outputs = outputs[1:] 165 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 166 | 167 | class_outputs = outputs[0] 168 | 169 | batch_incorrect_num = np.zeros(inputs.size(0)) 170 | 171 | for i in range(args.n_attributes): 172 | 173 | pred = attr_outputs_sigmoid[i].squeeze().cpu() >= 0.5 174 | eq = (pred.int()).eq(attr_labels[:,i].int()) 175 | 176 | not_eq = [int(not ele) for ele in eq] 177 | 178 | batch_incorrect_num += not_eq 179 | 180 | acc = binary_accuracy(attr_outputs_sigmoid[i].squeeze(), attr_labels[:, i]) 181 | acc = acc.data.cpu().numpy() 182 | attr_acc_meter[0].update(acc, inputs.size(0)) 183 | if args.feature_group_results: # keep track of accuracy of individual attributes 184 | attr_acc_meter[i + 1].update(acc, inputs.size(0)) 185 | 186 | 187 | attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1) 188 | attr_outputs_sigmoid = torch.cat([o for o in attr_outputs_sigmoid], dim=1) 189 | all_attr_outputs.extend(list(attr_outputs.flatten().data.cpu().numpy())) 190 | all_attr_outputs_sigmoid.extend(list(attr_outputs_sigmoid.flatten().data.cpu().numpy())) 191 | all_attr_labels.extend(list(attr_labels.flatten().data.cpu().numpy())) 192 | 193 | for i in range(inputs.size(0)): 194 | incorrect_cnt[int(batch_incorrect_num[i])] += 1 195 | 196 | 197 | for i in range(inputs.size(0)): 198 | incorrect = int(batch_incorrect_num[i]) 199 | if incorrect >= 15: 200 | incorrect_idx_list[15].append(data_size + i) 201 | else: 202 | incorrect_idx_list[incorrect].append(data_size + i) 203 | data_size += inputs.size(0) 204 | 205 | else: 206 | class_outputs = outputs[0] 207 | 208 | _, topk_preds = class_outputs.topk(max(K), 1, True, True) 209 | _, preds = class_outputs.topk(1, 1, True, True) 210 | 211 | all_class_outputs.extend(list(preds.detach().cpu().numpy().flatten())) 212 | all_class_labels.extend(list(labels.data.cpu().numpy())) 213 | all_class_logits.extend(class_outputs.detach().cpu().numpy()) 214 | topk_class_outputs.extend(topk_preds.detach().cpu().numpy()) 215 | topk_class_labels.extend(labels.view(-1, 1).expand_as(preds)) 216 | 217 | 218 | np.set_printoptions(threshold=sys.maxsize) 219 | class_acc = accuracy(class_outputs, labels, topk=K) # only class prediction accuracy 220 | 221 | _, pred = class_outputs.topk(1, 1, True, True) 222 | pred = pred.t() 223 | temp = labels.view(1, -1).expand_as(pred) 224 | temp = temp.cuda() 225 | class_eq = pred.eq(temp)[0] 226 | 227 | for i in range(inputs.size(0)): 228 | incorrect_num = int(batch_incorrect_num[i]) 229 | y_correct_by_incorrect_c [incorrect_num] += int(class_eq[i]) 230 | 231 | for m in range(len(class_acc_meter)): 232 | class_acc_meter[m].update(class_acc[m], inputs.size(0)) 233 | 234 | all_class_logits = np.vstack(all_class_logits) 235 | topk_class_outputs = np.vstack(topk_class_outputs) 236 | topk_class_labels = np.vstack(topk_class_labels) 237 | wrong_idx = np.where(np.sum(topk_class_outputs == topk_class_labels, axis=1) == 0)[0] 238 | 239 | yerr_by_cerr = 1 - np.array(y_correct_by_incorrect_c)/(np.array(incorrect_cnt)+1e-6) 240 | 241 | for j in range(len(K)): 242 | print('Average top %d class accuracy: %.5f' % (K[j], class_acc_meter[j].avg)) 243 | 244 | if args.use_attr and not args.no_img: # print some metrics for attribute prediction performance 245 | print('Average attribute accuracy: %.5f' % attr_acc_meter[0].avg) 246 | all_attr_outputs_int = np.array(all_attr_outputs_sigmoid) >= 0.5 247 | if args.feature_group_results: 248 | n = len(all_attr_labels) 249 | all_attr_acc, all_attr_f1 = [], [] 250 | for i in range(args.n_attributes): 251 | acc_meter = attr_acc_meter[1 + i] 252 | attr_acc = float(acc_meter.avg) 253 | attr_preds = [all_attr_outputs_int[j] for j in range(n) if j % args.n_attributes == i] 254 | attr_labels = [all_attr_labels[j] for j in range(n) if j % args.n_attributes == i] 255 | attr_f1 = f1_score(attr_labels, attr_preds) 256 | all_attr_acc.append(attr_acc) 257 | all_attr_f1.append(attr_f1) 258 | 259 | bins = np.arange(0, 1.01, 0.1) 260 | acc_bin_ids = np.digitize(np.array(all_attr_acc) / 100.0, bins) 261 | acc_counts_per_bin = [np.sum(acc_bin_ids == (i + 1)) for i in range(len(bins))] 262 | f1_bin_ids = np.digitize(np.array(all_attr_f1), bins) 263 | f1_counts_per_bin = [np.sum(f1_bin_ids == (i + 1)) for i in range(len(bins))] 264 | print("Accuracy bins:") 265 | print(acc_counts_per_bin) 266 | print("F1 bins:") 267 | print(f1_counts_per_bin) 268 | np.savetxt(os.path.join(args.log_dir, 'concepts.txt'), f1_counts_per_bin) 269 | 270 | balanced_acc, report = multiclass_metric(all_attr_outputs_int, all_attr_labels) 271 | f1 = f1_score(all_attr_labels, all_attr_outputs_int) 272 | print("Total 1's predicted:", sum(np.array(all_attr_outputs_sigmoid) >= 0.5) / len(all_attr_outputs_sigmoid)) 273 | print('Avg attribute balanced acc: %.5f' % (balanced_acc)) 274 | print("Avg attribute F1 score: %.5f" % f1) 275 | print(report + '\n') 276 | return class_acc_meter, attr_acc_meter, all_class_labels, topk_class_outputs, all_class_logits, all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, wrong_idx, all_attr_outputs2, incorrect_idx_list, yerr_by_cerr 277 | 278 | if __name__ == '__main__': 279 | torch.backends.cudnn.benchmark=True 280 | parser = argparse.ArgumentParser(description='PyTorch Training') 281 | parser.add_argument('-log_dir', default='.', help='where results are stored') 282 | parser.add_argument('-model_dirs', default=None, nargs='+', help='where the trained models are saved') 283 | parser.add_argument('-model_dirs2', default=None, nargs='+', help='where another trained model are saved (for bottleneck only)') 284 | parser.add_argument('-eval_data', default='test', help='Type of data (train/ val/ test) to be used') 285 | parser.add_argument('-use_attr', help='whether to use attributes (FOR COTRAINING ARCHITECTURE ONLY)', action='store_true') 286 | parser.add_argument('-no_img', help='if included, only use attributes (and not raw imgs) for class prediction', action='store_true') 287 | parser.add_argument('-bottleneck', help='whether to predict attributes before class labels', action='store_true') 288 | parser.add_argument('-image_dir', default='images', help='test image folder to run inference on') 289 | parser.add_argument('-n_class_attr', type=int, default=2, help='whether attr prediction is a binary or triary classification') 290 | parser.add_argument('-data_dir', default='', help='directory to the data used for evaluation') 291 | parser.add_argument('-n_attributes', type=int, default=N_ATTRIBUTES, help='whether to apply bottlenecks to only a few attributes') 292 | parser.add_argument('-attribute_group', default=None, help='file listing the (trained) model directory for each attribute group') 293 | parser.add_argument('-feature_group_results', help='whether to print out performance of individual atttributes', action='store_true') 294 | parser.add_argument('-use_relu', help='Whether to include relu activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 295 | parser.add_argument('-use_sigmoid', help='Whether to include sigmoid activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 296 | parser.add_argument('-inference_mode', default='soft', help='mode of inference') 297 | parser.add_argument('-mc_samples', default=5, type=int, help='Number of MC samples for samp inference mode') 298 | args = parser.parse_args() 299 | args.batch_size = 16 300 | 301 | print(args) 302 | y_results, c_results = [], [] 303 | yerr_by_cerr_results = [] 304 | for i, model_dir in enumerate(args.model_dirs): 305 | args.model_dir = model_dir 306 | args.model_dir2 = args.model_dirs2[i] if args.model_dirs2 else None 307 | result = eval(args) 308 | class_acc_meter, attr_acc_meter = result[0], result[1] 309 | yerr_by_cerr = result[-1] 310 | y_results.append(1 - class_acc_meter[0].avg[0].item() / 100.) 311 | if attr_acc_meter is not None: 312 | c_results.append(1 - attr_acc_meter[0].avg.item() / 100.) 313 | else: 314 | c_results.append(-1) 315 | yerr_by_cerr_results.append(yerr_by_cerr) 316 | values = (np.mean(y_results), np.std(y_results), np.mean(c_results), np.std(c_results)) 317 | output_string = '%.4f %.4f %.4f %.4f' % values 318 | print_string = 'Error of y: %.4f +- %.4f, Error of C: %.4f +- %.4f' % values 319 | print(print_string) 320 | 321 | yerr_by_cerr_mean = np.array(yerr_by_cerr_results).mean(axis=0) 322 | yerr_by_cerr_std = np.array(yerr_by_cerr_results).std(axis=0) 323 | 324 | print("yerr_by_cerr_mean", yerr_by_cerr_mean.tolist()) 325 | print("yerr_by_cerr_std", yerr_by_cerr_std.tolist()) 326 | output = open(os.path.join(args.log_dir, 'results.txt'), 'w') 327 | output.write(output_string) -------------------------------------------------------------------------------- /CUB/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train InceptionV3 Network using the CUB-200-2011 dataset 3 | """ 4 | import pdb 5 | import os 6 | import sys 7 | import argparse 8 | 9 | from CUB.template_model import MLP 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | import math 13 | import torch 14 | import numpy as np 15 | from analysis import Logger, AverageMeter, accuracy, binary_accuracy 16 | 17 | from CUB import tti 18 | from CUB.dataset import load_data, find_class_imbalance 19 | from CUB.config import BASE_DIR, N_CLASSES, N_ATTRIBUTES, UPWEIGHT_RATIO, MIN_LR, LR_DECAY_SIZE, GROUP_DICT 20 | from CUB.models import ModelXtoCY, ModelXtoChat_ChatToY, ModelXtoY, ModelXtoC, ModelOracleCtoY, ModelXtoCtoY 21 | 22 | def run_epoch_simple(model, optimizer, loader, loss_meter, acc_meter, criterion, args, is_training): 23 | """ 24 | A -> Y: Predicting class labels using only attributes with MLP 25 | """ 26 | if is_training: 27 | model.train() 28 | else: 29 | model.eval() 30 | for _, data in enumerate(loader): 31 | inputs, labels = data 32 | if isinstance(inputs, list): 33 | inputs = torch.stack(inputs).t().float() 34 | inputs = torch.flatten(inputs, start_dim=1).float() 35 | inputs_var = torch.autograd.Variable(inputs).cuda() 36 | inputs_var = inputs_var.cuda() if torch.cuda.is_available() else inputs_var 37 | labels_var = torch.autograd.Variable(labels).cuda() 38 | labels_var = labels_var.cuda() if torch.cuda.is_available() else labels_var 39 | 40 | outputs = model(inputs_var) 41 | loss = criterion(outputs, labels_var) 42 | acc = accuracy(outputs, labels, topk=(1,)) 43 | loss_meter.update(loss.item(), inputs.size(0)) 44 | acc_meter.update(acc[0], inputs.size(0)) 45 | 46 | if is_training: 47 | optimizer.zero_grad() #zero the parameter gradients 48 | loss.backward() 49 | optimizer.step() #optimizer step to update parameters 50 | return loss_meter, acc_meter 51 | 52 | def run_epoch(model, optimizer, loader, loss_meter, acc_meter, criterion, attr_criterion, args, is_training): 53 | """ 54 | For the rest of the networks (X -> A, cotraining, simple finetune) 55 | """ 56 | if is_training: 57 | model.train() 58 | else: 59 | model.eval() 60 | 61 | for _, data in enumerate(loader): 62 | if attr_criterion is None: 63 | inputs, labels = data 64 | attr_labels, attr_labels_var = None, None 65 | else: 66 | inputs, labels, attr_labels = data 67 | if args.n_attributes > 1: 68 | attr_labels = [i.long() for i in attr_labels] 69 | attr_labels = torch.stack(attr_labels).t()#.float() #N x 312 70 | else: 71 | if isinstance(attr_labels, list): 72 | attr_labels = attr_labels[0] 73 | attr_labels = attr_labels.unsqueeze(1) 74 | attr_labels_var = torch.autograd.Variable(attr_labels).float() 75 | attr_labels_var = attr_labels_var.cuda() if torch.cuda.is_available() else attr_labels_var 76 | 77 | inputs_var = torch.autograd.Variable(inputs) 78 | inputs_var = inputs_var.cuda() if torch.cuda.is_available() else inputs_var 79 | labels_var = torch.autograd.Variable(labels) 80 | labels_var = labels_var.cuda() if torch.cuda.is_available() else labels_var 81 | 82 | if is_training and args.use_aux: 83 | outputs, aux_outputs = model(inputs_var) 84 | losses = [] 85 | out_start = 0 86 | if not args.bottleneck: #loss main is for the main task label (always the first output) 87 | loss_main = 1.0 * criterion(outputs[0], labels_var) + 0.4 * criterion(aux_outputs[0], labels_var) 88 | losses.append(loss_main) 89 | out_start = 1 90 | if attr_criterion is not None and args.attr_loss_weight > 0: #X -> A, cotraining, end2end 91 | for i in range(len(attr_criterion)): 92 | losses.append(args.attr_loss_weight * (1.0 * attr_criterion[i](outputs[i+out_start].squeeze().type(torch.cuda.FloatTensor), attr_labels_var[:, i]) \ 93 | + 0.4 * attr_criterion[i](aux_outputs[i+out_start].squeeze().type(torch.cuda.FloatTensor), attr_labels_var[:, i]))) 94 | else: #testing or no aux logits 95 | outputs = model(inputs_var) 96 | losses = [] 97 | out_start = 0 98 | if not args.bottleneck: 99 | loss_main = criterion(outputs[0], labels_var) 100 | losses.append(loss_main) 101 | out_start = 1 102 | if attr_criterion is not None and args.attr_loss_weight > 0: #X -> A, cotraining, end2end 103 | for i in range(len(attr_criterion)): 104 | losses.append(args.attr_loss_weight * attr_criterion[i](outputs[i+out_start].squeeze().type(torch.cuda.FloatTensor), attr_labels_var[:, i])) 105 | 106 | if args.bottleneck: #attribute accuracy 107 | sigmoid_outputs = torch.nn.Sigmoid()(torch.cat(outputs, dim=1)) 108 | acc = binary_accuracy(sigmoid_outputs, attr_labels) 109 | acc_meter.update(acc.data.cpu().numpy(), inputs.size(0)) 110 | else: 111 | acc = accuracy(outputs[0], labels, topk=(1,)) #only care about class prediction accuracy 112 | acc_meter.update(acc[0], inputs.size(0)) 113 | 114 | if attr_criterion is not None: 115 | if args.bottleneck: 116 | total_loss = sum(losses)/ args.n_attributes 117 | else: #cotraining, loss by class prediction and loss by attribute prediction have the same weight 118 | total_loss = losses[0] + sum(losses[1:]) 119 | if args.normalize_loss: 120 | total_loss = total_loss / (1 + args.attr_loss_weight * args.n_attributes) 121 | else: #finetune 122 | total_loss = sum(losses) 123 | 124 | loss_meter.update(total_loss.item(), inputs.size(0)) 125 | 126 | if is_training: 127 | optimizer.zero_grad() 128 | total_loss.backward() 129 | optimizer.step() 130 | 131 | return loss_meter, acc_meter 132 | 133 | def train(model, args): 134 | # Determine imbalance 135 | imbalance = None 136 | if args.use_attr and not args.no_img and args.weighted_loss: 137 | train_data_path = os.path.join(BASE_DIR, args.data_dir, 'train.pkl') 138 | if args.weighted_loss == 'multiple': 139 | imbalance = find_class_imbalance(train_data_path, True) 140 | else: 141 | imbalance = find_class_imbalance(train_data_path, False) 142 | 143 | if os.path.exists(args.log_dir): # job restarted by cluster 144 | for f in os.listdir(args.log_dir): 145 | os.remove(os.path.join(args.log_dir, f)) 146 | else: 147 | os.makedirs(args.log_dir) 148 | 149 | logger = Logger(os.path.join(args.log_dir, 'log.txt')) 150 | logger.write(str(args) + '\n') 151 | logger.write(str(imbalance) + '\n') 152 | logger.flush() 153 | 154 | model = model.cuda() 155 | 156 | criterion = torch.nn.CrossEntropyLoss() 157 | if args.use_attr and not args.no_img: 158 | attr_criterion = [] #separate criterion (loss function) for each attribute 159 | if args.weighted_loss: 160 | assert(imbalance is not None) 161 | for ratio in imbalance: 162 | attr_criterion.append(torch.nn.BCEWithLogitsLoss(weight=torch.FloatTensor([ratio]).cuda())) 163 | else: 164 | for i in range(args.n_attributes): 165 | attr_criterion.append(torch.nn.CrossEntropyLoss()) 166 | else: 167 | attr_criterion = None 168 | 169 | if args.optimizer == 'Adam': 170 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) 171 | elif args.optimizer == 'RMSprop': 172 | optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 173 | else: 174 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 175 | 176 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.1) 177 | stop_epoch = int(math.log(MIN_LR / args.lr) / math.log(LR_DECAY_SIZE)) * args.scheduler_step 178 | print("Stop epoch: ", stop_epoch) 179 | 180 | train_data_path = os.path.join(BASE_DIR, args.data_dir, 'train.pkl') 181 | val_data_path = train_data_path.replace('train.pkl', 'val.pkl') 182 | logger.write('train data path: %s\n' % train_data_path) 183 | 184 | if args.ckpt: #retraining 185 | train_loader = load_data([train_data_path, val_data_path], args.use_attr, args.no_img, args.batch_size, args.uncertain_labels, image_dir=args.image_dir, \ 186 | n_class_attr=args.n_class_attr, resampling=args.resampling) 187 | val_loader = None 188 | else: 189 | train_loader = load_data([train_data_path], args.use_attr, args.no_img, args.batch_size, args.uncertain_labels, image_dir=args.image_dir, \ 190 | n_class_attr=args.n_class_attr, resampling=args.resampling) 191 | val_loader = load_data([val_data_path], args.use_attr, args.no_img, args.batch_size, image_dir=args.image_dir, n_class_attr=args.n_class_attr) 192 | 193 | best_val_epoch = -1 194 | best_val_loss = float('inf') 195 | best_val_acc = 0 196 | 197 | 198 | for epoch in range(0, args.epochs): 199 | train_loss_meter = AverageMeter() 200 | train_acc_meter = AverageMeter() 201 | if args.no_img: 202 | train_loss_meter, train_acc_meter = run_epoch_simple(model, optimizer, train_loader, train_loss_meter, train_acc_meter, criterion, args, is_training=True) 203 | else: 204 | train_loss_meter, train_acc_meter = run_epoch(model, optimizer, train_loader, train_loss_meter, train_acc_meter, criterion, attr_criterion, args, is_training=True) 205 | 206 | if not args.ckpt: # evaluate on val set 207 | val_loss_meter = AverageMeter() 208 | val_acc_meter = AverageMeter() 209 | 210 | with torch.no_grad(): 211 | if args.no_img: 212 | val_loss_meter, val_acc_meter = run_epoch_simple(model, optimizer, val_loader, val_loss_meter, val_acc_meter, criterion, args, is_training=False) 213 | else: 214 | val_loss_meter, val_acc_meter = run_epoch(model, optimizer, val_loader, val_loss_meter, val_acc_meter, criterion, attr_criterion, args, is_training=False) 215 | 216 | else: #retraining 217 | val_loss_meter = train_loss_meter 218 | val_acc_meter = train_acc_meter 219 | 220 | 221 | if best_val_acc < val_acc_meter.avg: 222 | best_val_epoch = epoch 223 | best_val_acc = val_acc_meter.avg 224 | logger.write('New model best model at epoch %d\n' % epoch) 225 | torch.save(model, os.path.join(args.log_dir, 'best_model_%d.pth' % args.seed)) 226 | 227 | 228 | train_loss_avg = train_loss_meter.avg 229 | val_loss_avg = val_loss_meter.avg 230 | 231 | logger.write('Epoch [%d]:\tTrain loss: %.4f\tTrain accuracy: %.4f\t' 232 | 'Val loss: %.4f\tVal acc: %.4f\t' 233 | 'Best val epoch: %d\n' 234 | % (epoch, train_loss_avg, train_acc_meter.avg, val_loss_avg, val_acc_meter.avg, best_val_epoch)) 235 | logger.flush() 236 | 237 | if epoch <= stop_epoch: 238 | scheduler.step(epoch) #scheduler step to update lr at the end of epoch 239 | #inspect lr 240 | if epoch % 10 == 0: 241 | print('Current lr:', scheduler.get_lr()) 242 | 243 | if epoch >= 100 and val_acc_meter.avg < 3: 244 | print("Early stopping because of low accuracy") 245 | break 246 | if epoch - best_val_epoch >= 100: 247 | print("Early stopping because acc hasn't improved for a long time") 248 | break 249 | 250 | def train_X_to_C(args): 251 | model = ModelXtoC(pretrained=args.pretrained, freeze=args.freeze, num_classes=N_CLASSES, use_aux=args.use_aux, 252 | n_attributes=args.n_attributes, expand_dim=args.expand_dim, three_class=args.three_class) 253 | train(model, args) 254 | 255 | def train_oracle_C_to_y_and_test_on_Chat(args): 256 | model = ModelOracleCtoY(n_class_attr=args.n_class_attr, n_attributes=args.n_attributes, 257 | num_classes=N_CLASSES, expand_dim=args.expand_dim) 258 | train(model, args) 259 | 260 | def train_Chat_to_y_and_test_on_Chat(args): 261 | model = ModelXtoChat_ChatToY(n_class_attr=args.n_class_attr, n_attributes=args.n_attributes, 262 | num_classes=N_CLASSES, expand_dim=args.expand_dim) 263 | train(model, args) 264 | 265 | def train_X_to_C_to_y(args): 266 | model = ModelXtoCtoY(n_class_attr=args.n_class_attr, pretrained=args.pretrained, freeze=args.freeze, 267 | num_classes=N_CLASSES, use_aux=args.use_aux, n_attributes=args.n_attributes, 268 | expand_dim=args.expand_dim, use_relu=args.use_relu, use_sigmoid=args.use_sigmoid) 269 | train(model, args) 270 | 271 | def train_X_to_y(args): 272 | model = ModelXtoY(pretrained=args.pretrained, freeze=args.freeze, num_classes=N_CLASSES, use_aux=args.use_aux) 273 | train(model, args) 274 | 275 | def train_X_to_Cy(args): 276 | model = ModelXtoCY(pretrained=args.pretrained, freeze=args.freeze, num_classes=N_CLASSES, use_aux=args.use_aux, 277 | n_attributes=args.n_attributes, three_class=args.three_class, connect_CY=args.connect_CY) 278 | train(model, args) 279 | 280 | def test_time_intervention(args): 281 | tti.run(args) 282 | 283 | def parse_arguments(experiment): 284 | # Get argparse configs from user 285 | parser = argparse.ArgumentParser(description='CUB Training') 286 | parser.add_argument('dataset', type=str, help='Name of the dataset.') 287 | parser.add_argument('exp', type=str, 288 | choices=['Concept_XtoC', 'Independent_CtoY', 'Sequential_CtoY', 289 | 'Standard', 'Multitask', 'Joint', 'TTI',], 290 | help='Name of experiment to run.') 291 | parser.add_argument('--seed', required=True, type=int, help='Numpy and torch seed.') 292 | 293 | if experiment == 'TTI': 294 | return (tti.parse_arguments(parser),) 295 | 296 | else: 297 | parser.add_argument('-log_dir', default=None, help='where the trained model is saved') 298 | parser.add_argument('-batch_size', '-b', type=int, help='mini-batch size') 299 | parser.add_argument('-epochs', '-e', type=int, help='epochs for training process') 300 | parser.add_argument('-save_step', default=1000, type=int, help='number of epochs to save model') 301 | parser.add_argument('-lr', type=float, help="learning rate") 302 | parser.add_argument('-weight_decay', type=float, default=5e-5, help='weight decay for optimizer') 303 | parser.add_argument('-pretrained', '-p', action='store_true', 304 | help='whether to load pretrained model & just fine-tune') 305 | parser.add_argument('-freeze', action='store_true', help='whether to freeze the bottom part of inception network') 306 | parser.add_argument('-use_aux', action='store_true', help='whether to use aux logits') 307 | parser.add_argument('-use_attr', action='store_true', 308 | help='whether to use attributes (FOR COTRAINING ARCHITECTURE ONLY)') 309 | parser.add_argument('-attr_loss_weight', default=1.0, type=float, help='weight for loss by predicting attributes') 310 | parser.add_argument('-no_img', action='store_true', 311 | help='if included, only use attributes (and not raw imgs) for class prediction') 312 | parser.add_argument('-bottleneck', help='whether to predict attributes before class labels', action='store_true') 313 | parser.add_argument('-weighted_loss', default='', # note: may need to reduce lr 314 | help='Whether to use weighted loss for single attribute or multiple ones') 315 | parser.add_argument('-uncertain_labels', action='store_true', 316 | help='whether to use (normalized) attribute certainties as labels') 317 | parser.add_argument('-n_attributes', type=int, default=N_ATTRIBUTES, 318 | help='whether to apply bottlenecks to only a few attributes') 319 | parser.add_argument('-expand_dim', type=int, default=0, 320 | help='dimension of hidden layer (if we want to increase model capacity) - for bottleneck only') 321 | parser.add_argument('-n_class_attr', type=int, default=2, 322 | help='whether attr prediction is a binary or triary classification') 323 | parser.add_argument('-data_dir', default='official_datasets', help='directory to the training data') 324 | parser.add_argument('-image_dir', default='images', help='test image folder to run inference on') 325 | parser.add_argument('-resampling', help='Whether to use resampling', action='store_true') 326 | parser.add_argument('-end2end', action='store_true', 327 | help='Whether to train X -> A -> Y end to end. Train cmd is the same as cotraining + this arg') 328 | parser.add_argument('-optimizer', default='SGD', help='Type of optimizer to use, options incl SGD, RMSProp, Adam') 329 | parser.add_argument('-ckpt', default='', help='For retraining on both train + val set') 330 | parser.add_argument('-scheduler_step', type=int, default=1000, 331 | help='Number of steps before decaying current learning rate by half') 332 | parser.add_argument('-normalize_loss', action='store_true', 333 | help='Whether to normalize loss by taking attr_loss_weight into account') 334 | parser.add_argument('-use_relu', action='store_true', 335 | help='Whether to include relu activation before using attributes to predict Y. ' 336 | 'For end2end & bottleneck model') 337 | parser.add_argument('-use_sigmoid', action='store_true', 338 | help='Whether to include sigmoid activation before using attributes to predict Y. ' 339 | 'For end2end & bottleneck model') 340 | parser.add_argument('-connect_CY', action='store_true', 341 | help='Whether to use concepts as auxiliary features (in multitasking) to predict Y') 342 | parser.add_argument('-fix_seed', action='store_true', help='whether to fix seed') 343 | args = parser.parse_args() 344 | args.three_class = (args.n_class_attr == 3) 345 | return (args,) -------------------------------------------------------------------------------- /SKINCON/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate trained models on the official CUB test set 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import joblib 8 | import argparse 9 | import numpy as np 10 | from sklearn.metrics import f1_score, balanced_accuracy_score 11 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 12 | 13 | from SKINCON.dataset import load_data 14 | from SKINCON.config import BASE_DIR, N_CLASSES, N_ATTRIBUTES 15 | from analysis import AverageMeter, multiclass_metric, accuracy, binary_accuracy 16 | 17 | K = [1] #top k class accuracies to compute 18 | 19 | def eval(args): 20 | """ 21 | Run inference using model (and model2 if bottleneck) 22 | Returns: (for notebook analysis) 23 | all_class_labels: flattened list of class labels for each image 24 | topk_class_outputs: array of top k class ids predicted for each image. Shape = size of test set * max(K) 25 | all_class_outputs: array of all logit outputs for class prediction, shape = N_TEST * N_CLASS 26 | all_attr_labels: flattened list of labels for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 27 | all_attr_outputs: flatted list of attribute logits (after ReLU/ Sigmoid respectively) predicted for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 28 | all_attr_outputs_sigmoid: flatted list of attribute logits predicted (after Sigmoid) for each attribute for each image (length = N_ATTRIBUTES * N_TEST) 29 | wrong_idx: image ids where the model got the wrong class prediction (to compare with other models) 30 | incorrect_idx_list: list of lists where the i-th list contains the images which predicted i concepts incorrectly 31 | all_img_paths: array of image path 32 | """ 33 | if args.model_dir: 34 | model = torch.load(args.model_dir) 35 | else: 36 | model = None 37 | 38 | if not hasattr(model, 'use_relu'): 39 | if args.use_relu: 40 | model.use_relu = True 41 | else: 42 | model.use_relu = False 43 | if not hasattr(model, 'use_sigmoid'): 44 | if args.use_sigmoid: 45 | model.use_sigmoid = True 46 | else: 47 | model.use_sigmoid = False 48 | if not hasattr(model, 'cy_fc'): 49 | model.cy_fc = None 50 | model.eval() 51 | 52 | incorrect_idx_list = [[] for _ in range(16) ] 53 | 54 | if args.model_dir2: 55 | if 'rf' in args.model_dir2: 56 | model2 = joblib.load(args.model_dir2) 57 | else: 58 | model2 = torch.load(args.model_dir2) 59 | if not hasattr(model2, 'use_relu'): 60 | if args.use_relu: 61 | model2.use_relu = True 62 | else: 63 | model2.use_relu = False 64 | if not hasattr(model2, 'use_sigmoid'): 65 | if args.use_sigmoid: 66 | model2.use_sigmoid = True 67 | else: 68 | model2.use_sigmoid = False 69 | model2.eval() 70 | else: 71 | model2 = None 72 | 73 | if args.use_attr: 74 | attr_acc_meter = [AverageMeter()] 75 | if args.feature_group_results: # compute acc for each feature individually in addition to the overall accuracy 76 | for _ in range(args.n_attributes): 77 | attr_acc_meter.append(AverageMeter()) 78 | else: 79 | attr_acc_meter = None 80 | 81 | class_acc_meter = [] 82 | for j in range(len(K)): 83 | class_acc_meter.append(AverageMeter()) 84 | 85 | data_dir = os.path.join(BASE_DIR, args.data_dir, args.class_label, args.eval_data + '.pkl') 86 | loader = load_data([data_dir], args.use_attr, args.no_img, args.batch_size, image_dir=args.image_dir, 87 | n_class_attr=args.n_class_attr, return_path=True, class_label=args.class_label) 88 | all_outputs, all_targets = [], [] 89 | all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, all_attr_outputs2 = [], [], [], [] 90 | all_class_labels, all_class_outputs, all_class_logits = [], [], [] 91 | topk_class_labels, topk_class_outputs = [], [] 92 | 93 | all_image_paths = [] 94 | 95 | data_size = 0 96 | 97 | incorrect_cnt = np.zeros(args.n_attributes) # i-th element: number of images where i number of concepts are mispredicted 98 | y_correct_by_incorrect_c = np.zeros(args.n_attributes) # i-th element: number of correctly classified images when i number of concepts are mispredicted 99 | 100 | 101 | for data_idx, data in enumerate(loader): 102 | if args.use_attr: 103 | if args.no_img: # A -> Y 104 | inputs, labels = data 105 | if isinstance(inputs, list): 106 | inputs = torch.stack(inputs).t().float() 107 | inputs = inputs.float() 108 | # inputs = torch.flatten(inputs, start_dim=1).float() 109 | else: 110 | inputs, labels, attr_labels, img_path = data 111 | attr_labels = torch.stack(attr_labels).t() # N x 312 112 | else: # simple finetune 113 | inputs, labels = data 114 | 115 | inputs_var = torch.autograd.Variable(inputs).cuda() 116 | labels_var = torch.autograd.Variable(labels).cuda() 117 | 118 | if args.attribute_group: 119 | outputs = [] 120 | f = open(args.attribute_group, 'r') 121 | for line in f: 122 | attr_model = torch.load(line.strip()) 123 | outputs.extend(attr_model(inputs_var)) 124 | else: 125 | outputs = model(inputs_var) 126 | 127 | if args.use_attr: 128 | if args.no_img: # A -> Y 129 | class_outputs = outputs 130 | else: 131 | if args.bottleneck: 132 | if args.use_relu: 133 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs] 134 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 135 | elif args.use_sigmoid: 136 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs] 137 | attr_outputs_sigmoid = attr_outputs 138 | else: 139 | attr_outputs = outputs 140 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs] 141 | if model2: 142 | stage2_inputs = torch.cat(attr_outputs, dim=1) 143 | if args.inference_mode == 'soft' or args.inference_mode == 'hard': 144 | if args.inference_mode == 'hard': 145 | stage2_inputs = stage2_inputs >= (torch.ones_like(stage2_inputs) * 0.5) 146 | stage2_inputs = stage2_inputs.float() 147 | class_outputs = model2(stage2_inputs) 148 | elif args.inference_mode == 'samp': 149 | class_outputs_all = [] 150 | for _ in range(args.mc_samples): 151 | rand_num = torch.rand(*stage2_inputs.size()) 152 | sampled_stage2_inputs = rand_num.cuda() < stage2_inputs 153 | sampled_stage2_inputs = sampled_stage2_inputs.float() 154 | _class_outputs = model2(sampled_stage2_inputs) 155 | class_outputs_all.append(class_outputs_all) 156 | class_outputs = torch.mean(torch.stack(class_outputs_all, axis=0), axis=0) 157 | else: # for debugging bottleneck performance without running stage 2 158 | class_outputs = torch.zeros([inputs.size(0), N_CLASSES], 159 | dtype=torch.float64).cuda() # ignore this 160 | else: # cotraining, end2end 161 | if args.use_relu: 162 | attr_outputs = [torch.nn.ReLU()(o) for o in outputs[1:]] 163 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 164 | elif args.use_sigmoid: 165 | attr_outputs = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 166 | attr_outputs_sigmoid = attr_outputs 167 | else: 168 | attr_outputs = outputs[1:] 169 | attr_outputs_sigmoid = [torch.nn.Sigmoid()(o) for o in outputs[1:]] 170 | class_outputs = outputs[0] 171 | 172 | batch_incorrect_num = np.zeros(inputs.size(0)) 173 | 174 | for i in range(args.n_attributes): 175 | 176 | pred = attr_outputs_sigmoid[i].squeeze().cpu() >= 0.5 177 | eq = (pred.int()).eq(attr_labels[:,i].int()) 178 | 179 | not_eq = [int(not ele) for ele in eq] 180 | 181 | batch_incorrect_num += not_eq 182 | 183 | acc = binary_accuracy(attr_outputs_sigmoid[i].squeeze(), attr_labels[:, i]) 184 | acc = acc.data.cpu().numpy() 185 | # acc = accuracy(attr_outputs_sigmoid[i], attr_labels[:, i], topk=(1,)) 186 | attr_acc_meter[0].update(acc, inputs.size(0)) 187 | if args.feature_group_results: # keep track of accuracy of individual attributes 188 | attr_acc_meter[i + 1].update(acc, inputs.size(0)) 189 | 190 | 191 | attr_outputs = torch.cat([o.unsqueeze(1) for o in attr_outputs], dim=1) 192 | attr_outputs_sigmoid = torch.cat([o for o in attr_outputs_sigmoid], dim=1) 193 | all_attr_outputs.extend(list(attr_outputs.flatten().data.cpu().numpy())) 194 | all_attr_outputs_sigmoid.extend(list(attr_outputs_sigmoid.flatten().data.cpu().numpy())) 195 | all_attr_labels.extend(list(attr_labels.flatten().data.cpu().numpy())) 196 | 197 | for i in range(inputs.size(0)): 198 | incorrect_cnt[int(batch_incorrect_num[i])] += 1 199 | 200 | 201 | for i in range(inputs.size(0)): 202 | incorrect = int(batch_incorrect_num[i]) 203 | if incorrect >= 15: 204 | incorrect_idx_list[15].append(data_size + i) 205 | else: 206 | incorrect_idx_list[incorrect].append(data_size + i) 207 | data_size += inputs.size(0) 208 | 209 | else: 210 | class_outputs = outputs[0] 211 | 212 | _, topk_preds = class_outputs.topk(max(K), 1, True, True) 213 | _, preds = class_outputs.topk(1, 1, True, True) 214 | 215 | _, topk_preds = class_outputs.topk(max(K), 1, True, True) 216 | _, preds = class_outputs.topk(1, 1, True, True) 217 | 218 | all_class_outputs.extend(list(preds.detach().cpu().numpy().flatten())) 219 | all_class_labels.extend(list(labels.data.cpu().numpy())) 220 | all_class_logits.extend(class_outputs.detach().cpu().numpy()) 221 | 222 | topk_class_outputs.extend(topk_preds.detach().cpu().numpy()) 223 | topk_class_labels.extend(labels.view(-1, 1).expand_as(preds)) 224 | 225 | "may need to define scaled version" 226 | 227 | if not args.no_img: 228 | all_image_paths.extend(img_path) 229 | 230 | np.set_printoptions(threshold=sys.maxsize) 231 | 232 | 233 | class_acc = accuracy(class_outputs, labels, topk=K) # only class prediction accuracy 234 | for m in range(len(class_acc_meter)): 235 | class_acc_meter[m].update(class_acc[m], inputs.size(0)) 236 | 237 | _, pred = class_outputs.topk(1, 1, True, True) 238 | pred = pred.t() 239 | temp = labels.view(1, -1).expand_as(pred) 240 | temp = temp.cuda() 241 | class_eq = pred.eq(temp)[0] 242 | 243 | if not args.no_img: 244 | for i in range(inputs.size(0)): 245 | incorrect_num = int(batch_incorrect_num[i]) 246 | y_correct_by_incorrect_c [incorrect_num] += int(class_eq[i]) 247 | 248 | all_class_logits = np.vstack(all_class_logits) 249 | 250 | if args.class_label != 'binary': 251 | topk_class_outputs = np.vstack(topk_class_outputs) 252 | topk_class_labels = np.vstack(topk_class_labels) 253 | wrong_idx = np.where(np.sum(topk_class_outputs == topk_class_labels, axis=1) == 0)[0] 254 | else: 255 | topk_class_outputs = all_class_outputs 256 | wrong_idx = np.where(np.sum(all_class_outputs == all_class_labels) == 0) 257 | 258 | 259 | 260 | yerr_by_cerr = 1 - np.array(y_correct_by_incorrect_c)/(np.array(incorrect_cnt)+1e-6) 261 | 262 | for j in range(len(K)): 263 | print('Average top %d class accuracy: %.5f' % (K[j], class_acc_meter[j].avg)) 264 | 265 | if args.use_attr and not args.no_img: # print some metrics for attribute prediction performance 266 | print('Average attribute accuracy: %.5f' % attr_acc_meter[0].avg) 267 | all_attr_outputs_int = np.array(all_attr_outputs_sigmoid) >= 0.5 268 | if args.feature_group_results: 269 | n = len(all_attr_labels) 270 | all_attr_acc, all_attr_f1 = [], [] 271 | for i in range(args.n_attributes): 272 | acc_meter = attr_acc_meter[1 + i] 273 | attr_acc = float(acc_meter.avg) 274 | attr_preds = [all_attr_outputs_int[j] for j in range(n) if j % args.n_attributes == i] 275 | attr_labels = [all_attr_labels[j] for j in range(n) if j % args.n_attributes == i] 276 | attr_f1 = f1_score(attr_labels, attr_preds) 277 | all_attr_acc.append(attr_acc) 278 | all_attr_f1.append(attr_f1) 279 | 280 | bins = np.arange(0, 1.01, 0.1) 281 | acc_bin_ids = np.digitize(np.array(all_attr_acc) / 100.0, bins) 282 | acc_counts_per_bin = [np.sum(acc_bin_ids == (i + 1)) for i in range(len(bins))] 283 | f1_bin_ids = np.digitize(np.array(all_attr_f1), bins) 284 | f1_counts_per_bin = [np.sum(f1_bin_ids == (i + 1)) for i in range(len(bins))] 285 | print("Accuracy bins:") 286 | print(acc_counts_per_bin) 287 | print("F1 bins:") 288 | print(f1_counts_per_bin) 289 | np.savetxt(os.path.join(args.log_dir, 'concepts.txt'), f1_counts_per_bin) 290 | 291 | balanced_acc, report = multiclass_metric(all_attr_outputs_int, all_attr_labels) 292 | f1 = f1_score(all_attr_labels, all_attr_outputs_int) 293 | print("Total 1's predicted:", sum(np.array(all_attr_outputs_sigmoid) >= 0.5) / len(all_attr_outputs_sigmoid)) 294 | print('Avg attribute balanced acc: %.5f' % (balanced_acc)) 295 | print("Avg attribute F1 score: %.5f" % f1) 296 | print(report + '\n') 297 | 298 | class_f1 = f1_score(all_class_labels, all_class_outputs, average='macro') 299 | class_balanced_acc = balanced_accuracy_score(all_class_labels, all_class_outputs) 300 | 301 | return class_acc_meter, attr_acc_meter, all_class_labels, topk_class_outputs, all_class_logits, all_attr_labels, all_attr_outputs, all_attr_outputs_sigmoid, wrong_idx, all_attr_outputs2, incorrect_idx_list, all_image_paths, yerr_by_cerr, class_f1, class_balanced_acc 302 | 303 | if __name__ == '__main__': 304 | torch.backends.cudnn.benchmark=True 305 | parser = argparse.ArgumentParser(description='PyTorch Training') 306 | parser.add_argument('-log_dir', default='.', help='where results are stored') 307 | parser.add_argument('-model_dirs', default=None, nargs='+', help='where the trained models are saved') 308 | parser.add_argument('-model_dirs2', default=None, nargs='+', help='where another trained model are saved (for bottleneck only)') 309 | parser.add_argument('-eval_data', default='test', help='Type of data (train/ val/ test) to be used') 310 | parser.add_argument('-use_attr', help='whether to use attributes (FOR COTRAINING ARCHITECTURE ONLY)', action='store_true') 311 | parser.add_argument('-no_img', help='if included, only use attributes (and not raw imgs) for class prediction', action='store_true') 312 | parser.add_argument('-bottleneck', help='whether to predict attributes before class labels', action='store_true') 313 | parser.add_argument('-image_dir', default='images', help='test image folder to run inference on') 314 | parser.add_argument('-n_class_attr', type=int, default=2, help='whether attr prediction is a binary or triary classification') 315 | parser.add_argument('-data_dir', default='', help='directory to the data used for evaluation') 316 | parser.add_argument('-n_attributes', type=int, default=N_ATTRIBUTES, help='whether to apply bottlenecks to only a few attributes') 317 | parser.add_argument('-attribute_group', default=None, help='file listing the (trained) model directory for each attribute group') 318 | parser.add_argument('-feature_group_results', help='whether to print out performance of individual atttributes', action='store_true') 319 | parser.add_argument('-use_relu', help='Whether to include relu activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 320 | parser.add_argument('-use_sigmoid', help='Whether to include sigmoid activation before using attributes to predict Y. For end2end & bottleneck model', action='store_true') 321 | parser.add_argument('-inference_mode', default='soft', help='mode of inference') 322 | parser.add_argument('-class_label', type=str, default='binary', help='which class label to use') 323 | args = parser.parse_args() 324 | args.batch_size = 16 325 | 326 | print(args) 327 | y_results, c_results = [], [] 328 | yerr_by_cerr_results = [] 329 | y_f1_results, y_balanced_acc_results = [], [] 330 | for i, model_dir in enumerate(args.model_dirs): 331 | args.model_dir = model_dir 332 | args.model_dir2 = args.model_dirs2[i] if args.model_dirs2 else None 333 | result = eval(args) 334 | class_acc_meter, attr_acc_meter = result[0], result[1] 335 | yerr_by_cerr = result[-4] 336 | y_results.append(1 - class_acc_meter[0].avg[0].item() / 100.) 337 | if attr_acc_meter is not None: 338 | c_results.append(1 - attr_acc_meter[0].avg.item() / 100.) 339 | else: 340 | c_results.append(-1) 341 | y_f1_results.append(result[-2]) 342 | y_balanced_acc_results.append(result[-1]) 343 | yerr_by_cerr_results.append(yerr_by_cerr) 344 | values = (np.mean(y_results), np.std(y_results), np.mean(c_results), np.std(c_results)) 345 | output_string = '%.4f %.4f %.4f %.4f' % values 346 | print_string = 'Error of y: %.4f +- %.4f, Error of C: %.4f +- %.4f' % values 347 | print(print_string) 348 | 349 | yerr_by_cerr_mean = np.array(yerr_by_cerr_results).mean(axis=0) 350 | yerr_by_cerr_std = np.array(yerr_by_cerr_results).std(axis=0) 351 | 352 | print("yerr_by_cerr_mean", yerr_by_cerr_mean.tolist()) 353 | print("yerr_by_cerr_std", yerr_by_cerr_std.tolist()) 354 | 355 | print("y_f1", np.array(y_f1_results).mean()) 356 | print("y_balanced_acc", np.array(y_balanced_acc_results).mean()) 357 | output = open(os.path.join(args.log_dir, 'results.txt'), 'w') 358 | output.write(output_string) --------------------------------------------------------------------------------