├── README.md ├── configs ├── eval.gin ├── recording.gin ├── training.gin ├── training_guided.gin └── training_random.gin ├── env.sh ├── eval.py ├── src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── balanced_mmtm.cpython-38.pyc │ ├── callbacks.cpython-38.pyc │ ├── dataset.cpython-38.pyc │ ├── framework.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── training_loop.cpython-38.pyc │ └── utils.cpython-38.pyc ├── balanced_mmtm.py ├── callbacks.py ├── dataset.py ├── framework.py ├── model.py ├── training_loop.py └── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Characterizing and overcoming the greedy nature of learning in multi-modal deep neural networks 2 | 3 | We provide the source code for the **balanced multi-modal learning algorithm** proposed in the above paper, along with implementations for the derived metrics, **conditional utilization rate** and **conditional learning speed**. 4 | 5 | **Accepted by ICML 2022 ** [[Paper]](https://arxiv.org/abs/2202.05306.pdf) 6 | 7 | ## Dependencies: 8 | * Python 3.8 / gin-config / numpy / pandas / pytorch / scikit-learn / scipy / torchvision / skimage / PIL 9 | 10 | ## Workflow 11 | 12 | We take the 3D object classification task using the [ModelNet40 dataset](http://maxwell.cs.umass.edu/mvcnn-data/) as an example. One can train the multi-modal DNN via the *balanced multi-modal learning algorithm* : 13 | 14 | * `python3 train.py $RESULTS_DIR/random configs/training_guided.gin` 15 | 16 | or its *random* version: 17 | 18 | * `python3 train.py $RESULTS_DIR/random configs/training_random.gin` 19 | 20 | To analysis multi-modal DNNs' *conditional utilization rate*, run the following two scripts consecutively: 21 | 22 | * `python3 eval.py $RESULTS_DIR/random configs/recording.gin` 23 | * `python3 eval.py $RESULTS_DIR/random configs/eval.gin` 24 | 25 | ## Citation 26 | Please cite this work if you find the analysis or the proposed method useful for your research. 27 | 28 | ``` 29 | @misc{wu2022characterizing, 30 | title={Characterizing and overcoming the greedy nature of learning in multi-modal deep neural networks}, 31 | author={Nan Wu and Stanisław Jastrzębski and Kyunghyun Cho and Krzysztof J. Geras}, 32 | year={2022}, 33 | eprint={2202.05306}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.LG} 36 | } 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /configs/eval.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | MMTM_MVCNN.pretraining=False 3 | MMTM_MVCNN.num_views=2 4 | MMTM_MVCNN.mmtm_off=True 5 | MMTM_MVCNN.mmtm_rescale_eval_file_path='/gpfs/data/geraslab/Nan/public_repo/greedymml/saves/guided/eval_history_batch' 6 | MMTM_MVCNN.mmtm_rescale_training_file_path='/gpfs/data/geraslab/Nan/public_repo/greedymml/saves/guided' 7 | MMTM_MVCNN.device='cuda:0' 8 | 9 | # Train configuration 10 | eval_.target_data_split='test' 11 | eval_.batch_size=8 12 | eval_.pretrained_weights_path='/gpfs/data/geraslab/Nan/public_repo/greedymml/saves/guided/model_best_val.pt' 13 | ProgressionCallback.other_metrics=[] 14 | 15 | # Training loop 16 | evalution_loop.use_gpu=True 17 | evalution_loop.device_numbers=[0] 18 | evalution_loop.save_with_structure=False 19 | 20 | # Dataset 21 | get_mvdcndata.make_npy_files=False 22 | get_mvdcndata.num_views=2 23 | get_mvdcndata.num_workers=0 24 | get_mvdcndata.specific_views=[0, 6] -------------------------------------------------------------------------------- /configs/recording.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | MMTM_MVCNN.pretraining=False 3 | MMTM_MVCNN.num_views=2 4 | MMTM_MVCNN.saving_mmtm_squeeze_array=True 5 | 6 | # Train configuration 7 | eval_.target_data_split='train' 8 | eval_.batch_size=8 9 | eval_.pretrained_weights_path='/gpfs/data/geraslab/Nan/public_repo/greedymml/saves/guided/model_best_val.pt' 10 | ProgressionCallback.other_metrics=[] 11 | 12 | # Training loop 13 | evalution_loop.use_gpu=True 14 | evalution_loop.device_numbers=[0] 15 | evalution_loop.save_with_structure=True 16 | 17 | # Dataset 18 | get_mvdcndata.valid_size=0 19 | get_mvdcndata.make_npy_files=False 20 | get_mvdcndata.num_views=2 21 | get_mvdcndata.num_workers=0 22 | get_mvdcndata.specific_views=[0, 6] -------------------------------------------------------------------------------- /configs/training.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | MMTM_MVCNN.pretraining=False 3 | MMTM_MVCNN.num_views=2 4 | 5 | # Train configuration 6 | train.batch_size=8 7 | train.lr=0.1 8 | train.wd=0.0 9 | train.momentum=0 10 | train.callbacks=['CompletedStopping', 'ReduceLROnPlateau_PyTorch', 'Bias_Mitigation_Strong'] 11 | ReduceLROnPlateau_PyTorch.metric='loss' 12 | CompletedStopping.patience=5 13 | CompletedStopping.monitor='acc' 14 | Bias_Mitigation_Strong.epsilon=0.01 15 | Bias_Mitigation_Strong.curation_windowsize=5 16 | Bias_Mitigation_Strong.starting_epoch=2 17 | Bias_Mitigation_Strong.branchnames=['net_view_0', 'net_view_1'] 18 | Bias_Mitigation_Strong.MMTMnames = ['visual', 'skeleton'] 19 | ProgressionCallback.other_metrics=['acc_modal_0', 'acc_modal_1', 'val_acc_modal_0', 'val_acc_modal_1', 'd_BDR', 'curation_mode', 'caring_modality'] 20 | 21 | 22 | # Training loop 23 | training_loop.nummodalities=2 24 | training_loop.n_epochs=300 25 | training_loop.use_gpu=True 26 | training_loop.device_numbers=[0] 27 | training_loop.checkpoint_monitor='val_acc' 28 | 29 | # Dataset 30 | get_mvdcndata.make_npy_files=False 31 | get_mvdcndata.num_views=2 32 | get_mvdcndata.num_workers=20 33 | get_mvdcndata.specific_views=[0, 6] -------------------------------------------------------------------------------- /configs/training_guided.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | MMTM_MVCNN.pretraining=False 3 | MMTM_MVCNN.num_views=2 4 | 5 | # Train configuration 6 | train.batch_size=8 7 | train.lr=0.1 8 | train.wd=0.0 9 | train.momentum=0 10 | train.callbacks=['CompletedStopping', 'ReduceLROnPlateau_PyTorch', 'Bias_Mitigation_Strong'] 11 | ReduceLROnPlateau_PyTorch.metric='loss' 12 | CompletedStopping.patience=5 13 | CompletedStopping.monitor='acc' 14 | Bias_Mitigation_Strong.epsilon=0.01 15 | Bias_Mitigation_Strong.curation_windowsize=5 16 | Bias_Mitigation_Strong.starting_epoch=2 17 | Bias_Mitigation_Strong.branchnames=['net_view_0', 'net_view_1'] 18 | Bias_Mitigation_Strong.MMTMnames = ['visual', 'skeleton'] 19 | ProgressionCallback.other_metrics=['acc_modal_0', 'acc_modal_1', 'val_acc_modal_0', 'val_acc_modal_1', 'd_BDR', 'curation_mode', 'caring_modality'] 20 | 21 | 22 | # Training loop 23 | training_loop.nummodalities=2 24 | training_loop.n_epochs=300 25 | training_loop.use_gpu=True 26 | training_loop.device_numbers=[0] 27 | training_loop.checkpoint_monitor='val_acc' 28 | 29 | # Dataset 30 | get_mvdcndata.make_npy_files=False 31 | get_mvdcndata.num_views=2 32 | get_mvdcndata.num_workers=0 33 | get_mvdcndata.specific_views=[0, 6] -------------------------------------------------------------------------------- /configs/training_random.gin: -------------------------------------------------------------------------------- 1 | # Model 2 | MMTM_MVCNN.pretraining=False 3 | MMTM_MVCNN.num_views=2 4 | 5 | # Train configuration 6 | train.batch_size=8 7 | train.lr=0.1 8 | train.wd=0.0 9 | train.momentum=0 10 | train.callbacks=['CompletedStopping', 'ReduceLROnPlateau_PyTorch', 'Bias_Mitigation_Random'] 11 | ReduceLROnPlateau_PyTorch.metric='loss' 12 | CompletedStopping.patience=5 13 | CompletedStopping.monitor='acc' 14 | ProgressionCallback.other_metrics=['acc_modal_0', 'acc_modal_1', 'val_acc_modal_0', 'val_acc_modal_1', 'd_BDR', 'curation_mode', 'caring_modality'] 15 | 16 | 17 | # Training loop 18 | training_loop.nummodalities=2 19 | training_loop.n_epochs=300 20 | training_loop.use_gpu=True 21 | training_loop.device_numbers=[0] 22 | training_loop.checkpoint_monitor='val_acc' 23 | 24 | # Dataset 25 | get_mvdcndata.make_npy_files=False 26 | get_mvdcndata.num_views=2 27 | get_mvdcndata.num_workers=20 28 | get_mvdcndata.specific_views=[0, 6] -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export PNAME="multimodal_env" 3 | export ROOT="greedy_multi_modal" 4 | 5 | export PYTHONPATH=$PYTHONPATH:$ROOT 6 | 7 | export RESULTS_DIR=$ROOT/saves 8 | export DATA_DIR=$ROOT/modelnet40_images_new_12x 9 | 10 | # Switches off importing out of environment packages 11 | export PYTHONNOUSERSITE=1 12 | 13 | # if [ ! -d "${DATA_DIR}" ]; then 14 | # echo "Creating ${DATA_DIR}" 15 | # mkdir -p ${DATA_DIR} 16 | # fi 17 | 18 | if [ ! -d "${RESULTS_DIR}" ]; then 19 | echo "Creating ${RESULTS_DIR}" 20 | mkdir -p ${RESULTS_DIR} 21 | fi 22 | 23 | echo "Welcome to MULTIMODAL ($1) PROJECT:)" 24 | echo "rooted at $ROOT" 25 | echo "...With PYTHONPATH: $PYTHONPATH" 26 | echo "...With RESULTS_DIR: $RESULTS_DIR" 27 | echo "...With DATA_DIR: $DATA_DIR" -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script. Example run command: bin/train.py save_to_folder configs/cnn.gin. 5 | """ 6 | import os 7 | import gin 8 | from gin.config import _CONFIG 9 | import torch 10 | import pickle 11 | import logging 12 | from functools import partial 13 | logger = logging.getLogger(__name__) 14 | 15 | from src import dataset 16 | from src import callbacks as avail_callbacks 17 | from src.model import MMTM_MVCNN 18 | from src.training_loop import evalution_loop 19 | from src.utils import gin_wrap 20 | 21 | from train import blend_loss, acc 22 | 23 | @gin.configurable 24 | def eval_(save_path, 25 | target_data_split, 26 | pretrained_weights_path, 27 | batch_size=128, 28 | callbacks=[], 29 | ): 30 | 31 | model = MMTM_MVCNN() 32 | train, val, testing = dataset.get_mvdcndata(batch_size=batch_size) 33 | 34 | if target_data_split == 'test': 35 | target_data = testing 36 | elif target_data_split == 'train': 37 | target_data = train 38 | elif target_data_split == 'val': 39 | target_data = val 40 | else: 41 | raise NotImplementedError 42 | 43 | # Create dynamically callbacks 44 | callbacks_constructed = [] 45 | for name in callbacks: 46 | if name in avail_callbacks.__dict__: 47 | clbk = avail_callbacks.__dict__[name]() 48 | callbacks_constructed.append(clbk) 49 | 50 | evalution_loop(model=model, 51 | loss_function=blend_loss, 52 | metrics=[acc], 53 | config=_CONFIG, 54 | save_path=save_path, 55 | test=target_data, 56 | test_steps=len(target_data), 57 | custom_callbacks=callbacks_constructed, 58 | pretrained_weights_path=pretrained_weights_path) 59 | 60 | 61 | if __name__ == "__main__": 62 | gin_wrap(eval_) 63 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relevant constants and configurations 3 | """ 4 | import os 5 | import matplotlib.style 6 | import matplotlib as mpl 7 | from .utils import configure_logger 8 | 9 | os.environ['KERAS_BACKEND'] = 'tensorflow' 10 | 11 | # Configure paths 12 | DATA_DIR = os.environ.get("DATA_DIR", os.path.join(os.path.dirname(__file__), "data")) 13 | RESULTS_DIR = os.environ.get("RESULTS_DIR", os.path.join(os.path.dirname(__file__), "results")) 14 | 15 | # Configure logger 16 | configure_logger('') 17 | -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/balanced_mmtm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/balanced_mmtm.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/framework.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/framework.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/training_loop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/training_loop.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyukat/greedy_multimodal_learning/a190e03755bd8d0883b2c5c0cd7921d2561ce2ee/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/balanced_mmtm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import torchvision.models as models 9 | import glob 10 | import gin 11 | from gin.config import _CONFIG 12 | 13 | from src.utils import numpy_to_torch, torch_to 14 | 15 | @gin.configurable 16 | class MMTM_mitigate(nn.Module): 17 | def __init__(self, 18 | dim_visual, 19 | dim_skeleton, 20 | ratio, 21 | device=0, 22 | SEonly=False, 23 | shareweight=False): 24 | super(MMTM_mitigate, self).__init__() 25 | dim = dim_visual + dim_skeleton 26 | dim_out = int(2*dim/ratio) 27 | self.SEonly = SEonly 28 | self.shareweight = shareweight 29 | 30 | self.running_avg_weight_visual = torch.zeros(dim_visual).to("cuda:{}".format(device)) 31 | self.running_avg_weight_skeleton = torch.zeros(dim_visual).to("cuda:{}".format(device)) 32 | self.step = 0 33 | 34 | if self.SEonly: 35 | self.fc_squeeze_visual = nn.Linear(dim_visual, dim_out) 36 | self.fc_squeeze_skeleton = nn.Linear(dim_skeleton, dim_out) 37 | else: 38 | self.fc_squeeze = nn.Linear(dim, dim_out) 39 | 40 | if self.shareweight: 41 | assert dim_visual == dim_skeleton 42 | self.fc_excite = nn.Linear(dim_out, dim_visual) 43 | else: 44 | self.fc_visual = nn.Linear(dim_out, dim_visual) 45 | self.fc_skeleton = nn.Linear(dim_out, dim_skeleton) 46 | self.relu = nn.ReLU() 47 | self.sigmoid = nn.Sigmoid() 48 | 49 | def forward(self, 50 | visual, 51 | skeleton, 52 | return_scale=False, 53 | return_squeezed_mps = False, 54 | turnoff_cross_modal_flow = False, 55 | average_squeezemaps = None, 56 | curation_mode=False, 57 | caring_modality=0, 58 | ): 59 | 60 | if self.SEonly: 61 | tview = visual.view(visual.shape[:2] + (-1,)) 62 | squeeze = torch.mean(tview, dim=-1) 63 | excitation = self.fc_squeeze_visual(squeeze) 64 | vis_out = self.fc_visual(self.relu(excitation)) 65 | 66 | tview = skeleton.view(skeleton.shape[:2] + (-1,)) 67 | squeeze = torch.mean(tview, dim=-1) 68 | excitation = self.fc_squeeze_skeleton(squeeze) 69 | sk_out = self.fc_skeleton(self.relu(excitation)) 70 | 71 | else: 72 | if turnoff_cross_modal_flow: 73 | 74 | tview = visual.view(visual.shape[:2] + (-1,)) 75 | squeeze = torch.cat([torch.mean(tview, dim=-1), 76 | torch.stack(visual.shape[0]*[average_squeezemaps[1]])], 1) 77 | excitation = self.relu(self.fc_squeeze(squeeze)) 78 | 79 | if self.shareweight: 80 | vis_out = self.fc_excite(excitation) 81 | else: 82 | vis_out = self.fc_visual(excitation) 83 | 84 | tview = skeleton.view(skeleton.shape[:2] + (-1,)) 85 | squeeze = torch.cat([torch.stack(skeleton.shape[0]*[average_squeezemaps[0]]), 86 | torch.mean(tview, dim=-1)], 1) 87 | excitation = self.relu(self.fc_squeeze(squeeze)) 88 | if self.shareweight: 89 | sk_out = self.fc_excite(excitation) 90 | else: 91 | sk_out = self.fc_skeleton(excitation) 92 | 93 | else: 94 | squeeze_array = [] 95 | for tensor in [visual, skeleton]: 96 | tview = tensor.view(tensor.shape[:2] + (-1,)) 97 | squeeze_array.append(torch.mean(tview, dim=-1)) 98 | 99 | squeeze = torch.cat(squeeze_array, 1) 100 | excitation = self.fc_squeeze(squeeze) 101 | excitation = self.relu(excitation) 102 | 103 | if self.shareweight: 104 | sk_out = self.fc_excite(excitation) 105 | vis_out = self.fc_excite(excitation) 106 | else: 107 | vis_out = self.fc_visual(excitation) 108 | sk_out = self.fc_skeleton(excitation) 109 | 110 | vis_out = self.sigmoid(vis_out) 111 | sk_out = self.sigmoid(sk_out) 112 | 113 | self.running_avg_weight_visual = (vis_out.mean(0) + self.running_avg_weight_visual*self.step).detach()/(self.step+1) 114 | self.running_avg_weight_skeleton = (vis_out.mean(0) + self.running_avg_weight_skeleton*self.step).detach()/(self.step+1) 115 | 116 | self.step +=1 117 | 118 | if return_scale: 119 | scales = [vis_out.cpu(), sk_out.cpu()] 120 | else: 121 | scales = None 122 | 123 | if return_squeezed_mps: 124 | squeeze_array = [x.cpu() for x in squeeze_array] 125 | else: 126 | squeeze_array = None 127 | 128 | if not curation_mode: 129 | dim_diff = len(visual.shape) - len(vis_out.shape) 130 | vis_out = vis_out.view(vis_out.shape + (1,) * dim_diff) 131 | 132 | dim_diff = len(skeleton.shape) - len(sk_out.shape) 133 | sk_out = sk_out.view(sk_out.shape + (1,) * dim_diff) 134 | 135 | else: 136 | if caring_modality==0: 137 | dim_diff = len(skeleton.shape) - len(sk_out.shape) 138 | sk_out = sk_out.view(sk_out.shape + (1,) * dim_diff) 139 | 140 | dim_diff = len(visual.shape) - len(vis_out.shape) 141 | vis_out = torch.stack(vis_out.shape[0]*[ 142 | self.running_avg_weight_visual 143 | ]).view(vis_out.shape + (1,) * dim_diff) 144 | 145 | elif caring_modality==1: 146 | dim_diff = len(visual.shape) - len(vis_out.shape) 147 | vis_out = vis_out.view(vis_out.shape + (1,) * dim_diff) 148 | 149 | dim_diff = len(skeleton.shape) - len(sk_out.shape) 150 | sk_out = torch.stack(sk_out.shape[0]*[ 151 | self.running_avg_weight_skeleton 152 | ]).view(sk_out.shape + (1,) * dim_diff) 153 | 154 | return visual * vis_out, skeleton * sk_out, scales, squeeze_array 155 | 156 | 157 | def get_mmtm_outputs(eval_save_path, mmtm_recorded, key): 158 | with open(os.path.join(eval_save_path, 'history.pickle'), 'rb') as f: 159 | his_epo = pickle.load(f) 160 | 161 | print(his_epo.keys()) 162 | data = [] 163 | for batch in his_epo[key][0]: 164 | assert mmtm_recorded == len(batch) 165 | 166 | for mmtmid in range(len(batch)): 167 | if len(data)self.epsilon: 245 | biased_direction=np.sign(self.d_BDR) 246 | self.model_pytoune.curation_mode = True 247 | self.curation_step = 0 248 | 249 | if biased_direction==-1: #BDR0BDR1 252 | self.model_pytoune.caring_modality = 0 253 | else: 254 | self.model_pytoune.curation_mode = False 255 | self.model_pytoune.caring_modality = 0 256 | else: 257 | self.curation_step +=1 258 | if self.curation_step==self.curation_windowsize: 259 | self.model_pytoune.curation_mode=False 260 | else: 261 | self.d_BDR = self.compute_BDR() 262 | self.model_pytoune.curation_mode = False 263 | self.model_pytoune.caring_modality = 0 264 | 265 | def on_epoch_begin(self, epoch, logs): 266 | if epoch>=self.starting_epoch: 267 | self.unlock=True 268 | 269 | @gin.configurable 270 | class Bias_Mitigation_Random(Callback): 271 | 272 | def on_train_begin(self, logs): 273 | self.model_pytoune.curation_mode = False 274 | self.model_pytoune.caring_modality = None 275 | self.unlock=False 276 | self.starting_epoch=2 277 | 278 | def on_batch_end(self, batch, logs): 279 | logs['curation_mode'] = float(self.model_pytoune.curation_mode) 280 | logs['caring_modality'] = self.model_pytoune.caring_modality 281 | 282 | def on_backward_end(self, batch): 283 | if self.unlock: 284 | 285 | mode=random.choice([0,1,2]) 286 | if mode==0: 287 | self.model_pytoune.curation_mode = False 288 | self.model_pytoune.caring_modality = 0 289 | elif mode==1: 290 | self.model_pytoune.curation_mode = True 291 | self.model_pytoune.caring_modality = 1 292 | else: 293 | self.model_pytoune.curation_mode = True 294 | self.model_pytoune.caring_modality = 0 295 | 296 | else: 297 | self.model_pytoune.curation_mode = False 298 | self.model_pytoune.caring_modality = 0 299 | 300 | def on_epoch_begin(self, epoch, logs): 301 | if epoch>=self.starting_epoch: 302 | self.unlock=True 303 | 304 | 305 | @gin.configurable 306 | class CompletedStopping(Callback): 307 | def __init__(self, *, monitor='acc', patience=5, verbose=True): 308 | super(CompletedStopping, self).__init__() 309 | self.monitor = monitor 310 | self.patience = patience 311 | self.verbose = verbose 312 | self.stopped_epoch = 0 313 | 314 | def on_train_begin(self, logs): 315 | self.stopped_epoch = 0 316 | self.counter = 0 317 | 318 | def on_epoch_end(self, epoch, logs): 319 | current = logs[self.monitor] 320 | if current == 100: 321 | self.counter +=1 322 | 323 | if self.counter>=self.patience: 324 | 325 | self.stopped_epoch = epoch 326 | self.model_pytoune.stop_training = True 327 | 328 | def on_train_end(self, logs): 329 | if self.stopped_epoch > 0 and self.verbose: 330 | print('Epoch %05d: completed stopping' % (self.stopped_epoch + 1)) 331 | 332 | 333 | @gin.configurable 334 | class ReduceLROnPlateau_PyTorch(Callback): 335 | def __init__(self, metric, factor=0.3, patience=10): 336 | self.metric = metric 337 | self.factor = factor 338 | self.patience=patience 339 | 340 | def on_train_begin(self, logs): 341 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 342 | mode='min', 343 | factor=self.factor , 344 | patience=self.patience, 345 | verbose=True, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08) 346 | 347 | def on_epoch_end(self, epoch, logs): 348 | self.scheduler.step(logs[self.metric]) 349 | 350 | 351 | class LambdaCallback(Callback): 352 | def __init__(self, 353 | on_epoch_begin=None, 354 | on_epoch_end=None, 355 | on_batch_begin=None, 356 | on_batch_end=None, 357 | on_train_begin=None, 358 | on_train_end=None): 359 | super(LambdaCallback, self).__init__() 360 | if on_epoch_begin is not None: 361 | self.on_epoch_begin = on_epoch_begin 362 | else: 363 | self.on_epoch_begin = lambda epoch, logs: None 364 | if on_epoch_end is not None: 365 | self.on_epoch_end = on_epoch_end 366 | else: 367 | self.on_epoch_end = lambda epoch, logs: None 368 | if on_batch_begin is not None: 369 | self.on_batch_begin = on_batch_begin 370 | else: 371 | self.on_batch_begin = lambda batch, logs: None 372 | if on_batch_end is not None: 373 | self.on_batch_end = on_batch_end 374 | else: 375 | self.on_batch_end = lambda batch, logs: None 376 | if on_train_begin is not None: 377 | self.on_train_begin = on_train_begin 378 | else: 379 | self.on_train_begin = lambda logs: None 380 | if on_train_end is not None: 381 | self.on_train_end = on_train_end 382 | else: 383 | self.on_train_end = lambda logs: None 384 | 385 | 386 | class ModelCheckpoint(Callback): 387 | def __init__(self, filepath, monitor='val_loss', verbose=0, 388 | save_best_only=False, 389 | mode='auto', period=1): 390 | super(ModelCheckpoint, self).__init__() 391 | self.monitor = monitor 392 | self.verbose = verbose 393 | self.filepath = filepath 394 | self.save_best_only = save_best_only 395 | self.period = period 396 | self.epochs_since_last_save = 0 397 | 398 | if mode not in ['auto', 'min', 'max']: 399 | mode = 'auto' 400 | 401 | if mode == 'min': 402 | self.monitor_op = np.less 403 | self.best = np.Inf 404 | elif mode == 'max': 405 | self.monitor_op = np.greater 406 | self.best = -np.Inf 407 | else: 408 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 409 | self.monitor_op = np.greater 410 | self.best = -np.Inf 411 | else: 412 | self.monitor_op = np.less 413 | self.best = np.Inf 414 | 415 | def __getstate__(self): 416 | state = self.__dict__.copy() 417 | del state['model'] 418 | del state['optimizer'] 419 | return state 420 | 421 | def __setstate__(self, newstate): 422 | newstate['model'] = self.model 423 | newstate['optimizer'] = self.optimizer 424 | self.__dict__.update(newstate) 425 | 426 | def on_epoch_end(self, epoch, logs=None): 427 | logs = logs or {} 428 | self.epochs_since_last_save += 1 429 | if self.epochs_since_last_save >= self.period: 430 | self.epochs_since_last_save = 0 431 | if self.save_best_only: 432 | current = logs.get(self.monitor) 433 | if current is None: 434 | logging.warning('Can save best model only with %s available, ' 435 | 'skipping.' % (self.monitor), RuntimeWarning) 436 | else: 437 | if self.monitor_op(current, self.best): 438 | if self.verbose > 0: 439 | print('Epoch %05d: %s improved from %0.5f to %0.5f,' 440 | ' saving model to %s' 441 | % (epoch, self.monitor, self.best, 442 | current, self.filepath)) 443 | self.best = current 444 | save_weights(self.model, self.optimizer, self.filepath) 445 | else: 446 | if self.verbose > 0: 447 | print('Epoch %05d: %s did not improve' % 448 | (epoch, self.monitor)) 449 | else: 450 | if self.verbose > 0: 451 | print('Epoch %05d: saving model to %s' % (epoch, self.filepath)) 452 | save_weights(self.model, self.optimizer, self.filepath) 453 | 454 | 455 | @gin.configurable 456 | class ProgressionCallback(Callback): 457 | def __init__(self, 458 | other_metrics = ['average_iol_current_epoch', 'average_iol']): 459 | 460 | self.other_metrics = [] 461 | for me in other_metrics: 462 | self.other_metrics.append(me) 463 | 464 | def on_train_begin(self, logs): 465 | self.metrics = ['loss'] + self.model_pytoune.metrics_names 466 | self.epochs = self.params['epochs'] 467 | self.steps = self.params['steps'] 468 | 469 | def on_epoch_begin(self, epoch, logs): 470 | self.step_times_sum = 0. 471 | self.epoch = epoch 472 | sys.stdout.write("\rEpoch %d/%d" % (self.epoch, self.epochs)) 473 | sys.stdout.flush() 474 | 475 | def on_epoch_end(self, epoch, logs): 476 | epoch_total_time = logs['time'] 477 | 478 | metrics_str = self._get_metrics_string(logs) 479 | iol_str = self._get_iol_string(logs) 480 | if self.steps is not None: 481 | print("\rEpoch %d/%d %.2fs/%.2fs: Step %d/%d: %s. %s" % 482 | (self.epoch, self.epochs, epoch_total_time, timeit.default_timer()-logs['epoch_begin_time'], self.steps, self.steps, metrics_str, iol_str)) 483 | 484 | else: 485 | print("\rEpoch %d/%d %.2fs/%.2fs: Step %d/%d: %s. %s" % 486 | (self.epoch, self.epochs, epoch_total_time, timeit.default_timer()-logs['epoch_begin_time'], self.last_step, self.last_step, metrics_str, iol_str)) 487 | 488 | def on_batch_end(self, batch, logs): 489 | self.step_times_sum += timeit.default_timer()-logs['batch_begin_time'] 490 | 491 | metrics_str = self._get_metrics_string(logs) 492 | iol_str = self._get_iol_string(logs) 493 | #print(iol_str) 494 | times_mean = self.step_times_sum / batch 495 | if self.steps is not None: 496 | remaining_time = times_mean * (self.steps - batch) 497 | 498 | sys.stdout.write("\rEpoch %d/%d ETA %.2fs Step %d/%d: %s. %s" % 499 | (self.epoch, self.epochs, remaining_time, batch, self.steps, metrics_str, iol_str)) 500 | if 'cumsum_iol' in iol_str: sys.stdout.write("\n") 501 | sys.stdout.flush() 502 | else: 503 | sys.stdout.write("\rEpoch %d/%d %.2fs/step Step %d: %s. %s" % 504 | (self.epoch, self.epochs, times_mean, batch, metrics_str, iol_str)) 505 | sys.stdout.flush() 506 | self.last_step = batch 507 | 508 | def _get_metrics_string(self, logs): 509 | train_metrics_str_gen = ('{}: {:f}'.format(k, logs[k]) for k in self.metrics if logs.get(k) is not None) 510 | val_metrics_str_gen = ('{}: {:f}'.format('val_' + k, logs['val_' + k]) for k in self.metrics 511 | if logs.get('val_' + k) is not None) 512 | return ', '.join(itertools.chain(train_metrics_str_gen, val_metrics_str_gen)) 513 | 514 | def _get_iol_string(self, logs): 515 | str_gen = ['{}: {:f}'.format(k, logs[k]) for k in self.other_metrics if logs.get(k) is not None] 516 | #print(str_gen, '\n',[(k, logs[k]) for k in ['average_iol_current_epoch', 'average_iol']]) 517 | return ', '.join(str_gen) 518 | 519 | class ValidationProgressionCallback(Callback): 520 | def __init__(self, 521 | phase, 522 | metrics_names, 523 | steps=None): 524 | self.params = {} 525 | self.params['steps'] = steps 526 | self.params['phase'] = phase 527 | self.metrics = metrics_names 528 | 529 | super(ValidationProgressionCallback, self).__init__() 530 | 531 | def _get_metrics_string(self, logs): 532 | metrics_str_gen = ('{}: {:f}'.format(self.params['phase'] + '_' + k, logs[k]) for k in self.metrics 533 | if logs.get(k) is not None) 534 | return ', '.join(metrics_str_gen) 535 | 536 | def on_batch_begin(self, batch, logs): 537 | if batch==1: 538 | self.step_times_sum = 0. 539 | 540 | self.steps = self.params['steps'] 541 | 542 | def on_batch_end(self, batch, logs): 543 | self.step_times_sum += timeit.default_timer()-logs['batch_begin_time'] 544 | 545 | metrics_str = self._get_metrics_string(logs) 546 | times_mean = self.step_times_sum / batch 547 | if self.steps is not None: 548 | remaining_time = times_mean * (self.steps - batch) 549 | 550 | sys.stdout.write("\r%s ETA %.2fs Step %d/%d: %s." % 551 | (self.params['phase'], remaining_time, batch, self.steps, metrics_str)) 552 | sys.stdout.flush() 553 | else: 554 | sys.stdout.write("\r%s %.2fs/step Step %d: %s." % 555 | (self.params['phase'], times_mean, batch, metrics_str)) 556 | sys.stdout.flush() 557 | self.last_step = batch 558 | 559 | 560 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import torch.utils.data 4 | import os 5 | import math 6 | import pandas as pd 7 | from skimage import io, transform, util 8 | from PIL import Image 9 | import torch 10 | import torchvision as vision 11 | from torchvision import transforms, datasets 12 | import random 13 | import gin 14 | from gin.config import _CONFIG 15 | import copy 16 | 17 | import timeit 18 | 19 | SEED_FIXED = 100000 20 | 21 | def load_modelviews(files): 22 | imgs = [] 23 | for f in files: 24 | im = np.array(Image.open(f).convert('RGB')) 25 | imgs.append(im) 26 | return np.stack(imgs) 27 | 28 | def load_modelinto_numpy(root_dir, classnames, ending='/*.png', numviews = 12): 29 | set_ = root_dir.split('/')[-1] 30 | parent_dir = root_dir.rsplit('/',2)[0] 31 | filepaths = [] 32 | 33 | data = [] 34 | labels = [] 35 | 36 | for i in range(len(classnames)): 37 | all_files = sorted(glob.glob(parent_dir+'/'+classnames[i]+'/'+set_+ending)) 38 | 39 | nummodels = int(len(all_files)/12) 40 | print('Transformming %d models of class %d - %s into tensor'%(nummodels, i, classnames[i])) 41 | starting_time = timeit.default_timer() 42 | 43 | for m_ind in range(nummodels): 44 | modelimgs = load_modelviews(all_files[m_ind*12:m_ind*12+12]) 45 | #print(all_files[m_ind*12].rsplit('.',2)[0]+'.npy') 46 | with open(all_files[m_ind*12].rsplit('.',2)[0]+'.npy', 'wb') as f: 47 | torch.save(modelimgs, f) 48 | 49 | print('... finished in %.2fs'%(timeit.default_timer() - starting_time)) 50 | 51 | 52 | class MultiviewModelDataset(torch.utils.data.Dataset): 53 | def __init__(self, root_dir, ending='/*.png', 54 | num_views=12, shuffle=True, specific_view=None, transform=None): 55 | 56 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 57 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 58 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 59 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 60 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 61 | self.root_dir = root_dir 62 | 63 | self.num_views = num_views 64 | self.specific_view = specific_view 65 | 66 | self.transform = transform 67 | self.init_filepaths(ending) 68 | 69 | def init_filepaths(self, ending): 70 | self.filepaths = [] 71 | for i in range(len(self.classnames)): 72 | all_files = sorted(glob.glob(self.root_dir.rsplit('/',2)[0]+'/'+self.classnames[i]+'/'+self.root_dir.split('/')[-1]+ending)) 73 | files = [] 74 | for file in all_files: 75 | files.append(file.split('.obj.')[0]) 76 | 77 | files = list(np.unique(np.array(files))) 78 | self.filepaths.extend(files) 79 | 80 | def __len__(self): 81 | return len(self.filepaths) 82 | 83 | def __getitem__(self, idx): 84 | path = self.filepaths[idx] 85 | class_name = path.split('/')[-3] 86 | class_id = self.classnames.index(class_name) 87 | imgs = torch.load(path+'.obj.npy') 88 | trans_imgs = [] 89 | for img, view in zip(imgs[self.specific_view], self.specific_view): 90 | if self.transform: 91 | img = self.transform(img) 92 | trans_imgs.append(img) 93 | data = torch.stack(trans_imgs) 94 | return idx, data, class_id 95 | 96 | 97 | @gin.configurable 98 | def get_mvdcndata( 99 | ending = '/*.png', 100 | root_dir = os.environ['DATA_DIR'], 101 | make_npy_files = False, 102 | valid_size=0.2, 103 | batch_size=8, 104 | random_seed_for_validation = 10, 105 | num_views=12, 106 | num_workers=0, 107 | specific_views=None, 108 | seed=777, 109 | use_cuda=True, 110 | ): 111 | random.seed(seed) 112 | np.random.seed(seed) # cpu vars 113 | torch.manual_seed(seed) # cpu vars 114 | if use_cuda: torch.cuda.manual_seed_all(seed) 115 | 116 | test_transform = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 119 | std=[0.229, 0.224, 0.225]) 120 | ]) 121 | 122 | if make_npy_files: 123 | classnames = ['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 124 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 125 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 126 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 127 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 128 | 129 | load_modelinto_numpy(os.path.join(root_dir, '*/test'), classnames, ending='/*.png', numviews = 12) 130 | load_modelinto_numpy(os.path.join(root_dir, '*/train'), classnames, ending='/*.png', numviews = 12) 131 | 132 | train_transform = transforms.Compose([ 133 | transforms.ToPILImage(), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 137 | std=[0.229, 0.224, 0.225]) 138 | ]) 139 | 140 | test_dataset = MultiviewModelDataset(os.path.join(root_dir, '*', 'test'), 141 | ending=ending, 142 | num_views=num_views, 143 | specific_view=specific_views, 144 | transform=test_transform) 145 | 146 | test_loader = torch.utils.data.DataLoader(test_dataset, 147 | batch_size=batch_size, 148 | shuffle=False, 149 | num_workers=num_workers) 150 | 151 | training = MultiviewModelDataset(os.path.join(root_dir, '*', 'train'), 152 | ending=ending, 153 | num_views=num_views, 154 | specific_view=specific_views, 155 | transform=train_transform) 156 | 157 | num_train = len(training) 158 | indices = list(range(num_train)) 159 | training_idx = indices 160 | 161 | error_msg = "[!] valid_size should be in the range [0, 1]." 162 | assert ((valid_size >= 0) and (valid_size <= 1)), error_msg 163 | 164 | split = int(np.floor(valid_size * num_train)) 165 | random.Random(random_seed_for_validation).shuffle(indices) 166 | training_idx, valid_idx = indices[split:], indices[:split] 167 | 168 | valid_sub = torch.utils.data.Subset(training, valid_idx) 169 | valid_loader = torch.utils.data.DataLoader(valid_sub, 170 | batch_size=batch_size, 171 | shuffle=False, 172 | num_workers=num_workers, 173 | ) 174 | 175 | training_sub = torch.utils.data.Subset(training, training_idx) 176 | 177 | training_loader = torch.utils.data.DataLoader(training_sub, 178 | batch_size=batch_size, 179 | shuffle=True, 180 | num_workers=num_workers, 181 | ) 182 | 183 | return training_loader, valid_loader, test_loader 184 | 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /src/framework.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import timeit 4 | import torch 5 | from functools import partial 6 | import math 7 | import itertools 8 | 9 | from src.callbacks import ( 10 | ValidationProgressionCallback, 11 | ProgressionCallback, 12 | CallbackList, 13 | Callback 14 | ) 15 | from src.utils import numpy_to_torch, torch_to_numpy, torch_to 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | 20 | warning_settings = { 21 | 'batch_size': 'warn' 22 | } 23 | 24 | def cycle(iterable): 25 | while True: 26 | for x in iterable: 27 | yield x 28 | 29 | 30 | def _get_step_iterator(steps, generator): 31 | count_iterator = range(1, steps + 1) if steps is not None else itertools.count(1) 32 | generator = cycle(generator) if steps is not None else generator 33 | return zip(count_iterator, generator) 34 | 35 | 36 | class StepIterator: 37 | def __init__(self, generator, steps_per_epoch, callback, metrics_names, nummodalities): 38 | self.generator = generator 39 | self.steps_per_epoch = steps_per_epoch 40 | self.callback = callback 41 | self.metrics_names = metrics_names 42 | self.nummodalities = nummodalities 43 | 44 | self.losses_sum = 0. 45 | self.metrics_sum = np.zeros(len(self.metrics_names)) 46 | self.metrics_permodal_sum = np.zeros((nummodalities, len(self.metrics_names))) 47 | self.sizes_sum = 0. 48 | self.extra_lists = {} 49 | self.indices_list = [] 50 | 51 | self.defaultfields = ['indices', 52 | 'loss', 53 | 'metrics', 54 | 'viewwises_metrics', 55 | 'number', 56 | 'size' 57 | ] 58 | 59 | @property 60 | def loss(self): 61 | if self.sizes_sum==0: 62 | return 0 63 | else: 64 | return self.losses_sum / self.sizes_sum 65 | 66 | @property 67 | def metrics(self): 68 | if self.sizes_sum==0: 69 | return dict(zip(self.metrics_names, np.zeros(len(self.metrics_names)))) 70 | else: 71 | metrics_dict = dict(zip(self.metrics_names, self.metrics_sum / self.sizes_sum)) 72 | for i in range(self.nummodalities): 73 | names = [f'{x}_modal_{i}' for x in self.metrics_names] 74 | metrics_dict.update(dict(zip(names, self.metrics_permodal_sum[i]/self.sizes_sum))) 75 | 76 | return metrics_dict 77 | 78 | @property 79 | def indices(self): 80 | if self.sizes_sum==0: 81 | return [] 82 | elif self.indices_list[0] is None: 83 | return [] 84 | else: 85 | return np.concatenate(self.indices_list, axis=0) 86 | 87 | def __iter__(self): 88 | for batch_ind, data in _get_step_iterator(self.steps_per_epoch, self.generator): 89 | batch_begin_time = timeit.default_timer() 90 | self.callback.on_batch_begin(batch_ind, {}) 91 | self.callback.on_forward_begin(batch_ind, data) 92 | 93 | step_data = {'number': batch_ind} 94 | step_data['indices'] = data[0] 95 | yield step_data, data[1:] 96 | 97 | self.losses_sum += step_data['loss'] * step_data['size'] 98 | self.metrics_sum += step_data['metrics'] * step_data['size'] 99 | self.metrics_permodal_sum += step_data['viewwises_metrics'] * step_data['size'] 100 | self.sizes_sum += step_data['size'] 101 | self.indices_list.append(step_data['indices']) 102 | 103 | metrics_dict = dict(zip(self.metrics_names, step_data['metrics'])) 104 | 105 | for i in range(self.nummodalities): 106 | names = [f'{x}_modal_{i}' for x in self.metrics_names] 107 | metrics_dict.update(dict(zip(names, step_data['viewwises_metrics'][i]))) 108 | 109 | for key, value in step_data.items(): 110 | if key not in self.defaultfields: 111 | if key in self.extra_lists: 112 | self.extra_lists[key].append(value) 113 | else: 114 | self.extra_lists[key] = [value] 115 | 116 | batch_total_time = timeit.default_timer() - batch_begin_time 117 | 118 | batch_logs = {'batch': batch_ind, 'size': step_data['size'], 119 | 'time': batch_total_time, 'batch_begin_time': batch_begin_time, 120 | 'loss': step_data['loss'], **metrics_dict} 121 | 122 | self.callback.on_batch_end(batch_ind, batch_logs) 123 | 124 | 125 | class Model_: 126 | def __init__(self, model, optimizer, loss_function, nummodalities, *, metrics=[], 127 | verbose=True, hyper_optim=None, vg=None): 128 | self.model = model 129 | self.optimizer = optimizer 130 | self.loss_function = loss_function 131 | self.metrics = metrics 132 | self.metrics_names = [metric.__name__ for metric in self.metrics] 133 | self.device = None 134 | self.verbose = verbose 135 | self.verbose_logs = {} 136 | self.nummodalities = nummodalities 137 | self.curation_mode=False 138 | self.caring_modality=None 139 | 140 | def _compute_loss_and_metrics(self, x, y): 141 | x, y = self._process_input(x, y) 142 | x = x if isinstance(x, (list, tuple)) else (x, ) 143 | 144 | self.minibatch_data = (x, y) 145 | 146 | pred_y_eval, pred_y, scales, squeezed_mps = self.model(*x, 147 | curation_mode=self.curation_mode, 148 | caring_modality=self.caring_modality) 149 | 150 | loss = self.loss_function(pred_y, y) 151 | 152 | record = {} 153 | 154 | with torch.no_grad(): 155 | record['metrics'] = self._compute_metrics(pred_y_eval, y) 156 | record['viewwises_metrics'] = self._compute_metrics_multiple_inputs(pred_y, y) 157 | 158 | if self.model.saving_mmtm_scales: 159 | record['mmtmscales_list'] = scales 160 | if self.model.saving_mmtm_squeeze_array: 161 | record['squeezedmaps_array_list'] = squeezed_mps 162 | 163 | return loss, record 164 | 165 | def _process_input(self, *args): 166 | args = numpy_to_torch(args) 167 | if self.device is not None: 168 | args = torch_to(args, self.device) 169 | return args[0] if len(args) == 1 else args 170 | 171 | def _compute_metrics(self, pred_y, y): 172 | return np.array([float(metric(pred_y, y)) for metric in self.metrics]) 173 | 174 | def _compute_metrics_multiple_inputs(self, list_pred_y, y): 175 | return np.array([self._compute_metrics(pred_y, y) for pred_y in list_pred_y]) 176 | 177 | def _get_batch_size(self, x, y): 178 | if torch.is_tensor(x) or isinstance(x, np.ndarray): 179 | return len(x) 180 | if torch.is_tensor(y) or isinstance(y, np.ndarray): 181 | return len(y) 182 | if warning_settings['batch_size'] == 'warn': 183 | warnings.warn("When 'x' or 'y' are not tensors nor Numpy arrays, " 184 | "the batch size is set to 1 and, thus, the computed " 185 | "loss and metrics at the end of each epoch is the " 186 | "mean of the batches' losses and metrics. To disable " 187 | "this warning, set\n" 188 | "from poutyne.framework import warning_settings\n" 189 | "warning_settings['batch_size'] = 'ignore'") 190 | return 1 191 | 192 | def _transfer_optimizer_state_to_right_device(self): 193 | # Since the optimizer state is loaded on CPU, it will crashed when the 194 | # optimizer will receive gradient for parameters not on CPU. Thus, for 195 | # each parameter, we transfer its state in the optimizer on the same 196 | # device as the parameter itself just before starting the optimization. 197 | for group in self.optimizer.param_groups: 198 | for p in group['params']: 199 | if p in self.optimizer.state: 200 | for _, v in self.optimizer.state[p].items(): 201 | if torch.is_tensor(v) and p.device != v.device: 202 | v.data = v.data.to(p.device) 203 | 204 | def to(self, device): 205 | self.device = device 206 | self.model.to(self.device) 207 | if isinstance(self.loss_function, torch.nn.Module): 208 | self.loss_function.to(self.device) 209 | 210 | for metric in self.metrics: 211 | if isinstance(metric, torch.nn.Module): 212 | metric.to(self.device) 213 | 214 | return self 215 | 216 | def _eval_generator(self, generator, phase, *, steps=None): 217 | if steps is None: 218 | steps = len(generator) 219 | 220 | step_iterator = StepIterator( 221 | generator, 222 | steps, 223 | ValidationProgressionCallback( 224 | phase=phase, 225 | steps=steps, 226 | metrics_names=['loss'] + self.metrics_names 227 | ), 228 | self.metrics_names, 229 | self.nummodalities 230 | ) 231 | 232 | self.model.eval() 233 | with torch.no_grad(): 234 | for step, (x, y) in step_iterator: 235 | step['size'] = self._get_batch_size(x, y) 236 | loss_tensor, info = self._compute_loss_and_metrics(x, y) 237 | step['loss'] = float(loss_tensor) 238 | step.update(info) 239 | 240 | metrics_dict = { 241 | f'{phase}_{metric_name}' : metric for metric_name, metric in step_iterator.metrics.items() 242 | } 243 | 244 | info_dict = {f'{phase}_loss' : step_iterator.loss, 245 | f'{phase}_indices': step_iterator.indices, 246 | **{f'{phase}_{k}':v for k, v in step_iterator.extra_lists.items()}, 247 | **metrics_dict 248 | } 249 | 250 | return info_dict 251 | 252 | def eval_loop(self, test_generator, *, test_steps=None, epochs=1, callbacks=[]): 253 | callback_list = CallbackList(callbacks) 254 | callback_list.set_model_pytoune(self) 255 | callback_list.on_train_begin({}) 256 | epoch = 0 257 | while epoch <=epochs: 258 | epoch_begin_time = timeit.default_timer() 259 | callback_list.on_epoch_begin(epoch, {}) 260 | test_dict = self._eval_generator(test_generator, 'test', steps=test_steps) 261 | 262 | test_dict['epoch'] = epoch 263 | test_dict['time'] = timeit.default_timer() - epoch_begin_time 264 | test_dict['epoch_begin_time'] = epoch_begin_time 265 | 266 | callback_list.on_epoch_end(epoch, test_dict) 267 | 268 | epoch+=1 269 | 270 | def train_loop(self, 271 | train_generator, 272 | test_generator=None, 273 | valid_generator=None, 274 | *, 275 | epochs=1000, 276 | steps_per_epoch=None, 277 | validation_steps=None, 278 | test_steps=None, 279 | callbacks=[], 280 | ): 281 | 282 | self._transfer_optimizer_state_to_right_device() 283 | 284 | callback_list = CallbackList(callbacks) 285 | callback_list.append(ProgressionCallback()) 286 | callback_list.set_model_pytoune(self) 287 | callback_list.set_params({'epochs': epochs, 'steps': steps_per_epoch}) 288 | 289 | self.stop_training = False 290 | 291 | callback_list.on_train_begin({}) 292 | val_dict, test_dict = {}, {} 293 | for epoch in range(1, epochs+1): 294 | callback_list.on_epoch_begin(epoch, {}) 295 | 296 | epoch_begin_time = timeit.default_timer() 297 | 298 | # training 299 | train_step_iterator = StepIterator(train_generator, 300 | steps_per_epoch, 301 | callback_list, 302 | self.metrics_names, 303 | self.nummodalities 304 | ) 305 | self.model.train(True) 306 | with torch.enable_grad(): 307 | for step, (x, y) in train_step_iterator: 308 | step['size'] = self._get_batch_size(x, y) 309 | 310 | self.optimizer.zero_grad() 311 | loss_tensor, info = self._compute_loss_and_metrics(x, y) 312 | 313 | loss_tensor.backward() 314 | callback_list.on_backward_end(step['number']) 315 | self.optimizer.step() 316 | 317 | loss = loss_tensor.item() 318 | step.update(info) 319 | step['loss'] = loss 320 | 321 | if math.isnan(step['loss']): 322 | self.stop_training = True 323 | 324 | train_dict = {'loss': train_step_iterator.loss, 325 | 'train_indices': train_step_iterator.indices, 326 | **{f'train_{k}':v for k, v in train_step_iterator.extra_lists.items()}, 327 | **train_step_iterator.metrics} 328 | 329 | # validation 330 | val_dict = self._eval_generator(valid_generator, 'val', steps=validation_steps) 331 | # test 332 | test_dict = self._eval_generator(test_generator, 'test', steps=test_steps) 333 | 334 | epoch_log = { 335 | 'epoch': epoch, 336 | 'time': timeit.default_timer() - epoch_begin_time, 337 | 'epoch_begin_time': epoch_begin_time, 338 | **train_dict, **val_dict, **test_dict 339 | } 340 | 341 | callback_list.on_epoch_end(epoch, epoch_log) 342 | 343 | if self.stop_training: break 344 | 345 | callback_list.on_train_end({}) 346 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import glob 9 | import gin 10 | from gin.config import _CONFIG 11 | 12 | from src.balanced_mmtm import MMTM_mitigate as MMTM 13 | from src.balanced_mmtm import get_rescale_weights 14 | 15 | @gin.configurable 16 | class MMTM_MVCNN(nn.Module): 17 | def __init__(self, 18 | nclasses=40, 19 | num_views=2, 20 | pretraining=False, 21 | mmtm_off=False, 22 | mmtm_rescale_eval_file_path = None, 23 | mmtm_rescale_training_file_path = None, 24 | device='cuda:0', 25 | saving_mmtm_scales=False, 26 | saving_mmtm_squeeze_array=False, 27 | ): 28 | super(MMTM_MVCNN, self).__init__() 29 | 30 | self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair', 31 | 'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box', 32 | 'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand', 33 | 'person','piano','plant','radio','range_hood','sink','sofa','stairs', 34 | 'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 35 | 36 | self.nclasses = nclasses 37 | self.num_views = num_views 38 | 39 | self.mmtm_off = mmtm_off 40 | if self.mmtm_off: 41 | self.mmtm_rescale = get_rescale_weights( 42 | mmtm_rescale_eval_file_path, 43 | mmtm_rescale_training_file_path, 44 | validation=False, 45 | starting_mmtmindice = 1, 46 | mmtmpositions=4, 47 | device=torch.device(device), 48 | ) 49 | 50 | self.saving_mmtm_scales = saving_mmtm_scales 51 | self.saving_mmtm_squeeze_array = saving_mmtm_squeeze_array 52 | 53 | self.net_view_0 = models.resnet18(pretrained=pretraining) 54 | self.net_view_0.fc = nn.Linear(512, nclasses) 55 | self.net_view_1 = models.resnet18(pretrained=pretraining) 56 | self.net_view_1.fc = nn.Linear(512, nclasses) 57 | 58 | self.mmtm2 = MMTM(128, 128, 4) 59 | self.mmtm3 = MMTM(256, 256, 4) 60 | self.mmtm4 = MMTM(512, 512, 4) 61 | 62 | 63 | def forward(self, x, curation_mode=False, caring_modality=None): 64 | 65 | frames_view_0 = self.net_view_0.conv1(x[:, 0, :]) 66 | frames_view_0 = self.net_view_0.bn1(frames_view_0) 67 | frames_view_0 = self.net_view_0.relu(frames_view_0) 68 | frames_view_0 = self.net_view_0.maxpool(frames_view_0) 69 | 70 | frames_view_1 = self.net_view_1.conv1(x[:, 1, :]) 71 | frames_view_1 = self.net_view_1.bn1(frames_view_1) 72 | frames_view_1 = self.net_view_1.relu(frames_view_1) 73 | frames_view_1 = self.net_view_1.maxpool(frames_view_1) 74 | 75 | frames_view_0 = self.net_view_0.layer1(frames_view_0) 76 | frames_view_1 = self.net_view_1.layer1(frames_view_1) 77 | 78 | scales = [] 79 | squeezed_mps = [] 80 | 81 | for i in [2, 3, 4]: 82 | 83 | frames_view_0 = getattr(self.net_view_0, f'layer{i}')(frames_view_0) 84 | frames_view_1 = getattr(self.net_view_1, f'layer{i}')(frames_view_1) 85 | 86 | frames_view_0, frames_view_1, scale, squeezed_mp = getattr(self, f'mmtm{i}')( 87 | frames_view_0, 88 | frames_view_1, 89 | self.saving_mmtm_scales, 90 | self.saving_mmtm_squeeze_array, 91 | turnoff_cross_modal_flow = True if self.mmtm_off else False, 92 | average_squeezemaps = self.mmtm_rescale[i-1] if self.mmtm_off else None, 93 | curation_mode = curation_mode, 94 | caring_modality = caring_modality 95 | ) 96 | scales.append(scale) 97 | squeezed_mps.append(squeezed_mp) 98 | 99 | frames_view_0 = self.net_view_0.avgpool(frames_view_0) 100 | frames_view_1 = self.net_view_1.avgpool(frames_view_1) 101 | 102 | x_0 = torch.flatten(frames_view_0, 1) 103 | x_0 = self.net_view_0.fc(x_0) 104 | 105 | x_1 = torch.flatten(frames_view_1, 1) 106 | x_1 = self.net_view_1.fc(x_1) 107 | 108 | return (x_0+x_1)/2, [x_0, x_1], scales, squeezed_mps 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /src/training_loop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A gorgeous, self-contained, training loop. Uses Poutyne implementation, but this can be swapped later. 4 | """ 5 | 6 | import logging 7 | import os 8 | import tqdm 9 | import pickle 10 | from functools import partial 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import gin 16 | 17 | from src.callbacks import ModelCheckpoint, LambdaCallback 18 | from src.utils import save_weights 19 | from src.framework import Model_ 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | types_of_instance_to_save_in_csv = (int, float, complex, np.int64, np.int32, np.float32, np.float64, np.float128, str) 24 | types_of_instance_to_save_in_history = (int, float, complex, np.int64, np.int32, np.float32, np.float64, np.ndarray, np.float128,str) 25 | 26 | def _construct_default_callbacks(model, optimizer, H, save_path, checkpoint_monitor, save_with_structure=False): 27 | callbacks = [] 28 | callbacks.append(LambdaCallback(on_epoch_end=partial(_append_to_history_csv, H=H))) 29 | 30 | callbacks.append( 31 | LambdaCallback( 32 | on_epoch_end=partial(_save_history_csv, 33 | save_path=save_path, 34 | H=H, 35 | save_with_structure=save_with_structure) 36 | ) 37 | ) 38 | 39 | callbacks.append(ModelCheckpoint(monitor=checkpoint_monitor, 40 | save_best_only=True, 41 | mode='max', 42 | filepath=os.path.join(save_path, "model_best_val.pt"))) 43 | 44 | def save_weights_fnc(epoch, logs): 45 | logger.info("Saving model from epoch " + str(epoch)) 46 | save_weights(model, optimizer, os.path.join(save_path, "model_last_epoch.pt")) 47 | 48 | callbacks.append(LambdaCallback(on_epoch_end=save_weights_fnc)) 49 | 50 | return callbacks 51 | 52 | 53 | def _save_history_csv(epoch, logs, save_path, H, save_with_structure = False): 54 | out = "" 55 | for key, value in logs.items(): 56 | if isinstance(value, types_of_instance_to_save_in_csv): 57 | out += "{key}={value}\t".format(key=key, value=value) 58 | logger.info(out) 59 | logger.info("Saving history to " + os.path.join(save_path, "history.csv")) 60 | H_tosave = {} 61 | for key, value in H.items(): 62 | if isinstance(value[-1], types_of_instance_to_save_in_csv): 63 | H_tosave[key] = value 64 | pd.DataFrame(H_tosave).to_csv(os.path.join(save_path, "history.csv"), index=False) 65 | if save_with_structure: 66 | with open(os.path.join(save_path, "history.pickle"), 'wb') as f: 67 | pickle.dump(H, f, pickle.HIGHEST_PROTOCOL) 68 | 69 | 70 | def _append_to_history_csv(epoch, logs, H): 71 | for key, value in logs.items(): 72 | if key not in H: 73 | H[key] = [value] 74 | else: 75 | H[key].append(value) 76 | 77 | 78 | def _load_pretrained_model(model, save_path): 79 | checkpoint = torch.load(save_path) 80 | model_dict = model.state_dict() 81 | model_dict.update(checkpoint['model']) 82 | model.load_state_dict(model_dict, strict=False) 83 | logger.info("Done reloading!") 84 | 85 | 86 | @gin.configurable 87 | def training_loop(model, loss_function, metrics, optimizer, config, 88 | save_path, steps_per_epoch, 89 | train=None, valid=None, test=None, 90 | test_steps=None, validation_steps=None, 91 | use_gpu = False, device_numbers = [0], 92 | custom_callbacks=[], 93 | checkpoint_monitor="val_acc", 94 | n_epochs=100, 95 | verbose=True, 96 | nummodalities=2): 97 | 98 | callbacks = list(custom_callbacks) 99 | 100 | history_csv_path = os.path.join(save_path, "history.csv") 101 | history_pkl_path = os.path.join(save_path, "history.pkl") 102 | 103 | logger.info("Removing {} and {}".format(history_pkl_path, history_csv_path)) 104 | os.system("rm " + history_pkl_path) 105 | os.system("rm " + history_csv_path) 106 | 107 | H = {} 108 | 109 | callbacks += _construct_default_callbacks(model, optimizer, H, 110 | save_path, checkpoint_monitor, custom_callbacks) 111 | 112 | # Configure callbacks 113 | for clbk in callbacks: 114 | clbk.set_save_path(save_path) 115 | clbk.set_model(model, ignore=False) # TODO: Remove this trick 116 | clbk.set_optimizer(optimizer) 117 | clbk.set_config(config) 118 | 119 | model = Model_(model=model, 120 | optimizer=optimizer, 121 | loss_function=loss_function, 122 | metrics=metrics, 123 | verbose=verbose, 124 | nummodalities=nummodalities, 125 | ) 126 | 127 | for clbk in callbacks: 128 | clbk.set_model_pytoune(model) 129 | 130 | if use_gpu and torch.cuda.is_available(): 131 | base_device = torch.device("cuda:{}".format(device_numbers[0])) 132 | model.to(base_device) 133 | logger.info("Sending model to {}".format(base_device)) 134 | 135 | _ = model.train_loop(train, 136 | valid_generator=valid, 137 | test_generator=test, 138 | test_steps=test_steps, 139 | validation_steps=validation_steps, 140 | steps_per_epoch=steps_per_epoch, 141 | epochs=n_epochs - 1, 142 | callbacks=callbacks, 143 | ) 144 | 145 | def _construct_default_eval_callbacks(H, save_path, save_with_structure): 146 | 147 | history_batch = os.path.join(save_path, 'eval_history_batch') 148 | if not os.path.exists(history_batch): 149 | os.mkdir(history_batch) 150 | 151 | callbacks = [] 152 | callbacks.append(LambdaCallback(on_epoch_end=partial(_append_to_history_csv, H=H))) 153 | 154 | callbacks.append(LambdaCallback(on_epoch_end=partial(_save_history_csv, 155 | save_path=history_batch, 156 | H=H, 157 | save_with_structure=save_with_structure))) 158 | 159 | return callbacks 160 | 161 | @gin.configurable 162 | def evalution_loop(model, loss_function, metrics, config, 163 | save_path, 164 | test=None, test_steps=None, 165 | use_gpu = False, device_numbers = [0], 166 | custom_callbacks=[], 167 | pretrained_weights_path=None, 168 | save_with_structure=False, 169 | nummodalities=2, 170 | ): 171 | 172 | 173 | _load_pretrained_model(model, pretrained_weights_path) 174 | 175 | history_csv_path = os.path.join(save_path, "eval_history.csv") 176 | history_pkl_path = os.path.join(save_path, "eval_history.pkl") 177 | 178 | logger.info("Removing {} and {}".format(history_pkl_path, history_csv_path)) 179 | os.system("rm " + history_pkl_path) 180 | os.system("rm " + history_csv_path) 181 | 182 | H = {} 183 | callbacks = list(custom_callbacks) 184 | callbacks += _construct_default_eval_callbacks( 185 | H, 186 | save_path, 187 | save_with_structure 188 | ) 189 | 190 | # Configure callbacks 191 | for clbk in callbacks: 192 | clbk.set_save_path(save_path) 193 | clbk.set_model(model, ignore=False) # TODO: Remove this trick 194 | clbk.set_config(config) 195 | 196 | model = Model_(model=model, 197 | optimizer=None, 198 | loss_function=loss_function, 199 | metrics=metrics, 200 | nummodalities=nummodalities) 201 | 202 | if use_gpu and torch.cuda.is_available(): 203 | base_device = torch.device("cuda:{}".format(device_numbers[0])) 204 | model.to(base_device) 205 | logger.info("Sending model to {}".format(base_device)) 206 | 207 | model.eval_loop( 208 | test, 209 | epochs=0, 210 | test_steps=test_steps, 211 | callbacks=callbacks 212 | ) 213 | 214 | 215 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from functools import reduce, partial, wraps 3 | import warnings 4 | import logging 5 | import copy 6 | import argh 7 | import gin 8 | from gin.config import _OPERATIVE_CONFIG 9 | from gin.config import _CONFIG 10 | 11 | import torch 12 | import numpy as np 13 | 14 | from contextlib import contextmanager 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class Fork(object): 19 | def __init__(self, file1, file2): 20 | self.file1 = file1 21 | self.file2 = file2 22 | 23 | def write(self, data): 24 | self.file1.write(data) 25 | self.file2.write(data) 26 | 27 | def flush(self): 28 | self.file1.flush() 29 | self.file2.flush() 30 | 31 | 32 | @contextmanager 33 | def replace_logging_stream(file_): 34 | root = logging.getLogger() 35 | if len(root.handlers) != 1: 36 | print(root.handlers) 37 | raise ValueError("Don't know what to do with many handlers") 38 | if not isinstance(root.handlers[0], logging.StreamHandler): 39 | raise ValueError 40 | stream = root.handlers[0].stream 41 | root.handlers[0].stream = file_ 42 | try: 43 | yield 44 | finally: 45 | root.handlers[0].stream = stream 46 | 47 | 48 | @contextmanager 49 | def replace_standard_stream(stream_name, file_): 50 | stream = getattr(sys, stream_name) 51 | setattr(sys, stream_name, file_) 52 | try: 53 | yield 54 | finally: 55 | setattr(sys, stream_name, stream) 56 | 57 | 58 | def gin_wrap(fnc): 59 | def main(save_path, config, bindings=""): 60 | # You can pass many configs (think of them as mixins), and many bindings. Both ";" separated. 61 | gin.parse_config_files_and_bindings(config.split("#"), bindings.replace("#", "\n")) 62 | if not os.path.exists(save_path): 63 | logger.info("Creating folder " + save_path) 64 | os.system("mkdir -p " + save_path) 65 | run_with_redirection(os.path.join(save_path, "stdout.txt"), 66 | os.path.join(save_path, "stderr.txt"), 67 | fnc)(save_path) 68 | argh.dispatch_command(main) 69 | 70 | 71 | def run_with_redirection(stdout_path, stderr_path, func): 72 | def func_wrapper(*args, **kwargs): 73 | with open(stdout_path, 'a', 1) as out_dst: 74 | with open(stderr_path, 'a', 1) as err_dst: 75 | out_fork = Fork(sys.stdout, out_dst) 76 | err_fork = Fork(sys.stderr, err_dst) 77 | with replace_standard_stream('stderr', err_fork): 78 | with replace_standard_stream('stdout', out_fork): 79 | with replace_logging_stream(err_fork): 80 | func(*args, **kwargs) 81 | 82 | return func_wrapper 83 | 84 | 85 | def _apply(obj, func): 86 | if isinstance(obj, (list, tuple)): 87 | return type(obj)(_apply(el, func) for el in obj) 88 | if isinstance(obj, dict): 89 | return {k: _apply(el, func) for k, el in obj.items()} 90 | return func(obj) 91 | 92 | 93 | def torch_apply(obj, func): 94 | fn = lambda t: func(t) if torch.is_tensor(t) else t 95 | return _apply(obj, fn) 96 | 97 | 98 | def torch_to(obj, *args, **kargs): 99 | return torch_apply(obj, lambda t: t.to(*args, **kargs)) 100 | 101 | 102 | def numpy_to_torch(obj): 103 | fn = lambda a: torch.from_numpy(a) if isinstance(a, np.ndarray) else a 104 | return _apply(obj, fn) 105 | 106 | 107 | def save_weights(model, optimizer, filename): 108 | """ 109 | Save all weights necessary to resume training 110 | """ 111 | state = { 112 | 'model': model.state_dict(), 113 | 'optimizer': optimizer.state_dict(), 114 | } 115 | torch.save(state, filename) 116 | 117 | 118 | def numpy_to_torch(obj): 119 | fn = lambda a: torch.from_numpy(a) if isinstance(a, np.ndarray) else a 120 | return _apply(obj, fn) 121 | 122 | 123 | def torch_to_numpy(obj, copy=False): 124 | if copy: 125 | func = lambda t: t.cpu().detach().numpy().copy() 126 | else: 127 | func = lambda t: t.cpu().detach().numpy() 128 | return torch_apply(obj, func) 129 | 130 | 131 | def configure_logger(name='', 132 | console_logging_level=logging.INFO, 133 | file_logging_level=None, 134 | log_file=None): 135 | """ 136 | Configures logger 137 | :param name: logger name (default=module name, __name__) 138 | :param console_logging_level: level of logging to console (stdout), None = no logging 139 | :param file_logging_level: level of logging to log file, None = no logging 140 | :param log_file: path to log file (required if file_logging_level not None) 141 | :return instance of Logger class 142 | """ 143 | 144 | if file_logging_level is None and log_file is not None: 145 | print("Didnt you want to pass file_logging_level?") 146 | 147 | if len(logging.getLogger(name).handlers) != 0: 148 | print("Already configured logger '{}'".format(name)) 149 | return 150 | 151 | if console_logging_level is None and file_logging_level is None: 152 | return # no logging 153 | 154 | logger = logging.getLogger(name) 155 | logger.handlers = [] 156 | logger.setLevel(logging.DEBUG) 157 | format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 158 | 159 | if console_logging_level is not None: 160 | ch = logging.StreamHandler(sys.stdout) 161 | ch.setFormatter(format) 162 | ch.setLevel(console_logging_level) 163 | logger.addHandler(ch) 164 | 165 | if file_logging_level is not None: 166 | if log_file is None: 167 | raise ValueError("If file logging enabled, log_file path is required") 168 | fh = logging.handlers.RotatingFileHandler(log_file, maxBytes=(1048576 * 5), backupCount=7) 169 | fh.setFormatter(format) 170 | logger.addHandler(fh) 171 | 172 | logger.info("Logging configured!") 173 | 174 | return logger 175 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Trainer script. Example run command: train.py save_to_folder configs/cnn.gin. 5 | """ 6 | import os 7 | import gin 8 | from gin.config import _CONFIG 9 | import torch 10 | import pickle 11 | import logging 12 | from functools import partial 13 | logger = logging.getLogger(__name__) 14 | 15 | from src import dataset 16 | from src import callbacks as avail_callbacks 17 | from src.model import MMTM_MVCNN 18 | from src.training_loop import training_loop 19 | from src.utils import gin_wrap 20 | 21 | 22 | def blend_loss(y_hat, y): 23 | loss_func = torch.nn.CrossEntropyLoss() 24 | losses = [] 25 | for y_pred in y_hat: 26 | losses.append(loss_func(y_pred, y)) 27 | 28 | return sum(losses) 29 | 30 | 31 | def acc(y_pred, y_true): 32 | if isinstance(y_pred, list): 33 | y_pred = torch.mean(torch.stack([out.data for out in y_pred], 0), 0) 34 | _, y_pred = y_pred.max(1) 35 | if len(y_true)==2: 36 | acc_pred = (y_pred == y_true[0]).float().mean() 37 | else: 38 | acc_pred = (y_pred == y_true).float().mean() 39 | return acc_pred * 100 40 | 41 | 42 | @gin.configurable 43 | def train(save_path, wd, lr, momentum, batch_size, callbacks=[]): 44 | model = MMTM_MVCNN() 45 | train, valid, test = dataset.get_mvdcndata(batch_size=batch_size) 46 | 47 | optimizer = torch.optim.SGD(model.parameters(), 48 | lr=lr, 49 | weight_decay=wd, 50 | momentum=momentum) 51 | 52 | callbacks_constructed = [] 53 | for name in callbacks: 54 | if name in avail_callbacks.__dict__: 55 | clbk = avail_callbacks.__dict__[name]() 56 | callbacks_constructed.append(clbk) 57 | 58 | training_loop(model=model, 59 | optimizer=optimizer, 60 | loss_function=blend_loss, 61 | metrics=[acc], 62 | train=train, valid=valid, test=test, 63 | steps_per_epoch=len(train), 64 | validation_steps=len(valid), 65 | test_steps=len(test), 66 | save_path=save_path, 67 | config=_CONFIG, 68 | custom_callbacks=callbacks_constructed 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | gin_wrap(train) 74 | --------------------------------------------------------------------------------