├── .gitignore ├── LICENSE ├── README.md ├── a-domainbed ├── main.py ├── prepare_data.sh ├── run_da.sh ├── run_ood.sh └── visual.py ├── a-imageclef ├── main.py ├── prepare_data.sh ├── run_da.sh ├── run_ood.sh └── visual.py ├── a-mnist ├── main.py ├── makedata.py ├── makedata.sh ├── run_da.sh └── run_ood.sh ├── arch ├── __init__.py ├── backbone.py ├── cnn.py ├── mlp.py └── mlpstru.json ├── csg-intro.png ├── distr ├── __init__.py ├── base.py ├── instances.py ├── tools.py └── utils.py ├── methods ├── __init__.py ├── cnbb.py ├── semvar.py ├── supvae.py └── xdistr.py ├── requirements.txt ├── test ├── distr_test.ipynb ├── distr_test.py └── utils_test.py └── utils ├── __init__.py ├── preprocess ├── __init__.py ├── data_list.py ├── data_loader.py └── data_provider.py ├── reprun.sh ├── utils.py └── utils_main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Customized 132 | *~ 133 | *.swp 134 | 135 | ckpt_* 136 | data/ 137 | DomainBed/ 138 | 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Causal Semantic Representation for Out-of-Distribution Prediction 2 | 3 | This repository is the official implementation of "[Learning Causal Semantic Representation for Out-of-Distribution Prediction](https://arxiv.org/abs/2011.01681)" (NeurIPS 2021). 4 | 5 | [Chang Liu][changliu] \<\>, 6 | Xinwei Sun, Jindong Wang, Haoyue Tang, Tao Li, Tao Qin, Wei Chen, Tie-Yan Liu.\ 7 | \[[Paper & Appendix](https://changliu00.github.io/causupv/causupv.pdf)\] 8 | \[[Slides](https://changliu00.github.io/causupv/causupv-slides.pdf)\] 9 | \[[Video](https://recorder-v3.slideslive.com/?share=52713&s=7a03cf16-4993-4e27-8502-7461239c487d)\] 10 | \[[Poster](https://changliu00.github.io/causupv/causupv-poster.pdf)\] 11 | 12 | ## Introduction 13 | 14 | ![graphical summary](./csg-intro.png) 15 | 16 | The work proposes a Causal Semantic Generative model (CSG) for OOD generalization (_single-source_ domain generalization) and domain adaptation. 17 | The model is developed following a causal reasoning process, and prediction is made by leveraging the _causal invariance principle_. 18 | Training and prediction algorithms are developed based on variational Bayes with a novel design. 19 | Theoretical guarantees on the identifiability of the causal factor and the benefits for OOD prediction are presented. 20 | 21 | This codebase implements the CSG methods, and implements or integrates various baselines. 22 | Most domain adaptation baselines (except [BNM](https://github.com/cuishuhao/BNM)) use the [dalib](https://github.com/thuml/Transfer-Learning-Library) package. 23 | The experiment setups on the PACS and VLCS datasets are adopted from the [domainbed](https://github.com/facebookresearch/DomainBed) repository. 24 | Authorships are clarified in each file or module. 25 | 26 | ## Requirements 27 | 28 | The code requires python version >= 3.6, and is based on [PyTorch](https://github.com/pytorch/pytorch). To install requirements: 29 | 30 | ```setup 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Usage 35 | 36 | Folder `a-mnist` contains scripts to run the experiments on the **Shifted-MNIST** dataset, 37 | and `a-imageclef` on the [**ImageCLEF-DA**](http://imageclef.org/2014/adaptation) dataset, 38 | and `a-domainbed` on the [**PACS**](https://openaccess.thecvf.com/content_ICCV_2017/papers/Li_Deeper_Broader_and_ICCV_2017_paper.pdf) and [**VLCS**](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf) datasets 39 | (the prefix `a-` represents "application"). 40 | 41 | Go to the respective folder and run the `prepare_data.sh` or `makedata.sh` script there to prepare the datasets. 42 | Run the `run_ood.sh` (for OOD generalization methods) and `run_da.sh` (for domain adaptation methods) scripts to train the models. 43 | Evaluation result (accuracy on the test domain) is printed and written to disk with the model and configurations. 44 | See the commands in the script files or `python3 main.py --help` for customized usage or hyperparameter tuning. 45 | 46 | ## Citation 47 | ``` 48 | @inproceedings{liu2021learning, 49 | author = {Liu, Chang and Sun, Xinwei and Wang, Jindong and Tang, Haoyue and Li, Tao and Qin, Tao and Chen, Wei and Liu, Tie-Yan}, 50 | booktitle = {Advances in Neural Information Processing Systems}, 51 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 52 | pages = {6155--6170}, 53 | publisher = {Curran Associates, Inc.}, 54 | title = {Learning Causal Semantic Representation for Out-of-Distribution Prediction}, 55 | url = {https://proceedings.neurips.cc/paper/2021/file/310614fca8fb8e5491295336298c340f-Paper.pdf}, 56 | volume = {34}, 57 | year = {2021} 58 | } 59 | ``` 60 | 61 | [changliu]: https://changliu00.github.io/ 62 | 63 | -------------------------------------------------------------------------------- /a-domainbed/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import warnings 3 | import sys 4 | import torch as tc 5 | sys.path.append("..") 6 | from utils.utils_main import main_stem, get_parser, is_ood, process_continue_run 7 | from utils.preprocess import data_loader 8 | from utils.utils import boolstr, ZipLongest 9 | from DomainBed.domainbed import datasets 10 | from DomainBed.domainbed.lib import misc 11 | from DomainBed.domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 12 | 13 | __author__ = "Chang Liu" 14 | __email__ = "changliu@microsoft.com" 15 | # tc.autograd.set_detect_anomaly(True) 16 | 17 | class MergeIters: 18 | def __init__(self, *itrs): 19 | self.itrs = itrs 20 | self.zipped = ZipLongest(*itrs) 21 | self.len = len(self.zipped) 22 | 23 | def __iter__(self): 24 | for vals in self.zipped: 25 | yield tuple(tc.cat([val[i] for val in vals]) for i in range(len(vals[0]))) 26 | 27 | def __len__(self): return self.len 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = get_parser() 32 | parser.add_argument("--data_root", type = str, default = "./DomainBed/domainbed/data/") 33 | parser.add_argument('--dataset', type = str, default = "PACS") 34 | parser.add_argument("--testdoms", type = int, nargs = '+', default = [0]) 35 | parser.add_argument("--n_bat_test", type = int, default = None) 36 | parser.add_argument("--traindoms", type = int, nargs = '+', default = None) # default: 'other' if `excl_test` else 'all' 37 | parser.add_argument("--excl_test", type = boolstr, default = True) # only active when `traindoms` is None (by default) 38 | parser.add_argument("--uda_frac", type = float, default = 1.) 39 | parser.add_argument("--data_aug", type = boolstr, default = True) 40 | 41 | parser.add_argument("--dim_s", type = int, default = 512) 42 | parser.add_argument("--dim_v", type = int, default = 128) 43 | parser.add_argument("--dim_btnk", type = int, default = 1024) # for discr_model 44 | parser.add_argument("--dims_bb2bn", type = int, nargs = '*') # for discr_model 45 | parser.add_argument("--dims_bn2s", type = int, nargs = '*') # for discr_model 46 | parser.add_argument("--dims_s2y", type = int, nargs = '*') # for discr_model 47 | parser.add_argument("--dims_bn2v", type = int, nargs = '*') # for discr_model 48 | parser.add_argument("--vbranch", type = boolstr, default = False) # for discr_model 49 | parser.add_argument("--dim_feat", type = int, default = 256) # for gen_model 50 | 51 | parser.set_defaults(discrstru = "DBresnet50", genstru = "DCGANpretr", 52 | n_bat = 32, n_epk = 40, eval_interval = 1, 53 | optim = "Adam", lr = 5e-5, wl2 = 5e-4, 54 | # momentum = .9, nesterov = True, lr_expo = .75, lr_wdatum = 6.25e-6, # only when "lr" is "SGD" 55 | sig_s = 3e+1, sig_v = 3e+1, corr_sv = .7, tgt_mvn_prior = True, src_mvn_prior = True, 56 | pstd_x = 1e-1, qstd_s = -1., qstd_v = -1., 57 | wgen = 1e-7, wsup = 0., wlogpi = 1., 58 | wda = .25, 59 | domdisc_dimh = 1024, # for {dann, cdan, mdd} only 60 | cdan_rand = False, # for cdan only 61 | ker_alphas = [.5, 1., 2.], # for dan only 62 | mdd_margin = 4. # for mdd only 63 | ) 64 | ag = parser.parse_args() 65 | if ag.wlogpi is None: ag.wlogpi = ag.wgen 66 | if ag.n_bat_test is None: ag.n_bat_test = ag.n_bat 67 | ag, ckpt = process_continue_run(ag) 68 | IS_OOD = is_ood(ag.mode) 69 | 70 | ag.data_dir = ag.data_root 71 | ag.test_envs = ag.testdoms 72 | ag.holdout_fraction = 1. - ag.tr_val_split 73 | ag.uda_holdout_fraction = ag.uda_frac 74 | ag.trial_seed = 0. 75 | hparams = {'batch_size': ag.n_bat, 'class_balanced': False, 'data_augmentation': ag.data_aug} 76 | 77 | # BEGIN: from 'domainbed.scripts.train.py' 78 | if ag.dataset in vars(datasets): 79 | dataset = vars(datasets)[ag.dataset](ag.data_dir, 80 | ag.test_envs, hparams) 81 | else: 82 | raise NotImplementedError 83 | 84 | # (customed plugin) 85 | if ag.traindoms is None: 86 | ag.traindoms = list(i for i in range(len(dataset)) if not ag.excl_test or i not in ag.test_envs) 87 | ag.traindom = ag.traindoms # for printing info in `main_stem` 88 | # (end) 89 | 90 | # Split each env into an 'in-split' and an 'out-split'. We'll train on 91 | # each in-split except the test envs, and evaluate on all splits. 92 | 93 | # To allow unsupervised domain adaptation experiments, we split each test 94 | # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used 95 | # by collect_results.py to compute classification accuracies. The 96 | # 'out-split' is used by the Oracle model selectino method. The unlabeled 97 | # samples in 'uda-split' are passed to the algorithm at training time if 98 | # args.task == "domain_adaptation". If we are interested in comparing 99 | # domain generalization and domain adaptation results, then domain 100 | # generalization algorithms should create the same 'uda-splits', which will 101 | # be discared at training. 102 | in_splits = [] 103 | out_splits = [] 104 | uda_splits = [] 105 | for env_i, env in enumerate(dataset): 106 | uda = [] 107 | 108 | out, in_ = misc.split_dataset(env, 109 | int(len(env)*ag.holdout_fraction), 110 | misc.seed_hash(ag.trial_seed, env_i)) 111 | 112 | if env_i in ag.test_envs: 113 | uda, in_ = misc.split_dataset(in_, 114 | int(len(in_)*ag.uda_holdout_fraction), 115 | misc.seed_hash(ag.trial_seed, env_i)) 116 | 117 | if hparams['class_balanced']: 118 | in_weights = misc.make_weights_for_balanced_classes(in_) 119 | out_weights = misc.make_weights_for_balanced_classes(out) 120 | if uda is not None: 121 | uda_weights = misc.make_weights_for_balanced_classes(uda) 122 | else: 123 | in_weights, out_weights, uda_weights = None, None, None 124 | in_splits.append((in_, in_weights)) 125 | out_splits.append((out, out_weights)) 126 | if len(uda): 127 | uda_splits.append((uda, uda_weights)) 128 | # Now `in_splits` and `out_splits` contain used-validation splits for all envs, and `uda_splits` contains the part of `in_splits` for uda for test envs only. 129 | 130 | if len(uda_splits) == 0: # args.task == "domain_adaptation" and len(uda_splits) == 0: 131 | raise ValueError("Not enough unlabeled samples for domain adaptation.") 132 | 133 | train_loaders = [FastDataLoader( # InfiniteDataLoader( 134 | dataset=env, 135 | # weights=env_weights, 136 | batch_size=hparams['batch_size'], 137 | num_workers=dataset.N_WORKERS) 138 | for i, (env, env_weights) in enumerate(in_splits) 139 | if i in ag.traindoms] 140 | 141 | val_loaders = [FastDataLoader( # InfiniteDataLoader( 142 | dataset=env, 143 | # weights=env_weights, 144 | batch_size=ag.n_bat_test, # hparams['batch_size'], 145 | num_workers=dataset.N_WORKERS) 146 | for i, (env, env_weights) in enumerate(out_splits) 147 | if i in ag.traindoms] 148 | 149 | uda_loaders = [FastDataLoader( # InfiniteDataLoader( 150 | dataset=env, 151 | # weights=env_weights, 152 | batch_size = hparams['batch_size'] * len(train_loaders), # =hparams['batch_size'], 153 | num_workers=dataset.N_WORKERS) 154 | for i, (env, env_weights) in enumerate(uda_splits) 155 | # if i in args.test_envs 156 | ] 157 | 158 | # eval_loaders = [FastDataLoader( 159 | # dataset=env, 160 | # batch_size=64, 161 | # num_workers=dataset.N_WORKERS) 162 | # for env, _ in (in_splits + out_splits + uda_splits)] 163 | # eval_weights = [None for _, weights in (in_splits + out_splits + uda_splits)] 164 | # eval_loader_names = ['env{}_in'.format(i) 165 | # for i in range(len(in_splits))] 166 | # eval_loader_names += ['env{}_out'.format(i) 167 | # for i in range(len(out_splits))] 168 | # eval_loader_names += ['env{}_uda'.format(i) 169 | # for i in range(len(uda_splits))] 170 | # END 171 | 172 | archtype = "cnn" 173 | shape_x = dataset.input_shape 174 | dim_y = dataset.num_classes 175 | tr_src_loader = MergeIters(*train_loaders) 176 | val_src_loader = MergeIters(*val_loaders) 177 | ls_ts_tgt_loader = uda_loaders 178 | if not IS_OOD: 179 | ls_tr_tgt_loader = uda_loaders 180 | 181 | if IS_OOD: 182 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 183 | tr_src_loader, val_src_loader, ls_ts_tgt_loader ) 184 | else: 185 | for testdom, tr_tgt_loader, ts_tgt_loader in zip( 186 | ag.testdoms, ls_tr_tgt_loader, ls_ts_tgt_loader): 187 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 188 | tr_src_loader, val_src_loader, None, 189 | tr_tgt_loader, ts_tgt_loader, testdom ) 190 | 191 | -------------------------------------------------------------------------------- /a-domainbed/prepare_data.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/facebookresearch/DomainBed 2 | cd DomainBed/ 3 | git checkout 2deb150 4 | python3 -m pip install gdown==3.13.0 5 | # python3 -m pip install wilds==1.1.0 torch_scatter # Installing `torch_scatter` seems quite involved. Turn to edit the files to exclude the import. 6 | vi "+14norm I# " "+15norm I# " "+271norm I# " "+267s/# //" "+270s/# //" "+wq" domainbed/scripts/download.py 7 | vi "+12norm I# " "+13norm I# " "+wq" domainbed/datasets.py 8 | python3 -m domainbed.scripts.download --data_dir=./domainbed/data 9 | 10 | -------------------------------------------------------------------------------- /a-domainbed/run_da.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | 4 | case $1 in 5 | dann) 6 | # Results taken from "In search of lost domain generalization" (ICLR'21). 7 | ;; 8 | cdan) 9 | # Results taken from "In search of lost domain generalization" (ICLR'21). 10 | ;; 11 | dan) 12 | ##dataset PACS 13 | $REPRUN $PYCMD $1 --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1e-2 14 | $REPRUN $PYCMD $1 --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1e-2 15 | $REPRUN $PYCMD $1 --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1e-2 16 | $REPRUN $PYCMD $1 --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1e-2 17 | ##dataset VLCS 18 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1e-2 19 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1e-2 20 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1e-2 21 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1e-2 22 | ;; 23 | mdd) 24 | ##dataset PACS 25 | $REPRUN $PYCMD $1 --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1e-2 26 | $REPRUN $PYCMD $1 --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1e-2 27 | $REPRUN $PYCMD $1 --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1e-2 28 | $REPRUN $PYCMD $1 --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1e-2 29 | ##dataset VLCS 30 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1e-2 31 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1e-2 32 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1e-2 33 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1e-2 34 | ;; 35 | bnm) 36 | ##dataset PACS 37 | $REPRUN $PYCMD $1 --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1. 38 | $REPRUN $PYCMD $1 --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1. 39 | $REPRUN $PYCMD $1 --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1. 40 | $REPRUN $PYCMD $1 --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1. 41 | ##dataset VLCS 42 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wl2 5e-4 --wsup 1. --wda 1. 43 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wl2 5e-4 --wsup 1. --wda 1. 44 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wl2 5e-4 --wsup 1. --wda 1. 45 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wl2 5e-4 --wsup 1. --wda 1. 46 | ;; 47 | svae-da) # CSGz-DA 48 | ##dataset PACS 49 | $REPRUN $PYCMD $1 --testdoms 0 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 50 | $REPRUN $PYCMD $1 --testdoms 1 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 51 | $REPRUN $PYCMD $1 --testdoms 2 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 52 | $REPRUN $PYCMD $1 --testdoms 3 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 53 | ##dataset VLCS 54 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 55 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 56 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 57 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 58 | ;; 59 | svgm-da) # CSG-DA 60 | ##dataset PACS 61 | $REPRUN $PYCMD $1 --testdoms 0 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 62 | $REPRUN $PYCMD $1 --testdoms 1 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 63 | $REPRUN $PYCMD $1 --testdoms 2 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 64 | $REPRUN $PYCMD $1 --testdoms 3 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 65 | ##dataset VLCS 66 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 67 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 68 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 69 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --pstd_x 3e-1 --wl2 5e-4 --wsup 1. --wlogpi 0. --wgen 1e-8 --wda 1e-8 70 | ;; 71 | *) 72 | echo "unknown argument $1" 73 | ;; 74 | esac 75 | 76 | -------------------------------------------------------------------------------- /a-domainbed/run_ood.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | 4 | case $1 in 5 | discr) # CE 6 | # Results taken from "In search of lost domain generalization" (ICLR'21). 7 | ;; 8 | cnbb) 9 | ##dataset PACS 10 | $REPRUN $PYCMD $1 --testdoms 0 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 11 | $REPRUN $PYCMD $1 --testdoms 1 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 12 | $REPRUN $PYCMD $1 --testdoms 2 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 13 | $REPRUN $PYCMD $1 --testdoms 3 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 14 | ##dataset VLCS 15 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 16 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 17 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 18 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 19 | ;; 20 | svae) # CSGz 21 | ##dataset PACS 22 | $REPRUN $PYCMD $1 --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 23 | $REPRUN $PYCMD $1 --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 24 | $REPRUN $PYCMD $1 --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 25 | $REPRUN $PYCMD $1 --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 26 | ##dataset VLCS 27 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 28 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 29 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 30 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 31 | ;; 32 | svgm) # CSG 33 | ##dataset PACS 34 | $REPRUN $PYCMD $1 --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 35 | $REPRUN $PYCMD $1 --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 36 | $REPRUN $PYCMD $1 --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 37 | $REPRUN $PYCMD $1 --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 38 | ##dataset VLCS 39 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 40 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 41 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 42 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 43 | ;; 44 | svgm-ind) # CSG-ind 45 | ##dataset PACS 46 | $REPRUN $PYCMD $1 --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 47 | $REPRUN $PYCMD $1 --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 48 | $REPRUN $PYCMD $1 --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 49 | $REPRUN $PYCMD $1 --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 50 | ##dataset VLCS 51 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 0 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 52 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 1 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 53 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 2 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 54 | $REPRUN $PYCMD $1 --dataset VLCS --testdoms 3 --wsup 1. --wlogpi 0. --wl2 5e-4 --wgen 1e-7 --pstd_x 3e-1 55 | ;; 56 | *) 57 | echo "unknown argument $1" 58 | ;; 59 | esac 60 | 61 | -------------------------------------------------------------------------------- /a-imageclef/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import warnings 3 | import sys 4 | sys.path.append("..") 5 | from utils.utils_main import main_stem, get_parser, is_ood, process_continue_run 6 | from utils.preprocess import data_loader 7 | from utils.utils import boolstr 8 | 9 | __author__ = "Chang Liu" 10 | __email__ = "changliu@microsoft.com" 11 | # tc.autograd.set_detect_anomaly(True) 12 | 13 | if __name__ == "__main__": 14 | parser = get_parser() 15 | parser.add_argument("--data_root", type = str, default = "./data/image_CLEF/") 16 | parser.add_argument("--traindom", type = str, default = "b") 17 | parser.add_argument("--testdoms", type = str, nargs = '+', default = ["b", "c", "i", "p"]) 18 | 19 | parser.add_argument("--dim_s", type = int, default = 1024) 20 | parser.add_argument("--dim_v", type = int, default = 256) 21 | parser.add_argument("--dim_btnk", type = int, default = 1024) # for discr_model 22 | parser.add_argument("--dims_bb2bn", type = int, nargs = '*') # for discr_model 23 | parser.add_argument("--dims_bn2s", type = int, nargs = '*') # for discr_model 24 | parser.add_argument("--dims_s2y", type = int, nargs = '*') # for discr_model 25 | parser.add_argument("--dims_bn2v", type = int, nargs = '*') # for discr_model 26 | parser.add_argument("--vbranch", type = boolstr, default = False) # for discr_model 27 | parser.add_argument("--dim_feat", type = int, default = 128) # for gen_model 28 | 29 | parser.set_defaults(discrstru = "ResNet50", genstru = "DCGANvar", 30 | n_bat = 32, n_epk = 100, 31 | optim = "SGD", lr = 4e-3, wl2 = 5e-4, 32 | momentum = .9, nesterov = True, lr_expo = .75, lr_wdatum = 6.25e-6, # only when "lr" is "SGD" 33 | wda = .25, 34 | domdisc_dimh = 1024, # for {dann, cdan, mdd} only 35 | cdan_rand = False, # for cdan only 36 | ker_alphas = [.5, 1., 2.], # for dan only 37 | mdd_margin = 4. # for mdd only 38 | ) 39 | ag = parser.parse_args() 40 | if ag.wlogpi is None: ag.wlogpi = ag.wgen 41 | ag, ckpt = process_continue_run(ag) 42 | IS_OOD = is_ood(ag.mode) 43 | 44 | archtype = "cnn" 45 | # Dataset 46 | shape_x = (3, 224, 224) # determined by the loader 47 | dim_y = 12 48 | kwargs = {'num_workers': 4, 'pin_memory': True} 49 | tr_src_loader, val_src_loader = data_loader.load_training( 50 | ag.data_root, ag.traindom, ag.n_bat, kwargs, 51 | ag.tr_val_split, rand_split=True ) # needs to rand split otherwise some classes are unseen in training. 52 | ls_ts_tgt_loader = [data_loader.load_testing( 53 | ag.data_root, testdom, ag.n_bat, kwargs) 54 | for testdom in ag.testdoms] 55 | if not IS_OOD: 56 | ls_tr_tgt_loader = [data_loader.load_training( 57 | ag.data_root, testdom, ag.n_bat, kwargs, -1) 58 | for testdom in ag.testdoms] 59 | 60 | if IS_OOD: 61 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 62 | tr_src_loader, val_src_loader, ls_ts_tgt_loader ) 63 | else: 64 | for testdom, tr_tgt_loader, ts_tgt_loader in zip( 65 | ag.testdoms, ls_tr_tgt_loader, ls_ts_tgt_loader): 66 | if testdom != ag.traindom: 67 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 68 | tr_src_loader, val_src_loader, None, 69 | tr_tgt_loader, ts_tgt_loader, testdom ) 70 | else: 71 | warnings.warn("same domain adaptation ignored") 72 | 73 | -------------------------------------------------------------------------------- /a-imageclef/prepare_data.sh: -------------------------------------------------------------------------------- 1 | DATAFILE=image_CLEF.zip 2 | wget https://transferlearningdrive.blob.core.windows.net/teamdrive/dataset/$DATAFILE 3 | mkdir -p data 4 | unzip $DATAFILE -d data/ 5 | mv $DATAFILE data/ 6 | 7 | -------------------------------------------------------------------------------- /a-imageclef/run_da.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | 4 | case $1 in 5 | dann) 6 | # Results taken from "Conditional adversarial domain adaptation" (NeurIPS'18) 7 | ;; 8 | cdan) 9 | # Results taken from "Conditional adversarial domain adaptation" (NeurIPS'18) 10 | ;; 11 | dan) 12 | # Results taken from "Conditional adversarial domain adaptation" (NeurIPS'18) 13 | ;; 14 | mdd) 15 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-2 --n_epk 20 --eval_interval 1 16 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-2 --n_epk 20 --eval_interval 1 17 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-2 --n_epk 20 --eval_interval 1 18 | ;; 19 | bnm) 20 | $REPRUN $PYCMD $1 --traindom c --testdoms p --wl2 5e-4 --wsup 1. --wda 1. --n_epk 20 --eval_interval 1 21 | $REPRUN $PYCMD $1 --traindom i --testdoms p --wl2 5e-4 --wsup 1. --wda 1. --n_epk 20 --eval_interval 1 22 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --wl2 5e-4 --wsup 1. --wda 1. --n_epk 20 --eval_interval 1 23 | ;; 24 | svae-da) # CSGz-DA 25 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 26 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 27 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 28 | ;; 29 | svgm-da) # CSG-DA 30 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 31 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 32 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 256 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --wda 1e-8 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 3e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 33 | ;; 34 | *) 35 | echo "unknown argument $1" 36 | ;; 37 | esac 38 | 39 | -------------------------------------------------------------------------------- /a-imageclef/run_ood.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | 4 | case $1 in 5 | discr) # CE 6 | # Results taken from "Conditional adversarial domain adaptation" (NeurIPS'18) 7 | ;; 8 | cnbb) 9 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dim_btnk 1024 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --reg_w 1e-6 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 --n_epk 20 --eval_interval 1 10 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dim_btnk 1024 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --reg_w 1e-6 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 --n_epk 20 --eval_interval 1 11 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dim_btnk 1024 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --reg_w 1e-6 --reg_s 3e-6 --lr_w 1e-4 --n_iter_w 4 --n_epk 20 --eval_interval 1 12 | ;; 13 | svae) # CSGz 14 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-7 --genstru DCGANpretr --n_epk 20 --eval_interval 1 15 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-7 --genstru DCGANpretr --n_epk 20 --eval_interval 1 16 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 0 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 0. --corr_sv 0. --pstd_x 1e-1 --qstd_s=-1. --qstd_v=0. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-7 --genstru DCGANpretr --n_epk 20 --eval_interval 1 17 | ;; 18 | svgm) # CSG 19 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 20 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 21 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 22 | ;; 23 | svgm-ind) # CSG-ind 24 | $REPRUN $PYCMD $1 --traindom c --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 25 | $REPRUN $PYCMD $1 --traindom i --testdoms p --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 26 | $REPRUN $PYCMD $1 --traindom p --testdoms c i --n_bat 32 --dims_bb2bn --dim_btnk 1024 --dim_v 128 --vbranch 0 --dims_bn2v --dims_bn2s 1024 --dim_s 256 --dims_s2y --dim_feat 128 --optim SGD --lr 1e-3 --wl2 5e-4 --lr_expo .75 --lr_wdatum 6.25e-6 --sig_s 3e+1 --sig_v 3e+1 --corr_sv .7 --pstd_x 1e-1 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-8 --genstru DCGANpretr --n_epk 20 --eval_interval 1 27 | ;; 28 | *) 29 | echo "unknown argument $1" 30 | ;; 31 | esac 32 | 33 | -------------------------------------------------------------------------------- /a-imageclef/visual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import warnings 3 | import sys 4 | sys.path.append("..") 5 | from utils.utils_main import main_stem, get_parser, is_ood, process_continue_run, get_models, da_methods, ResultsContainer, ood_methods 6 | from utils.preprocess import data_loader 7 | from utils.utils import boolstr 8 | 9 | from distr import edic 10 | from arch import mlp, cnn 11 | from methods import CNBBLoss, SemVar, SupVAE 12 | from utils import Averager, unique_filename, boolstr, zip_longer # This imports from 'utils/__init__.py' 13 | 14 | from dalib.modules.domain_discriminator import DomainDiscriminator 15 | from dalib.modules.kernels import GaussianKernel 16 | from dalib.adaptation.dann import DomainAdversarialLoss 17 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss 18 | from dalib.adaptation.dan import MultipleKernelMaximumMeanDiscrepancy 19 | from dalib.adaptation.mdd import MarginDisparityDiscrepancy 20 | 21 | import torch as tc 22 | from functools import partial 23 | import os 24 | 25 | MODES_TWIST = {"svgm-ind", "svae-da", "svgm-da"} 26 | 27 | __author__ = "Chang Liu" 28 | __email__ = "changliu@microsoft.com" 29 | # tc.autograd.set_detect_anomaly(True) 30 | from torchvision import models, transforms 31 | 32 | def get_input_transform(): 33 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | transf = transforms.Compose([ 36 | transforms.Resize((256, 256)), 37 | transforms.CenterCrop(224), 38 | transforms.ToTensor(), 39 | normalize 40 | ]) 41 | return transf 42 | 43 | def get_visual(ag, ckpt, archtype, shape_x, dim_y, 44 | tr_src_loader, val_src_loader, 45 | ls_ts_tgt_loader = None, # for ood 46 | tr_tgt_loader = None, ts_tgt_loader = None, testdom = None # for da 47 | ): 48 | print(ag) 49 | IS_OOD = is_ood(ag.mode) 50 | device = tc.device("cuda:"+str(ag.gpu) if tc.cuda.is_available() else "cpu") 51 | 52 | # Datasets 53 | dim_x = tc.tensor(shape_x).prod().item() 54 | if IS_OOD: n_per_epk = len(tr_src_loader) 55 | else: n_per_epk = max(len(tr_src_loader), len(tr_tgt_loader)) 56 | 57 | # Models 58 | res = get_models(archtype, edic(locals()) | vars(ag), ckpt, device) 59 | if ag.mode.endswith("-da2"): 60 | discr, gen, frame, discr_src = res 61 | discr_src.train() 62 | else: 63 | discr, gen, frame = res 64 | 65 | # get pictures 66 | discr.eval() 67 | if gen is not None: gen.eval() 68 | 69 | # Methods and Losses 70 | if IS_OOD: 71 | lossfn = ood_methods(discr, frame, ag, dim_y, cnbb_actv="Sigmoid") # Actually the activation is ReLU, but there is no `is_treat` rule for ReLU in CNBB. 72 | domdisc = None 73 | else: 74 | lossfn, domdisc, dalossobj = da_methods(discr, frame, ag, dim_x, dim_y, device, ckpt, 75 | discr_src if ag.mode.endswith("-da2") else None) 76 | 77 | epk0 = 1; i_bat0 = 1 78 | if ckpt is not None: 79 | epk0 = ckpt['epochs'][-1] + 1 if ckpt['epochs'] else 1 80 | i_bat0 = ckpt['i_bat'] 81 | res = ResultsContainer( len(ag.testdoms) if IS_OOD else None, 82 | frame, ag, dim_y==1, device, ckpt ) 83 | print(f"Run in mode '{ag.mode}' for {ag.n_epk:3d} epochs:") 84 | try: 85 | if ag.mode.endswith("-da2"): discr_src.eval(); true_discr = discr_src 86 | elif ag.mode in MODES_TWIST and ag.true_sup_val: true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q) 87 | else: true_discr = discr 88 | res.evaluate(true_discr, "val "+str(ag.traindom), 'val', val_src_loader, 'src') 89 | 90 | if IS_OOD: 91 | for i, (testdom, ts_tgt_loader) in enumerate(zip(ag.testdoms, ls_ts_tgt_loader)): 92 | res.evaluate(discr, "test "+str(testdom), 'ts', ts_tgt_loader, 'tgt', i) 93 | else: 94 | res.evaluate(discr, "test "+str(testdom), 'ts', ts_tgt_loader, 'tgt') 95 | print() 96 | 97 | def batch_predict(images): 98 | import torch.nn.functional as F 99 | if tc.tensor(images[0]).size()[-1] == 3: 100 | images = [tc.tensor(pic, dtype=tc.float).permute(2, 0, 1) for pic in images] 101 | batch = tc.stack(tuple(i for i in images), dim=0) 102 | batch = batch.to(device) 103 | 104 | logits = discr(batch) 105 | probs = F.softmax(logits, dim=1) 106 | 107 | return probs.detach().cpu().numpy() 108 | 109 | if IS_OOD: 110 | test_loader = ls_ts_tgt_loader[0] 111 | else: 112 | test_loader = ts_tgt_loader 113 | 114 | iter_tr, iter_ts = iter(tr_src_loader), iter(test_loader) 115 | train_batch, train_label = next(iter_tr) 116 | test_batch, test_label = next(iter_ts) 117 | 118 | os.makedirs(ag.mode, exist_ok=True) 119 | 120 | # search for the first accurate predict: 121 | cursor_train, cursor_test = 0, 0 122 | for i in range(400): 123 | cursor_test += 1 124 | if cursor_test >= test_batch.size()[0]: 125 | cursor_test = 0 126 | test_batch, test_label = next(iter_ts) 127 | while True: 128 | x_test = test_batch[cursor_test] 129 | test_pred = batch_predict([x_test]) 130 | if cursor_test < test_batch.size()[0] and test_label[cursor_test] == test_pred.squeeze().argmax(): 131 | break; 132 | else: 133 | cursor_test = cursor_test + 1 134 | if cursor_test >= test_batch.size()[0]: 135 | cursor_test = 0 136 | test_batch, test_label = next(iter_ts) 137 | 138 | selected_pic, selected_label = test_batch[cursor_test], test_label[cursor_test] 139 | 140 | cursor_train += 1 141 | if cursor_train >= train_batch.size()[0]: 142 | cursor_train = 0 143 | train_batch, train_label = next(iter_tr) 144 | while True: 145 | x_train = train_batch[cursor_train] 146 | test_pred = batch_predict([x_train]) 147 | if cursor_train < train_batch.size()[0] and train_label[cursor_train] == test_pred.squeeze().argmax(): 148 | break; 149 | else: 150 | cursor_train = cursor_train + 1 151 | if cursor_train >= train_batch.size()[0]: 152 | cursor_train = 0 153 | train_batch, train_label = next(iter_tr) 154 | 155 | selected_train_pic = train_batch[cursor_train] 156 | 157 | from lime import lime_image 158 | import numpy as np 159 | 160 | explainer = lime_image.LimeImageExplainer() 161 | explanation = explainer.explain_instance(np.array(selected_pic.permute(1, 2, 0), dtype=np.double), 162 | batch_predict, # classification function 163 | top_labels=5, 164 | hide_color=0, 165 | num_samples=1000) # number of images that will be sent to classification function 166 | from skimage.segmentation import mark_boundaries 167 | test_pic, mask_test_pic = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False) 168 | 169 | explanation_train = explainer.explain_instance(np.array(selected_train_pic.permute(1, 2, 0), dtype=np.double), 170 | batch_predict, # classification function 171 | top_labels=5, 172 | hide_color=0, 173 | num_samples=1000) # number of images that will be sent to classification function 174 | train_pic, mask_train_pic = explanation_train.get_image_and_mask(explanation_train.top_labels[0], positive_only=True, num_features=5, hide_rest=False) 175 | 176 | def vis_pic_trans(pic): 177 | pic = tc.tensor(pic).permute(2, 0, 1) 178 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 179 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 180 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 181 | std = [ 1., 1., 1. ]), 182 | ]) 183 | pic = invTrans(pic.unsqueeze(0)).squeeze() 184 | return pic.permute(1, 2, 0).numpy() 185 | 186 | test_pic = mark_boundaries(vis_pic_trans(test_pic), mask_test_pic) 187 | train_pic = mark_boundaries(vis_pic_trans(train_pic), mask_train_pic) 188 | 189 | import matplotlib.pyplot as plt 190 | 191 | plt.imshow(train_pic) 192 | plt.savefig(ag.mode+"/train-"+str(i)+".png") 193 | plt.imshow(test_pic) 194 | plt.savefig(ag.mode+"/test-"+str(i)+".png") 195 | 196 | except (KeyboardInterrupt, SystemExit): pass 197 | 198 | if __name__ == "__main__": 199 | parser = get_parser() 200 | parser.add_argument("--data_root", type = str, default = "./data/image_CLEF/") 201 | parser.add_argument("--traindom", type = str, default = "b") 202 | parser.add_argument("--testdoms", type = str, nargs = '+', default = ["b", "c", "i", "p"]) 203 | 204 | parser.add_argument("--dim_s", type = int, default = 1024) 205 | parser.add_argument("--dim_v", type = int, default = 256) 206 | parser.add_argument("--dim_btnk", type = int, default = 1024) # for discr_model 207 | parser.add_argument("--dims_bb2bn", type = int, nargs = '*') # for discr_model 208 | parser.add_argument("--dims_bn2s", type = int, nargs = '*') # for discr_model 209 | parser.add_argument("--dims_s2y", type = int, nargs = '*') # for discr_model 210 | parser.add_argument("--dims_bn2v", type = int, nargs = '*') # for discr_model 211 | parser.add_argument("--vbranch", type = boolstr, default = False) # for discr_model 212 | parser.add_argument("--dim_feat", type = int, default = 128) # for gen_model 213 | # parser.add_argument("--gpu", type=int, default=0) 214 | 215 | parser.set_defaults(discrstru = "ResNet50", genstru = "DCGANvar", 216 | n_bat = 100, n_epk = 100, 217 | optim = "SGD", lr = 4e-3, wl2 = 5e-4, 218 | momentum = .9, nesterov = True, lr_expo = .75, lr_wdatum = 6.25e-6, # only when "lr" is "SGD" 219 | wda = .25, 220 | domdisc_dimh = 1024, # for {dann, cdan, mdd} only 221 | cdan_rand = False, # for cdan only 222 | ker_alphas = [.5, 1., 2.], # for dan only 223 | mdd_margin = 4. # for mdd only 224 | ) 225 | ag = parser.parse_args() 226 | bat_size = ag.n_bat 227 | if ag.wlogpi is None: ag.wlogpi = ag.wgen 228 | ag, ckpt = process_continue_run(ag) 229 | IS_OOD = is_ood(ag.mode) 230 | ag.n_bat = bat_size 231 | 232 | archtype = "cnn" 233 | # Dataset 234 | shape_x = (3, 224, 224) # determined by the loader 235 | dim_y = 12 236 | kwargs = {'num_workers': 4, 'pin_memory': True} 237 | tr_src_loader = data_loader.load_testing( 238 | ag.data_root, ag.traindom, ag.n_bat, kwargs) # needs to rand split otherwise some classes are unseen in training. 239 | val_src_loader = tr_src_loader 240 | ls_ts_tgt_loader = [data_loader.load_testing( 241 | ag.data_root, testdom, ag.n_bat, kwargs) 242 | for testdom in ag.testdoms] 243 | print(ag.testdoms, len(ls_ts_tgt_loader), len(tr_src_loader), len(ls_ts_tgt_loader[0])) 244 | if not IS_OOD: 245 | ls_tr_tgt_loader = [data_loader.load_testing( 246 | ag.data_root, testdom, ag.n_bat, kwargs) 247 | for testdom in ag.testdoms] 248 | 249 | if IS_OOD: 250 | get_visual( ag, ckpt, archtype, shape_x, dim_y, 251 | tr_src_loader, val_src_loader, ls_ts_tgt_loader ) 252 | else: 253 | for testdom, tr_tgt_loader, ts_tgt_loader in zip( 254 | ag.testdoms, ls_tr_tgt_loader, ls_ts_tgt_loader): 255 | if testdom != ag.traindom: 256 | get_visual( ag, ckpt, archtype, shape_x, dim_y, 257 | tr_src_loader, val_src_loader, None, 258 | tr_tgt_loader, ts_tgt_loader, testdom ) 259 | else: 260 | warnings.warn("same domain adaptation ignored") 261 | 262 | -------------------------------------------------------------------------------- /a-mnist/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import warnings 3 | import sys 4 | import torch as tc 5 | sys.path.append("..") 6 | from utils.utils_main import main_stem, get_parser, is_ood, process_continue_run 7 | from utils.utils import boolstr 8 | 9 | __author__ = "Chang Liu" 10 | __email__ = "changliu@microsoft.com" 11 | # tc.autograd.set_detect_anomaly(True) 12 | 13 | if __name__ == "__main__": 14 | parser = get_parser() 15 | parser.add_argument("--data_root", type = str, default = "./data/MNIST/processed/") 16 | parser.add_argument("--traindom", type = str) # 12665 = 5923 (46.77%) + 6742 17 | parser.add_argument("--testdoms", type = str, nargs = '+') # 2115 = 980 (46.34%) + 1135 18 | parser.add_argument("--shuffle", type = boolstr, default = True) 19 | 20 | parser.add_argument("--mlpstrufile", type = str, default = "../arch/mlpstru.json") 21 | parser.add_argument("--actv", type = str, default = "Sigmoid") 22 | parser.add_argument("--after_actv", type = boolstr, default = True) 23 | 24 | parser.set_defaults(discrstru = "lite", genstru = None, 25 | n_bat = 128, n_epk = 100, 26 | mu_s = .5, mu_v = .5, 27 | pstd_x = 3e-2, qstd_s = 3e-2, qstd_v = 3e-2, 28 | optim = "RMSprop", lr = 1e-3, wl2 = 1e-5, 29 | momentum = 0., nesterov = False, lr_expo = .5, lr_wdatum = 6.25e-6, # only when "lr" is "SGD" 30 | wda = 1., 31 | domdisc_dimh = 1024, # for {dann, cdan, mdd} only 32 | cdan_rand = False, # for cdan only 33 | ker_alphas = [.5, 1., 2.], # for dan only 34 | mdd_margin = 4. # for mdd only 35 | ) 36 | ag = parser.parse_args() 37 | if ag.wlogpi is None: ag.wlogpi = ag.wgen 38 | ag, ckpt = process_continue_run(ag) 39 | IS_OOD = is_ood(ag.mode) 40 | 41 | archtype = "mlp" 42 | # Dataset 43 | src_x, src_y = tc.load(ag.data_root+ag.traindom) # `x` already tc.Tensor in range [0., 1.] 44 | dim_x = tc.tensor(src_x.shape[1:]).prod().item() 45 | src_x = src_x.reshape(-1, dim_x) 46 | shape_x = (dim_x,) 47 | dim_y = src_y.max().long().item() + 1 48 | if dim_y == 2: dim_y = 1 49 | ## tr-val split 50 | len_src = len(src_x) 51 | assert len_src == len(src_y) 52 | ids_src = tc.randperm(len_src) 53 | len_tr_src = int(len_src * ag.tr_val_split) 54 | ids_tr_src, ids_val_src = ids_src[:len_tr_src], ids_src[len_tr_src:] 55 | tr_src_x, tr_src_y = src_x[ids_tr_src], src_y[ids_tr_src] 56 | val_src_x, val_src_y = src_x[ids_val_src], src_y[ids_val_src] 57 | ## dataloaders 58 | kwargs = {'num_workers': 4, 'pin_memory': True} 59 | tr_src_loader = tc.utils.data.DataLoader( 60 | tc.utils.data.TensorDataset(tr_src_x, tr_src_y), 61 | ag.n_bat, ag.shuffle, **kwargs ) 62 | val_src_loader = tc.utils.data.DataLoader( 63 | tc.utils.data.TensorDataset(val_src_x, val_src_y), 64 | ag.n_bat, ag.shuffle, **kwargs ) 65 | ## tgt (ts) domain 66 | ls_tgt_xy_raw = [tc.load(ag.data_root+testdom) for testdom in ag.testdoms] # (x,y). `x` already tc.Tensor in range [0., 1.] 67 | ls_tgt_xy = [(xy[0].reshape(len(xy[0]), dim_x), xy[1]) for xy in ls_tgt_xy_raw] 68 | ls_ts_tgt_loader = [tc.utils.data.DataLoader( 69 | tc.utils.data.TensorDataset(*xy), 70 | ag.n_bat, ag.shuffle, **kwargs ) 71 | for xy in ls_tgt_xy] 72 | if not IS_OOD: 73 | ls_tr_tgt_loader = [tc.utils.data.DataLoader( 74 | tc.utils.data.TensorDataset(*xy), 75 | ag.n_bat, ag.shuffle, **kwargs ) 76 | for xy in ls_tgt_xy] 77 | 78 | if IS_OOD: 79 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 80 | tr_src_loader, val_src_loader, ls_ts_tgt_loader ) 81 | else: 82 | for testdom, tr_tgt_loader, ts_tgt_loader in zip( 83 | ag.testdoms, ls_tr_tgt_loader, ls_ts_tgt_loader): 84 | if testdom != ag.traindom: 85 | main_stem( ag, ckpt, archtype, shape_x, dim_y, 86 | tr_src_loader, val_src_loader, None, 87 | tr_tgt_loader, ts_tgt_loader, testdom ) 88 | else: 89 | warnings.warn("same domain adaptation ignored") 90 | 91 | -------------------------------------------------------------------------------- /a-mnist/makedata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''For generating MNIST-01 and its shifted interventional datasets. 3 | ''' 4 | import torch as tc 5 | import torchvision as tv 6 | import torchvision.transforms.functional as tvtf 7 | import argparse 8 | 9 | __author__ = "Chang Liu" 10 | __email__ = "changliu@microsoft.com" 11 | 12 | def select_xy(dataset, selected_y = (0,1), piltransf = None, ytransf = None): 13 | dataset_selected = [( 14 | tvtf.to_tensor( img if piltransf is None else piltransf(img, label) ), 15 | label if ytransf is None else ytransf(label) 16 | ) for img, label in dataset if label in selected_y] 17 | xs, ys = tuple(zip(*dataset_selected)) 18 | return tc.cat(xs, dim=0), tc.tensor(ys) 19 | 20 | def get_shift_transf(pleft: list, distr: str, loc: float, scale: float): 21 | return lambda img, label: tvtf.affine(img, angle=0, translate=( 22 | scale * getattr(tc, distr)(()) + loc * (1. - 2. * tc.bernoulli(tc.tensor(pleft[label]))), 0. 23 | ), scale=1., shear=0, fillcolor=0) 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("mode", type = str, choices = {"train", "test"}) 28 | parser.add_argument("--pleft", type = float, nargs = '+', default = [0.5, 0.5]) 29 | parser.add_argument("--distr", type = str, choices = {"randn", "rand"}) 30 | parser.add_argument("--loc", type = float, default = 4.) 31 | parser.add_argument("--scale", type = float, default = 1.) 32 | parser.add_argument("--procroot", type = str, default = "./data/MNIST/processed/") 33 | ag = parser.parse_args() 34 | 35 | dataset = tv.datasets.MNIST(root="./data", train = ag.mode=="train", download=True) # as PIL 36 | piltransf = get_shift_transf(ag.pleft, ag.distr, ag.loc, ag.scale) 37 | selected_y = tuple(range(len(ag.pleft))) 38 | shift_x, shift_y = select_xy(dataset, selected_y, piltransf) 39 | filename = ag.procroot + ag.mode + "".join(str(y) for y in selected_y) + ( 40 | "_" + "_".join(f"{p:.1f}" for p in ag.pleft) + 41 | "_" + ag.distr + f"_{ag.loc:.1f}_{ag.scale:.1f}.pt" ) 42 | tc.save((shift_x, shift_y), filename) 43 | print("Processed data saved to '" + filename + "'") 44 | 45 | -------------------------------------------------------------------------------- /a-mnist/makedata.sh: -------------------------------------------------------------------------------- 1 | python3 makedata.py train --pleft 1. 0. --distr randn --loc 5. --scale 1. 2 | python3 makedata.py test --pleft .5 .5 --distr randn --loc 0. --scale 0. 3 | mv ./data/MNIST/processed/test01_0.5_0.5_randn_0.0_0.0.pt ./data/MNIST/processed/test_01.pt 4 | python3 makedata.py test --pleft .5 .5 --distr randn --loc 0. --scale 2. 5 | -------------------------------------------------------------------------------- /a-mnist/run_da.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | TRAIN="--traindom train01_1.0_0.0_randn_5.0_1.0.pt" 4 | TEST="--testdoms test_01.pt test01_0.5_0.5_randn_0.0_2.0.pt" 5 | 6 | case $1 in 7 | dann) 8 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-4 9 | ;; 10 | cdan) 11 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-6 12 | ;; 13 | dan) 14 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-8 15 | ;; 16 | mdd) 17 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-6 18 | ;; 19 | bnm) 20 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-7 21 | ;; 22 | svae-da) # CSGz-DA 23 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-4 --mu_s .5 --sig_s .5 --pstd_x 3e-2 --qstd_s=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-4 --wsup 1. 24 | ;; 25 | svgm-da) # CSG-DA 26 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite --optim RMSprop --lr 3e-4 --wl2 1e-5 --wda 1e-4 --mu_s .5 --sig_s .5 --mu_v .5 --sig_v .5 --corr_sv .9 --pstd_x 3e-2 --qstd_s=-1. --qstd_v=-1. --tgt_mvn_prior 1 --src_mvn_prior 1 --wgen 1e-4 --wsup 1. 27 | ;; 28 | *) 29 | echo "unknown argument $1" 30 | ;; 31 | esac 32 | 33 | -------------------------------------------------------------------------------- /a-mnist/run_ood.sh: -------------------------------------------------------------------------------- 1 | REPRUN="../utils/reprun.sh 10" 2 | PYCMD="python3 main.py" 3 | TRAIN="--traindom train01_1.0_0.0_randn_5.0_1.0.pt" 4 | TEST="--testdoms test_01.pt test01_0.5_0.5_randn_0.0_2.0.pt" 5 | 6 | case $1 in 7 | discr) # CE 8 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 1e-3 --wl2 1e-5 9 | ;; 10 | cnbb) 11 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite-1.5x --optim RMSprop --lr 1e-3 --wl2 1e-5 --reg_w 1e-4 --reg_s 3e-6 --lr_w 1e-3 --n_iter_w 4 12 | ;; 13 | svae) # CSGz 14 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite --optim RMSprop --lr 1e-3 --wl2 1e-5 --mu_s .5 --sig_s .5 --pstd_x 3e-2 --qstd_s=-1. --wgen 1e-4 --wsup 1. --mvn_prior 1 15 | ;; 16 | svgm) # CSG 17 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite --optim RMSprop --lr 1e-3 --wl2 1e-5 --mu_s .5 --sig_s .5 --mu_v .5 --sig_v .5 --corr_sv .9 --pstd_x 3e-2 --qstd_s=-1. --qstd_v=-1. --wgen 1e-4 --wsup 1. --mvn_prior 1 18 | ;; 19 | svgm-ind) # CSG-ind 20 | $REPRUN $PYCMD $1 $TRAIN $TEST --discrstru lite --optim RMSprop --lr 1e-3 --wl2 1e-5 --mu_s .5 --sig_s .5 --mu_v .5 --sig_v .5 --corr_sv .9 --pstd_x 3e-2 --qstd_s=-1. --qstd_v=-1. --wgen 1e-4 --wsup 1. --mvn_prior 1 21 | ;; 22 | *) 23 | echo "unknown argument $1" 24 | ;; 25 | esac 26 | 27 | -------------------------------------------------------------------------------- /arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changliu00/causal-semantic-generative-model/05c10d6790f3db8d4847efd98d18a6ecafb469cb/arch/__init__.py -------------------------------------------------------------------------------- /arch/backbone.py: -------------------------------------------------------------------------------- 1 | """ This file is adapted from 2 | """ 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from torchvision import models 8 | from torch.autograd import Variable 9 | 10 | 11 | # convnet without the last layer 12 | class AlexNetFc(nn.Module): 13 | def __init__(self): 14 | super(AlexNetFc, self).__init__() 15 | model_alexnet = models.alexnet(pretrained=True) 16 | self.features = model_alexnet.features 17 | self.classifier = nn.Sequential() 18 | for i in range(6): 19 | self.classifier.add_module("classifier"+str(i), model_alexnet.classifier[i]) 20 | self.__in_features = model_alexnet.classifier[6].in_features 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = x.view(x.size(0), 256*6*6) 25 | x = self.classifier(x) 26 | return x 27 | 28 | def output_num(self): 29 | return self.__in_features 30 | 31 | 32 | class ResNet18Fc(nn.Module): 33 | def __init__(self): 34 | super(ResNet18Fc, self).__init__() 35 | model_resnet18 = models.resnet18(pretrained=True) 36 | self.conv1 = model_resnet18.conv1 37 | self.bn1 = model_resnet18.bn1 38 | self.relu = model_resnet18.relu 39 | self.maxpool = model_resnet18.maxpool 40 | self.layer1 = model_resnet18.layer1 41 | self.layer2 = model_resnet18.layer2 42 | self.layer3 = model_resnet18.layer3 43 | self.layer4 = model_resnet18.layer4 44 | self.avgpool = model_resnet18.avgpool 45 | self.__in_features = model_resnet18.fc.in_features 46 | 47 | def forward(self, x): 48 | x = self.conv1(x) 49 | x = self.bn1(x) 50 | x = self.relu(x) 51 | x = self.maxpool(x) 52 | x = self.layer1(x) 53 | x = self.layer2(x) 54 | x = self.layer3(x) 55 | x = self.layer4(x) 56 | x = self.avgpool(x) 57 | x = x.view(x.size(0), -1) 58 | return x 59 | 60 | def output_num(self): 61 | return self.__in_features 62 | 63 | 64 | class ResNet34Fc(nn.Module): 65 | def __init__(self): 66 | super(ResNet34Fc, self).__init__() 67 | model_resnet34 = models.resnet34(pretrained=True) 68 | self.conv1 = model_resnet34.conv1 69 | self.bn1 = model_resnet34.bn1 70 | self.relu = model_resnet34.relu 71 | self.maxpool = model_resnet34.maxpool 72 | self.layer1 = model_resnet34.layer1 73 | self.layer2 = model_resnet34.layer2 74 | self.layer3 = model_resnet34.layer3 75 | self.layer4 = model_resnet34.layer4 76 | self.avgpool = model_resnet34.avgpool 77 | self.__in_features = model_resnet34.fc.in_features 78 | 79 | def forward(self, x): 80 | x = self.conv1(x) 81 | x = self.bn1(x) 82 | x = self.relu(x) 83 | x = self.maxpool(x) 84 | x = self.layer1(x) 85 | x = self.layer2(x) 86 | x = self.layer3(x) 87 | x = self.layer4(x) 88 | x = self.avgpool(x) 89 | x = x.view(x.size(0), -1) 90 | return x 91 | 92 | def output_num(self): 93 | return self.__in_features 94 | 95 | 96 | class ResNet50Fc(nn.Module): 97 | def __init__(self): 98 | super(ResNet50Fc, self).__init__() 99 | model_resnet50 = models.resnet50(pretrained=True) 100 | self.conv1 = model_resnet50.conv1 101 | self.bn1 = model_resnet50.bn1 102 | self.relu = model_resnet50.relu 103 | self.maxpool = model_resnet50.maxpool 104 | self.layer1 = model_resnet50.layer1 105 | self.layer2 = model_resnet50.layer2 106 | self.layer3 = model_resnet50.layer3 107 | self.layer4 = model_resnet50.layer4 108 | self.avgpool = model_resnet50.avgpool 109 | self.__in_features = model_resnet50.fc.in_features 110 | 111 | def forward(self, x): 112 | x = self.conv1(x) 113 | x = self.bn1(x) 114 | x = self.relu(x) 115 | x = self.maxpool(x) 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | x = self.avgpool(x) 121 | x = x.view(x.size(0), -1) 122 | return x 123 | 124 | def output_num(self): 125 | return self.__in_features 126 | 127 | 128 | class ResNet101Fc(nn.Module): 129 | def __init__(self): 130 | super(ResNet101Fc, self).__init__() 131 | model_resnet101 = models.resnet101(pretrained=True) 132 | self.conv1 = model_resnet101.conv1 133 | self.bn1 = model_resnet101.bn1 134 | self.relu = model_resnet101.relu 135 | self.maxpool = model_resnet101.maxpool 136 | self.layer1 = model_resnet101.layer1 137 | self.layer2 = model_resnet101.layer2 138 | self.layer3 = model_resnet101.layer3 139 | self.layer4 = model_resnet101.layer4 140 | self.avgpool = model_resnet101.avgpool 141 | self.__in_features = model_resnet101.fc.in_features 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | return x 155 | 156 | def output_num(self): 157 | return self.__in_features 158 | 159 | 160 | class ResNet152Fc(nn.Module): 161 | def __init__(self): 162 | super(ResNet152Fc, self).__init__() 163 | model_resnet152 = models.resnet152(pretrained=True) 164 | self.conv1 = model_resnet152.conv1 165 | self.bn1 = model_resnet152.bn1 166 | self.relu = model_resnet152.relu 167 | self.maxpool = model_resnet152.maxpool 168 | self.layer1 = model_resnet152.layer1 169 | self.layer2 = model_resnet152.layer2 170 | self.layer3 = model_resnet152.layer3 171 | self.layer4 = model_resnet152.layer4 172 | self.avgpool = model_resnet152.avgpool 173 | self.__in_features = model_resnet152.fc.in_features 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | x = self.layer4(x) 184 | x = self.avgpool(x) 185 | x = x.view(x.size(0), -1) 186 | return x 187 | 188 | def output_num(self): 189 | return self.__in_features 190 | 191 | 192 | network_dict = {"AlexNet": AlexNetFc, 193 | "ResNet18": ResNet18Fc, 194 | "ResNet34": ResNet34Fc, 195 | "ResNet50": ResNet50Fc, 196 | "ResNet101": ResNet101Fc, 197 | "ResNet152": ResNet152Fc} 198 | -------------------------------------------------------------------------------- /arch/cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | ''' CNN Architecture 3 | 4 | Based on the architecture in MDD , 5 | and leverage the repositories of `domainbed` 6 | and `pytorch_GAN_zoo` . 7 | Architectures organized and enhanced for the use for the Causal Semantic Generative model. 8 | ''' 9 | __author__ = "Chang Liu" 10 | __email__ = "changliu@microsoft.com" 11 | 12 | import sys, os 13 | from warnings import warn 14 | import torch as tc 15 | import torch.nn as nn 16 | import torchvision as tv 17 | from . import backbone 18 | from . import mlp 19 | sys.path.append('..') 20 | from distr import tensorify, is_same_tensor, wrap4_multi_batchdims 21 | dbpath = '../a-domainbed/DomainBed' 22 | if os.path.isdir(dbpath): 23 | sys.path.append(dbpath) 24 | import domainbed # 54c2f8c 25 | from domainbed.networks import Featurizer 26 | 27 | def init_linear(nnseq, wmean, wstd, bval): 28 | for mod in nnseq: 29 | if type(mod) is nn.Linear: 30 | mod.weight.data.normal_(wmean, wstd) 31 | mod.bias.data.fill_(bval) 32 | 33 | # The inference/discriminative models / encoders. 34 | class CNNsvy1x(nn.Module): 35 | def __init__(self, backbone_stru: str, dim_bottleneck: int, dim_s: int, dim_y: int, dim_v: int, 36 | std_v1x_val: float, std_s1vx_val: float, # if <= 0, then learn the std. 37 | dims_bb2bn: list=None, dims_bn2s: list=None, dims_s2y: list=None, 38 | vbranch: bool=False, dims_bn2v: list=None): 39 | """ Based on MDD from 40 | if not vbranch: 41 | (bb) (bn) (med) (cls) 42 | /-> v -\ 43 | x ====> -| |-> s ----> y 44 | \-> parav -/ 45 | else: 46 | (bb) (bn) (med) (cls) 47 | x ====> ----> ----> s ----> y 48 | \ (vbr) 49 | \----> v 50 | """ 51 | if not vbranch: assert dim_v <= dim_bottleneck 52 | super(CNNsvy1x, self).__init__() 53 | self.dim_s = dim_s; self.dim_v = dim_v; self.dim_y = dim_y 54 | self.shape_s = (dim_s,); self.shape_v = (dim_v,) 55 | self.vbranch = vbranch 56 | self.std_v1x_val = std_v1x_val; self.std_s1vx_val = std_s1vx_val 57 | self.learn_std_v1x = std_v1x_val <= 0 if type(std_v1x_val) is float else (std_v1x_val <= 0).any() 58 | self.learn_std_s1vx = std_s1vx_val <= 0 if type(std_s1vx_val) is float else (std_s1vx_val <= 0).any() 59 | 60 | self._x_cache_bb = self._bb_cache = None 61 | self._x_cache_bn = self._bn_cache = None 62 | self._param_groups = [] 63 | 64 | if 'domainbed' in globals() and backbone_stru.startswith("DB"): 65 | self.nn_backbone = Featurizer((3,224,224), 66 | {'resnet18': backbone_stru[2:]=='resnet18', 'resnet_dropout': 0.}) 67 | self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 1.0}] 68 | self.nn_backbone.output_num = lambda: self.nn_backbone.n_outputs 69 | if dim_bottleneck is None: dim_bottleneck = self.nn_backbone.output_num() // 2 70 | if dim_s is None: dim_s = self.nn_backbone.output_num() // 4 71 | else: 72 | self.nn_backbone = backbone.network_dict[backbone_stru]() 73 | self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}] 74 | self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=3) 75 | 76 | if dims_bb2bn is None: dims_bb2bn = [] 77 | self.nn_bottleneck = mlp.mlp_constructor( 78 | [self.nn_backbone.output_num()] + dims_bb2bn + [dim_bottleneck], 79 | nn.ReLU, lastactv = False 80 | ) 81 | init_linear(self.nn_bottleneck, 0., 5e-3, 0.1) 82 | self._param_groups += [{"params": self.nn_bottleneck.parameters(), "lr_ratio": 1.}] 83 | self.f_bottleneck = self.nn_bottleneck 84 | 85 | if dims_bn2s is None: dims_bn2s = [] 86 | self.nn_mediate = nn.Sequential( 87 | *([] if backbone_stru.startswith("DB") else [nn.BatchNorm1d(dim_bottleneck)]), 88 | nn.ReLU(), 89 | # nn.Dropout(0.5), 90 | mlp.mlp_constructor( 91 | [dim_bottleneck] + dims_bn2s + [dim_s], 92 | nn.ReLU, lastactv = False) 93 | ) 94 | init_linear(self.nn_mediate, 0., 1e-2, 0.) 95 | self._param_groups += [{"params": self.nn_mediate.parameters(), "lr_ratio": 1.}] 96 | self.f_mediate = wrap4_multi_batchdims(self.nn_mediate, ndim_vars=1) # required by `BatchNorm1d` 97 | 98 | if dims_s2y is None: dims_s2y = [] 99 | self.nn_classifier = nn.Sequential( 100 | nn.ReLU(), 101 | # nn.Dropout(0.5), 102 | mlp.mlp_constructor( 103 | [dim_s] + dims_s2y + [dim_y], 104 | nn.ReLU, lastactv = False) 105 | ) 106 | init_linear(self.nn_classifier, 0., 1e-2, 0.) 107 | self._param_groups += [{"params": self.nn_classifier.parameters(), "lr_ratio": 1.}] 108 | self.f_classifier = self.nn_classifier 109 | 110 | if vbranch: 111 | if dims_bn2v is None: dims_bn2v = [] 112 | self.nn_vbranch = nn.Sequential( 113 | nn.BatchNorm1d(dim_bottleneck), 114 | nn.ReLU(), 115 | # nn.Dropout(0.5), 116 | mlp.mlp_constructor( 117 | [dim_bottleneck] + dims_bn2v + [dim_v], 118 | nn.ReLU, lastactv = False) 119 | ) 120 | init_linear(self.nn_vbranch, 0., 1e-2, 0.) 121 | self._param_groups += [{"params": self.nn_vbranch.parameters(), "lr_ratio": 1.}] 122 | self.f_vbranch = wrap4_multi_batchdims(self.nn_vbranch, ndim_vars=1) 123 | 124 | ## std models 125 | if self.learn_std_v1x: 126 | if not vbranch: 127 | self.nn_std_v = nn.Sequential( 128 | mlp.mlp_constructor( 129 | [self.nn_backbone.output_num()] + dims_bb2bn + [dim_v], 130 | nn.ReLU, lastactv = False), 131 | nn.Softplus() 132 | ) 133 | else: 134 | self.nn_std_v = nn.Sequential( 135 | mlp.mlp_constructor( 136 | [dim_bottleneck] + dims_bn2v + [dim_v], 137 | nn.ReLU, lastactv = False), 138 | nn.Softplus() 139 | ) 140 | init_linear(self.nn_std_v, 0., 1e-2, 0.) 141 | self._param_groups += [{"params": self.nn_std_v.parameters(), "lr_ratio": 1.}] 142 | self.f_std_v = self.nn_std_v 143 | 144 | if self.learn_std_s1vx: 145 | self.nn_std_s = nn.Sequential( 146 | nn.BatchNorm1d(dim_bottleneck), 147 | nn.ReLU(), 148 | # nn.Dropout(0.5), 149 | mlp.mlp_constructor( 150 | [dim_bottleneck] + dims_bn2s + [dim_s], 151 | nn.ReLU, lastactv = False), 152 | nn.Softplus() 153 | ) 154 | init_linear(self.nn_std_s, 0., 1e-2, 0.) 155 | self._param_groups += [{"params": self.nn_std_s.parameters(), "lr_ratio": 1.}] 156 | self.f_std_s = wrap4_multi_batchdims(self.nn_std_s, ndim_vars=1) 157 | 158 | def _get_bb(self, x): 159 | if not is_same_tensor(x, self._x_cache_bb): 160 | self._x_cache_bb = x 161 | self._bb_cache = self.f_backbone(x) 162 | return self._bb_cache 163 | 164 | def _get_bn(self, x): 165 | if not is_same_tensor(x, self._x_cache_bn): 166 | self._x_cache_bn = x 167 | self._bn_cache = self.f_bottleneck(self._get_bb(x)) 168 | return self._bn_cache 169 | 170 | def v1x(self, x): 171 | bn = self._get_bn(x) 172 | if not self.vbranch: return bn[..., :self.dim_v] 173 | else: return self.f_vbranch(bn) 174 | def std_v1x(self, x): 175 | if self.learn_std_v1x: 176 | if not self.vbranch: return self.f_std_v(self._get_bb(x)) 177 | else: return self.f_std_v(self._get_bn(x)) 178 | else: 179 | return tensorify(x.device, self.std_v1x_val)[0].expand(x.shape[:-3]+(self.dim_v,)) 180 | 181 | def s1vx(self, v, x): 182 | if not self.vbranch: 183 | bn = self._get_bn(x) 184 | bn_synth = tc.cat([v, bn[..., self.dim_v:]], dim=-1) 185 | return self.f_mediate(bn_synth) 186 | else: 187 | return self.s1x(x) 188 | def std_s1vx(self, v, x): 189 | if self.learn_std_s1vx: 190 | if not self.vbranch: 191 | bn = self._get_bn(x) 192 | bn_synth = tc.cat([v, bn[..., self.dim_v:]], dim=-1) 193 | return self.f_std_s(bn_synth) 194 | else: 195 | return self.std_s1x(x) 196 | else: 197 | return tensorify(x.device, self.std_s1vx_val)[0].expand(x.shape[:-3]+(self.dim_s,)) 198 | 199 | def s1x(self, x): 200 | return self.f_mediate(self._get_bn(x)) 201 | def std_s1x(self, x): 202 | if self.learn_std_s1vx: 203 | return self.f_std_s(self._get_bn(x)) 204 | else: 205 | return tensorify(x.device, self.std_s1vx_val)[0].expand(x.shape[:-3]+(self.dim_s,)) 206 | 207 | def y1s(self, s): 208 | return self.f_classifier(s).squeeze(-1) # squeeze for binary y 209 | 210 | def ys1x(self, x): 211 | s = self.s1x(x) 212 | return self.y1s(s), s 213 | 214 | def forward(self, x): 215 | return self.y1s(self.s1x(x)) 216 | 217 | def parameter_groups(self): 218 | return self._param_groups 219 | 220 | def save(self, path): tc.save(self.state_dict(), path) 221 | def load(self, path): 222 | self.load_state_dict(tc.load(path)) 223 | self.eval() 224 | 225 | # The generative models / decoders. 226 | def weights_init(m): 227 | classname = m.__class__.__name__ 228 | if classname.find('Conv') != -1: 229 | nn.init.normal_(m.weight.data, 0.0, 0.02) 230 | elif classname.find('BatchNorm') != -1: 231 | nn.init.normal_(m.weight.data, 1.0, 0.02) 232 | nn.init.constant_(m.bias.data, 0) 233 | 234 | ## Instances 235 | class CNN_DCGANvar_224(nn.Module): 236 | # Based on the decoder of DCGAN. 237 | def __init__(self, dim_in, dim_feat, dim_chanl = 3): 238 | super(CNN_DCGANvar_224, self).__init__() 239 | self.nn_main = nn.Sequential( 240 | # l_out = stride*(l_in - 1) + l_kernel - 2*padding. (*, *, l_kernel, stride, padding) 241 | # input is Z, going into a convolution 242 | nn.ConvTranspose2d( dim_in, dim_feat * 8, 7, 1, 0, bias=False), 243 | nn.BatchNorm2d(dim_feat * 8), 244 | nn.ReLU(True), 245 | # state size. (dim_feat*8) x 7 x 7 246 | nn.ConvTranspose2d(dim_feat * 8, dim_feat * 4, 4, 4, 0, bias=False), 247 | nn.BatchNorm2d(dim_feat * 4), 248 | nn.ReLU(True), 249 | # state size. (dim_feat*4) x 28 x 28 250 | nn.ConvTranspose2d( dim_feat * 4, dim_feat * 2, 4, 2, 1, bias=False), 251 | nn.BatchNorm2d(dim_feat * 2), 252 | nn.ReLU(True), 253 | # state size. (dim_feat*2) x 56 x 56 254 | nn.ConvTranspose2d( dim_feat * 2, dim_feat, 4, 2, 1, bias=False), 255 | nn.BatchNorm2d(dim_feat), 256 | nn.ReLU(True), 257 | # state size. (dim_feat) x 112 x 112 258 | nn.ConvTranspose2d( dim_feat, dim_chanl, 4, 2, 1, bias=True), # False), 259 | # nn.Tanh() 260 | # state size. (dim_chanl) x 224 x 224 261 | ) 262 | self.apply(weights_init) 263 | self._param_groups = [{"params": self.nn_main.parameters(), "lr_ratio": 1.}] 264 | self.f_main = wrap4_multi_batchdims(self.nn_main, ndim_vars=3) 265 | 266 | def forward(self, val): 267 | # `val` should be of shape (..., dim_in) 268 | return self.f_main(val[..., None, None]) 269 | 270 | class CNN_DCGANpretr_224(nn.Module): 271 | def __init__(self, dim_in, dim_feat = None, dim_chanl = None): 272 | # if dim_feat is not None or dim_chanl is not None: 273 | # warn(f"`dim_feat` {dim_feat} and `dim_chanl` {dim_chanl} ignored") 274 | super(CNN_DCGANpretr_224, self).__init__() 275 | self._param_groups = [] 276 | self.nn_pre = nn.Sequential( 277 | nn.Linear(dim_in, dim_feat), nn.Tanh(), 278 | nn.Linear(dim_feat, 120), nn.Tanh() 279 | ) 280 | self.nn_pre.apply(weights_init) 281 | self._param_groups += [{"params": self.nn_pre.parameters(), "lr_ratio": 1.}] 282 | self.f_pre = self.nn_pre 283 | 284 | self.nn_backbone = tc.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', # force_reload=True, 285 | pretrained=True, useGPU=False, model_name='cifar10').getOriginalG() 286 | self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}] 287 | self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=1) 288 | 289 | self.nn_post = nn.Sequential( 290 | nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 291 | nn.ReLU(inplace=True), 292 | nn.ConvTranspose2d(3, 3, 4, 4, 16, bias=False) 293 | ) 294 | self.nn_post.apply(weights_init) 295 | self._param_groups += [{"params": self.nn_post.parameters(), "lr_ratio": 1.}] 296 | self.f_post = wrap4_multi_batchdims(self.nn_post, ndim_vars=3) 297 | 298 | def forward(self, val): 299 | # `val` should be of shape (..., dim_in) 300 | return self.f_post(self.f_backbone(self.f_pre(val))) 301 | 302 | def parameter_groups(self): 303 | return self._param_groups 304 | 305 | class CNN_PGANpretr_224(nn.Module): 306 | def __init__(self, dim_in, dim_feat = None, dim_chanl = None): 307 | # if dim_feat is not None or dim_chanl is not None: 308 | # warn(f"`dim_feat` {dim_feat} and `dim_chanl` {dim_chanl} ignored") 309 | super(CNN_PGANpretr_224, self).__init__() 310 | self._param_groups = [] 311 | self.nn_pre = nn.Sequential( 312 | nn.Linear(dim_in, dim_feat), nn.Tanh(), 313 | nn.Linear(dim_feat, 512), nn.Tanh() 314 | ) 315 | self.nn_pre.apply(weights_init) 316 | self._param_groups += [{"params": self.nn_pre.parameters(), "lr_ratio": 1.}] 317 | self.f_pre = self.nn_pre 318 | 319 | self.nn_backbone = tc.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'PGAN', # force_reload=True, 320 | pretrained=True, useGPU=False, model_name='celebAHQ-256').getOriginalG() # 'cifar10' unavailable for PGAN 321 | self._param_groups += [{"params": self.nn_backbone.parameters(), "lr_ratio": 0.1}] 322 | self.f_backbone = wrap4_multi_batchdims(self.nn_backbone, ndim_vars=1) 323 | 324 | self.f_post = tv.transforms.CenterCrop(224) 325 | 326 | def forward(self, val): 327 | # `val` should be of shape (..., dim_in) 328 | return self.f_post(self.f_backbone(self.f_pre(val))) 329 | 330 | def parameter_groups(self): 331 | return self._param_groups 332 | 333 | ## Uniform interfaces 334 | class CNNx1sv(nn.Module): 335 | def __init__(self, dim_xside: int, dim_s: int, dim_v: int, dim_feat: int, dectype: str="DCGANvar"): 336 | super(CNNx1sv, self).__init__() 337 | self.net = globals()["CNN_" + dectype + "_" + str(dim_xside)](dim_s + dim_v, dim_feat) 338 | # Parameters of `self.net` automatically included in `self.parameters()` 339 | 340 | def x1sv(self, s, v): 341 | return self.net(tc.cat([s, v], dim=-1)) 342 | def forward(self, s, v): return self.x1sv(s, v) 343 | def parameter_groups(self): 344 | return self.net._param_groups 345 | 346 | def save(self, path): tc.save(self.state_dict(), path) 347 | def load(self, path): 348 | self.load_state_dict(tc.load(path)) 349 | self.eval() 350 | 351 | class CNNx1s(nn.Module): 352 | def __init__(self, dim_xside: int, dim_s: int, dim_feat: int, dectype: str="DCGANvar"): 353 | super(CNNx1s, self).__init__() 354 | self.net = globals()["CNN_" + dectype + "_" + str(dim_xside)](dim_s, dim_feat) 355 | # Parameters of `self.net` automatically included in `self.parameters()` 356 | 357 | def x1s(self, s): return self.net(s) 358 | def forward(self, s): return self.x1s(s) 359 | def parameter_groups(self): 360 | return self.net._param_groups 361 | 362 | def save(self, path): tc.save(self.state_dict(), path) 363 | def load(self, path): 364 | self.load_state_dict(tc.load(path)) 365 | self.eval() 366 | 367 | -------------------------------------------------------------------------------- /arch/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Multi-Layer Perceptron Architecture. 3 | 4 | For causal discriminative model and the corresponding generative model. 5 | ''' 6 | import sys, os 7 | import json 8 | import torch as tc 9 | import torch.nn as nn 10 | sys.path.append('..') 11 | from distr import tensorify, is_same_tensor, wrap4_multi_batchdims 12 | 13 | __author__ = "Chang Liu" 14 | __email__ = "changliu@microsoft.com" 15 | 16 | def init_linear(nnseq, wmean, wstd, bval): 17 | for mod in nnseq: 18 | if type(mod) is nn.Linear: 19 | mod.weight.data.normal_(wmean, wstd) 20 | mod.bias.data.fill_(bval) 21 | 22 | def mlp_constructor(dims, actv = "Sigmoid", lastactv = True): # `Sequential()`, or `Sequential(*[])`, is the identity map for any shape! 23 | if type(actv) is str: actv = getattr(nn, actv) 24 | if len(dims) <= 1: return nn.Sequential() 25 | else: return nn.Sequential(*( 26 | sum([[nn.Linear(dims[i], dims[i+1]), actv()] for i in range(len(dims)-2)], []) + \ 27 | [nn.Linear(dims[-2], dims[-1])] + ([actv()] if lastactv else []) 28 | )) 29 | 30 | class MLPBase(nn.Module): 31 | def save(self, path): tc.save(self.state_dict(), path) 32 | def load(self, path): 33 | self.load_state_dict(tc.load(path)) 34 | self.eval() 35 | def load_or_save(self, filename): 36 | dirname = "init_models_mlp/" 37 | os.makedirs(dirname, exist_ok=True) 38 | path = dirname + filename 39 | if os.path.exists(path): self.load(path) 40 | else: self.save(path) 41 | 42 | class MLP(MLPBase): 43 | def __init__(self, dims, actv = "Sigmoid"): 44 | if type(actv) is str: actv = getattr(nn, actv) 45 | super(MLP, self).__init__() 46 | self.f_x2y = mlp_constructor(dims, actv, lastactv = False) 47 | def forward(self, x): return self.f_x2y(x).squeeze(-1) 48 | 49 | class MLPsvy1x(MLPBase): 50 | def __init__(self, dim_x, dims_postx2prev, dim_v, dim_parav, dims_postv2s, dims_posts2prey, dim_y, actv = "Sigmoid", 51 | std_v1x_val: float=-1., std_s1vx_val: float=-1., after_actv: bool=True): # if <= 0, then learn the std. 52 | """ 53 | /-> v -\ 54 | x ====> prev -| |==> s ==> y 55 | \-> parav -/ 56 | """ 57 | super(MLPsvy1x, self).__init__() 58 | if type(actv) is str: actv = getattr(nn, actv) 59 | self.dim_x, self.dim_v, self.dim_y = dim_x, dim_v, dim_y 60 | dim_prev, dim_s = dims_postx2prev[-1], dims_postv2s[-1] 61 | self.dim_prev, self.dim_s = dim_prev, dim_s 62 | self.shape_x, self.shape_v, self.shape_s = (dim_x,), (dim_v,), (dim_s,) 63 | self.dims_postx2prev, self.dim_parav, self.dims_postv2s, self.dims_posts2prey, self.actv \ 64 | = dims_postx2prev, dim_parav, dims_postv2s, dims_posts2prey, actv 65 | self.f_x2prev = mlp_constructor([dim_x] + dims_postx2prev, actv) 66 | if after_actv: 67 | self.f_prev2v = nn.Sequential( nn.Linear(dim_prev, dim_v), actv() ) 68 | self.f_prev2parav = nn.Sequential( nn.Linear(dim_prev, dim_parav), actv() ) 69 | self.f_vparav2s = mlp_constructor([dim_v + dim_parav] + dims_postv2s, actv) 70 | self.f_s2y = mlp_constructor([dim_s] + dims_posts2prey + [dim_y], actv, lastactv = False) 71 | else: 72 | self.f_prev2v = nn.Linear(dim_prev, dim_v) 73 | self.f_prev2parav = nn.Linear(dim_prev, dim_parav) 74 | self.f_vparav2s = nn.Sequential( actv(), mlp_constructor([dim_v + dim_parav] + dims_postv2s, actv, lastactv = False) ) 75 | self.f_s2y = nn.Sequential( actv(), mlp_constructor([dim_s] + dims_posts2prey + [dim_y], actv, lastactv = False) ) 76 | 77 | self.std_v1x_val = std_v1x_val; self.std_s1vx_val = std_s1vx_val 78 | self.learn_std_v1x = std_v1x_val <= 0 if type(std_v1x_val) is float else (std_v1x_val <= 0).any() 79 | self.learn_std_s1vx = std_s1vx_val <= 0 if type(std_s1vx_val) is float else (std_s1vx_val <= 0).any() 80 | 81 | self._prev_cache = self._x_cache_prev = None 82 | self._v_cache = self._x_cache_v = None 83 | self._parav_cache = self._x_cache_parav = None 84 | 85 | ## std models 86 | if self.learn_std_v1x: 87 | self.nn_std_v = nn.Sequential( 88 | mlp_constructor( 89 | [dim_prev, dim_v], 90 | nn.ReLU, lastactv = False), 91 | nn.Softplus() 92 | ) 93 | init_linear(self.nn_std_v, 0., 1e-2, 0.) 94 | self.f_std_v = self.nn_std_v 95 | 96 | if self.learn_std_s1vx: 97 | self.nn_std_s = nn.Sequential( 98 | nn.BatchNorm1d(dim_v + dim_parav), 99 | nn.ReLU(), 100 | # nn.Dropout(0.5), 101 | mlp_constructor( 102 | [dim_v + dim_parav] + dims_postv2s, 103 | nn.ReLU, lastactv = False), 104 | nn.Softplus() 105 | ) 106 | init_linear(self.nn_std_s, 0., 1e-2, 0.) 107 | self.f_std_s = wrap4_multi_batchdims(self.nn_std_s, ndim_vars=1) 108 | 109 | def _get_prev(self, x): 110 | if not is_same_tensor(x, self._x_cache_prev): 111 | self._x_cache_prev = x 112 | self._prev_cache = self.f_x2prev(x) 113 | return self._prev_cache 114 | 115 | def v1x(self, x): 116 | if not is_same_tensor(x, self._x_cache_v): 117 | self._x_cache_v = x 118 | self._v_cache = self.f_prev2v(self._get_prev(x)) 119 | return self._v_cache 120 | def std_v1x(self, x): 121 | if self.learn_std_v1x: 122 | return self.f_std_v(self._get_prev(x)) 123 | else: 124 | return tensorify(x.device, self.std_v1x_val)[0].expand(x.shape[:-1]+(self.dim_v,)) 125 | 126 | def _get_parav(self, x): 127 | if not is_same_tensor(x, self._x_cache_parav): 128 | self._x_cache_parav = x 129 | self._parav_cache = self.f_prev2parav(self._get_prev(x)) 130 | return self._parav_cache 131 | 132 | def s1vx(self, v, x): 133 | parav = self._get_parav(x) 134 | return self.f_vparav2s(tc.cat([v, parav], dim=-1)) 135 | def std_s1vx(self, v, x): 136 | if self.learn_std_s1vx: 137 | parav = self._get_parav(x) 138 | return self.f_std_s(tc.cat([v, parav], dim=-1)) 139 | else: 140 | return tensorify(x.device, self.std_s1vx_val)[0].expand(x.shape[:-1]+(self.dim_s,)) 141 | 142 | def s1x(self, x): 143 | return self.s1vx(self.v1x(x), x) 144 | def std_s1x(self, x): 145 | return self.std_s1vx(self.v1x(x), x) 146 | 147 | def y1s(self, s): 148 | return self.f_s2y(s).squeeze(-1) # squeeze for binary y 149 | 150 | def ys1x(self, x): 151 | s = self.s1x(x) 152 | return self.y1s(s), s 153 | 154 | def forward(self, x): 155 | return self.y1s(self.s1x(x)) 156 | 157 | class MLPx1sv(MLPBase): 158 | def __init__(self, dim_s = None, dims_pres2parav = None, dim_v = None, dims_prev2postx = None, dim_x = None, 159 | actv = None, *, discr = None): 160 | if dim_s is None: dim_s = discr.dim_s 161 | if dim_v is None: dim_v = discr.dim_v 162 | if dim_x is None: dim_x = discr.dim_x 163 | if actv is None: actv = discr.actv if hasattr(discr, "actv") else "Sigmoid" 164 | if type(actv) is str: actv = getattr(nn, actv) 165 | if dims_pres2parav is None: dims_pres2parav = discr.dims_postv2s[::-1][1:] + [discr.dim_parav] 166 | if dims_prev2postx is None: dims_prev2postx = discr.dims_postx2prev[::-1] 167 | super(MLPx1sv, self).__init__() 168 | self.dim_s, self.dim_v, self.dim_x = dim_s, dim_v, dim_x 169 | self.dims_pres2parav, self.dims_prev2postx, self.actv = dims_pres2parav, dims_prev2postx, actv 170 | self.f_s2parav = mlp_constructor([dim_s] + dims_pres2parav, actv) 171 | self.f_vparav2x = mlp_constructor([dim_v + dims_pres2parav[-1]] + dims_prev2postx + [dim_x], actv) 172 | 173 | def x1sv(self, s, v): return self.f_vparav2x(tc.cat([v, self.f_s2parav(s)], dim=-1)) 174 | def forward(self, s, v): return self.x1sv(s, v) 175 | 176 | class MLPx1s(MLPBase): 177 | def __init__(self, dim_s = None, dims_pres2postx = None, dim_x = None, 178 | actv = None, *, discr = None): 179 | if dim_s is None: dim_s = discr.dim_s 180 | if dim_x is None: dim_x = discr.dim_x 181 | if actv is None: actv = discr.actv if hasattr(discr, "actv") else "Sigmoid" 182 | if type(actv) is str: actv = getattr(nn, actv) 183 | if dims_pres2postx is None: 184 | dims_pres2postx = discr.dims_postv2s[::-1][1:] + [discr.dim_v + discr.dim_parav] + discr.dims_postx2prev[::-1] 185 | super(MLPx1s, self).__init__() 186 | self.dim_s, self.dim_x, self.dims_pres2postx, self.actv = dim_s, dim_x, dims_pres2postx, actv 187 | self.f_s2x = mlp_constructor([dim_s] + dims_pres2postx + [dim_x], actv) 188 | 189 | def x1s(self, s): return self.f_s2x(s) 190 | def forward(self, s): return self.x1s(s) 191 | 192 | class MLPv1s(MLPBase): 193 | def __init__(self, dim_s = None, dims_pres2postv = None, dim_v = None, 194 | actv = None, *, discr = None): 195 | if dim_s is None: dim_s = discr.dim_s 196 | if dim_v is None: dim_v = discr.dim_v 197 | if actv is None: actv = discr.actv if hasattr(discr, "actv") else "Sigmoid" 198 | if type(actv) is str: actv = getattr(nn, actv) 199 | if dims_pres2postv is None: dims_pres2postv = discr.dims_postv2s[::-1][1:] 200 | super(MLPv1s, self).__init__() 201 | self.dim_s, self.dim_v, self.dims_pres2postv, self.actv = dim_s, dim_v, dims_pres2postv, actv 202 | self.f_s2v = mlp_constructor([dim_s] + dims_pres2postv + [dim_v], actv) 203 | 204 | def v1s(self, s): return self.f_s2v(s) 205 | def forward(self, s): return self.v1s(s) 206 | 207 | def create_discr_from_json(stru_name: str, dim_x: int, dim_y: int, actv: str=None, 208 | std_v1x_val: float=-1., std_s1vx_val: float=-1., after_actv: bool=True, jsonfile: str="mlpstru.json"): 209 | stru = json.load(open(jsonfile))['MLPsvy1x'][stru_name] 210 | if actv is not None: stru['actv'] = actv 211 | return MLPsvy1x(dim_x=dim_x, dim_y=dim_y, std_v1x_val=std_v1x_val, std_s1vx_val=std_s1vx_val, 212 | after_actv=after_actv, **stru) 213 | 214 | def create_gen_from_json(model_type: str="MLPx1sv", discr: MLPsvy1x=None, stru_name: str=None, dim_x: int=None, actv: str=None, jsonfile: str="mlpstru.json"): 215 | if stru_name is None: 216 | return eval(model_type)(dim_x=dim_x, discr=discr, actv=actv) 217 | else: 218 | stru = json.load(open(jsonfile))[model_type][stru_name] 219 | if actv is not None: stru['actv'] = actv 220 | return eval(model_type)(dim_x=dim_x, discr=discr, **stru) 221 | 222 | -------------------------------------------------------------------------------- /arch/mlpstru.json: -------------------------------------------------------------------------------- 1 | { 2 | "MLPsvy1x": { 3 | "lite": { 4 | "dims_postx2prev": [400], 5 | "dim_v": 100, 6 | "dim_parav": 100, 7 | "dims_postv2s": [50], 8 | "dims_posts2prey": [] 9 | }, 10 | "lite-1.5x": { 11 | "dims_postx2prev": [600], 12 | "dim_v": 150, 13 | "dim_parav": 150, 14 | "dims_postv2s": [75], 15 | "dims_posts2prey": [] 16 | }, 17 | "bt882": { 18 | "dims_postx2prev": [400, 200], 19 | "dim_v": 8, 20 | "dim_parav": 8, 21 | "dims_postv2s": [2], 22 | "dims_posts2prey": [] 23 | }, 24 | "irm": { 25 | "dims_postx2prev": [256], 26 | "dim_v": 128, 27 | "dim_parav": 128, 28 | "dims_postv2s": [256], 29 | "dims_posts2prey": [] 30 | }, 31 | "bt882-256": { 32 | "dims_postx2prev": [256, 256], 33 | "dim_v": 8, 34 | "dim_parav": 8, 35 | "dims_postv2s": [2], 36 | "dims_posts2prey": [] 37 | } 38 | }, 39 | "MLPx1sv": { 40 | }, 41 | "MLPx1s": { 42 | }, 43 | "MLPv1s": { 44 | "s2v0": { 45 | "dims_pres2postv": [] 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /csg-intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changliu00/causal-semantic-generative-model/05c10d6790f3db8d4847efd98d18a6ecafb469cb/csg-intro.png -------------------------------------------------------------------------------- /distr/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Probabilistic Programming Package. 3 | 4 | The prototype is distributions, which can be a conditional one with 5 | functions for parameters to define the dependency. Distribution 6 | multiplication is implemented, as well as the mean, expectation, 7 | sampling with backprop capability, and log-probability. 8 | ''' 9 | 10 | __author__ = "Chang Liu" 11 | __version__ = "1.0.1" 12 | __email__ = "changliu@microsoft.com" 13 | 14 | from .base import Distr, DistrElem 15 | from .instances import Determ, Normal, MVNormal, Catg, Bern 16 | 17 | from .utils import ( append_attrs, 18 | edic, edicify, 19 | fargnames, fedic, wrap4_multi_batchdims, 20 | tensorify, is_scalar, is_same_tensor, 21 | expand_front, flatten_last, reduce_last, swap_dim_ranges, expand_middle, 22 | tcsizeify, tcsize_div, tcsize_broadcast, 23 | ) 24 | from .tools import elbo, elbo_z2xy, elbo_z2xy_twist, elbo_zy2x 25 | 26 | -------------------------------------------------------------------------------- /distr/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Probabilistic Programming Package. 3 | 4 | The prototype is distributions, which can be a conditional one with 5 | functions for parameters to define the dependency. Distribution 6 | multiplication is implemented, as well as the mean, expectation, 7 | sampling with backprop capability, and log-probability. 8 | ''' 9 | 10 | import math 11 | import torch as tc 12 | from .utils import edic, edicify, expand_front, fargnames, fedic, swap_dim_ranges, tcsizeify, tcsize_broadcast, tcsize_div, tensorify 13 | 14 | __author__ = "Chang Liu" 15 | __version__ = "1.0.1" 16 | __email__ = "changliu@microsoft.com" 17 | 18 | class Distr: 19 | ''' 20 | names: set <-> vals: edic(name:tensor) 21 | parents: set <-> conds: edic(name:tensor) 22 | ''' 23 | default_device = None 24 | _shapes_all_vars = {} # dict 25 | 26 | @staticmethod 27 | def all_names() -> set: return set(Distr._shapes_all_vars) 28 | @staticmethod 29 | def clear(): Distr._shapes_all_vars.clear() 30 | @staticmethod 31 | def has_name(name: str) -> bool: return name in Distr._shapes_all_vars 32 | @staticmethod 33 | def has_names(names: set) -> bool: 34 | return all(name in Distr._shapes_all_vars for name in names) 35 | @staticmethod 36 | def shape_var(name: str) -> tc.Size: return Distr._shapes_all_vars[name] 37 | @staticmethod 38 | def shapes_var(names: set) -> dict: 39 | return {name: Distr._shapes_all_vars[name] for name in names} 40 | @staticmethod 41 | def shape_bat(conds: edic) -> tc.Size: 42 | for name, ten in conds.items(): 43 | if Distr.has_name(name): return tcsize_div(ten.shape, Distr.shape_var(name)) 44 | return tc.Size() 45 | @staticmethod 46 | def shape_bat_broadcast(conds: edic) -> tc.Size: 47 | shapes_bat = [tcsize_div(ten.shape, Distr.shape_var(name)) 48 | for name, ten in conds.items() if Distr.has_name(name)] 49 | return tcsize_broadcast(*shapes_bat) 50 | @staticmethod 51 | def broadcast_vars(conds: edic) -> edic: 52 | shape_bat = Distr.shape_bat_broadcast(conds) 53 | return edic({name: ten.expand(shape_bat + Distr.shape_var(name)) 54 | for name, ten in conds.items() if Distr.has_name(name)}) 55 | 56 | def __init__(self, *, names: set=set(), names_shapes: dict={}, parents: set=set()): 57 | for name, shape in names_shapes.items(): 58 | shape, = tcsizeify(shape,) 59 | if Distr.has_name(name): 60 | if shape != Distr.shape_var(name): raise ValueError(f"shape not match for existing variable '{name}'") 61 | else: Distr._shapes_all_vars[name] = shape 62 | for name in names: 63 | if not Distr.has_name(name): raise ValueError(f"new variable '{name}' needs a shape") 64 | names = names | set(names_shapes) 65 | if not names: raise ValueError("name(s) have to be provided") 66 | names_comm = names & parents 67 | if names_comm: raise ValueError(f"common variable(s) '{names_comm}' found in `names` and `parents`") 68 | self._names, self._parents, self._is_root = names, parents, (not bool(parents)) 69 | 70 | @property 71 | def names(self) -> set: return self._names 72 | @property 73 | def parents(self) -> set: return self._parents 74 | @property 75 | def is_root(self) -> bool: return self._is_root 76 | 77 | def __repr__(self) -> str: 78 | return "p(" + ", ".join(self.names) + (" | " + ", ".join(self.parents) + ")" if self.parents else ")") 79 | 80 | def mean(self, conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 81 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 82 | raise NotImplementedError 83 | 84 | def expect(self, fn, conds: edic=edic(), n_mc: int=10, repar: bool=True, reducefn = tc.mean) -> tc.Tensor: 85 | # [shape_bat] -> [shape_bat] 86 | if n_mc == 0: 87 | vals = self.mean(conds, 0, repar) 88 | return fn(conds|vals) 89 | elif n_mc > 0: 90 | vals = self.draw(tc.Size((n_mc,)), conds, repar) 91 | return reducefn(fn( edicify(conds)[0].expand_front((n_mc,)) | vals ), dim=0) 92 | else: raise ValueError(f"For {self}, negative `n_mc` {n_mc} encountered") 93 | 94 | def rdraw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic()) -> edic: # vals 95 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 96 | return self.draw(shape_mc, conds, True) 97 | 98 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 99 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 100 | raise NotImplementedError 101 | 102 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 103 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 104 | raise NotImplementedError 105 | 106 | def logp_cartes(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 107 | # [shape_mc, shape_var], [shape_bat, shape_cond] -> [shape_mc, shape_bat] 108 | vals, conds = edicify(vals, conds) 109 | shape_mc, shape_bat = Distr.shape_bat(vals), Distr.shape_bat(conds) 110 | len_mc, len_bat = len(shape_mc), len(shape_bat) 111 | vals_expd = vals.sub( self.names, 112 | lambda v: swap_dim_ranges(expand_front(v, shape_bat), (0, len_bat), (len_bat, len_bat+len_mc)) ) 113 | conds_expd = conds.sub_expand_front(self.parents, shape_mc) 114 | return self.logp(vals_expd, conds_expd) 115 | 116 | def __mul__(self, other): # p(z|w1) * p(x|z,w2) -> p(z,x|w1,w2) 117 | if not self.names.isdisjoint(other.names): 118 | raise ValueError(f"cycle found for {self} * {other}: common variable") 119 | if not self.parents.isdisjoint(other.names): 120 | if not self.names.isdisjoint(other.parents): 121 | raise ValueError(f"cycle found for {self} * {other}: cyclic generation") 122 | else: p_marg, p_cond = other, self 123 | else: p_marg, p_cond = self, other 124 | indep = p_marg.names.isdisjoint(p_cond.parents) 125 | p_joint = Distr(names = p_marg.names | p_cond.names, 126 | parents = p_marg.parents | p_cond.parents - p_marg.names) 127 | p_joint._p_marg, p_joint._p_cond, p_joint._indep = p_marg, p_cond, indep 128 | 129 | def _mean(conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 130 | mean_marg = p_marg.mean(conds, n_mc, repar) # [shape_bat, shape_var_marg] 131 | # mean_cond = p_marg.expect(lambda dc: p_cond.mean(dc,dc), conds, n_mc, repar, tc.mean) # Commented for the lack of tc.mean(edic) and efficiency concern 132 | if indep: 133 | mean_cond = p_cond.mean(conds, n_mc, repar) # [shape_bat, shape_var_cond] 134 | else: 135 | if n_mc == 0: 136 | mean_cond = p_cond.mean(conds|mean_marg, 0, repar) # [shape_bat, shape_var_cond] 137 | elif n_mc > 0: 138 | vals_marg = p_marg.draw((n_mc,), conds, repar) # [n_mc, shape_bat, shape_var_marg] 139 | mean_cond = p_cond.mean(edicify(conds)[0].sub_expand_front(p_joint.parents, (n_mc,)) | vals_marg, 140 | n_mc, repar).mean(dim=0) # [shape_bat, shape_var_cond] 141 | else: raise ValueError(f"For {p_joint}, negative `n_mc` {n_mc} encountered") 142 | return mean_marg | mean_cond 143 | p_joint.mean = _mean 144 | 145 | def _draw(shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 146 | vals_marg = p_marg.draw(shape_mc, conds, repar) # [shape_mc, shape_bat, shape_var_marg] 147 | if indep: 148 | vals_cond = p_cond.draw(shape_mc, conds, repar) # [shape_mc, shape_bat, shape_var_cond] 149 | else: 150 | vals_cond = p_cond.draw(tc.Size(), 151 | edicify(conds)[0].sub_expand_front(p_joint.parents, shape_mc) | vals_marg, 152 | repar) # [shape_mc, shape_bat, shape_var_cond] 153 | return vals_marg | vals_cond 154 | p_joint.draw = _draw 155 | 156 | def _logp(vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 157 | return p_marg.logp(vals, conds) + p_cond.logp(vals, edicify(conds)[0]|vals) # [shape_bat] + [shape_bat] guaranteed 158 | p_joint.logp = _logp 159 | 160 | def _entropy(conds: edic=edic(), n_mc: int=10, repar: bool=True) -> tc.Tensor: 161 | # [shape_bat, shape_cond] -> [shape_bat] 162 | return p_marg.entropy(conds) + ( 163 | p_cond.entropy(conds) if indep else 164 | p.marg.expect(p_cond.entropy, conds, n_mc, repar, tc.mean) ) 165 | if indep: p_joint.entropy = _entropy 166 | 167 | return p_joint 168 | 169 | def marg(self, mnames: set, n_mc: int=10): 170 | if mnames == self.names: return self 171 | if not mnames: raise ValueError(f"'mnames' empty for {self}") 172 | names_irrelev = mnames - self.names; 173 | if names_irrelev: raise ValueError(f"irrelevant variable(s) {names_irrelev} found for {self}") 174 | # To filter out non-elem distr that is formed by MC marginalization 175 | if hasattr(self, '_p_marg') and hasattr(self, '_p_cond'): 176 | # Marg p(ym, yo, xm, xo, zm, zo) = p(xm, xo, zm, zo) p(ym, yo | xm, xo) 177 | # for mnames = {ym, xm, zm}. Other conditioned vars omitted. 178 | p_marg, p_cond = self._p_marg, self._p_cond 179 | # ym empty. Marg p(xm, xo, zm, zo) for {xm, zm} 180 | if mnames <= p_marg.names: return p_marg.marg(mnames, n_mc) 181 | # ym not empty 182 | mnames_marg, mnames_cond = (mnames & p_marg.names), (mnames & p_cond.names) # {xm, zm}, {ym} 183 | p_condm = p_cond.marg(mnames_cond, n_mc) # p(ym | xm, xo) 184 | names_intsec = p_marg.names & p_condm.parents # {xm, xo} 185 | if not names_intsec: # {xm, xo} empty. Marg p(zm, zo) p(ym) for {ym, zm} 186 | return p_marg.marg(mnames_marg, n_mc) * p_condm if mnames_marg else p_condm # p(zm) p(ym) if zm not empty else p(ym) 187 | else: # {xm, xo} not empty 188 | p_margm = p_marg.marg(mnames_marg | names_intsec, n_mc) # p(xm, xo, zm) 189 | if not mnames_marg: # label: L0 190 | # {xm, zm} empty. Marg p(xo) p(ym | xo) for ym 191 | p_joint = p_margm * p_condm # p(xo) p(ym | xo) 192 | p_res = Distr(names = mnames, parents = p_joint.parents) 193 | def _mean(conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 194 | return p_joint.mean(conds, n_mc, repar).sub(mnames) 195 | p_res.mean = _mean 196 | def _draw(shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 197 | return p_joint.draw(shape_mc, conds, repar).sub(mnames) 198 | p_res.draw = _draw 199 | def _logp(vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 200 | return p_margm.expect(lambda dc: p_condm.logp(dc,dc), # `dc` contains properly expanded `conds` and `vals` 201 | conds|vals, n_mc, True, tc.logsumexp) - math.log(n_mc) 202 | p_res.logp = _logp 203 | return p_res 204 | else: # {xm, zm} not empty. Marg p(xm, xo, zm) p(ym | xm, xo) for {ym, xm, zm} 205 | if names_intsect <= mnames_marg: # xo is empty. Marg p(xm, zm) p(ym | xm) for {ym, xm, zm} 206 | return p_margm * p_condm # p(xm, zm) p(ym | xm) 207 | else: # xo is not empty 208 | if hasattr(p_margm, '_p_marg') and hasattr(p_margm, '_p_cond'): 209 | if p_margm._p_marg.names == mnames_marg: # p(xm, xo, zm) = p(xm, zm) p(xo | xm, zm) 210 | # Marg p(xm, zm) p(xo | xm, zm) p(ym | xm, xo) for {ym, xm, zm} is p(xm, zm) p(ym | xm, zm), 211 | # where p(ym | xm, zm) is from marg p(xo | xm, zm) p(ym | xm, xo) for ym 212 | return p_margm._p_marg * (p_margm._p_cond * p_condm).marg(mnames_cond, n_mc) # goto L0 213 | elif p_margm._p_cond.names == mnames_marg: 214 | if p_margm._indep: # p(xm, xo, zm) = p(xo) p(xm, zm) 215 | # Similar to the above. p(xm, zm) p(ym | xm, zm), 216 | # where p(ym | xm, zm) is from marg p(xo) p(ym | xm, xo) for ym 217 | return p_margm._p_cond * (p_margm._p_marg * p_condm).marg(mnames_cond, n_mc) # goto L0 218 | else: # p(xm, xo, zm) = p(xo) p(xm, zm | xo) 219 | # Marg p(xo) p(xm, zm | xo) p(ym | xm, xo) for {ym, xm, zm} 220 | return (p_margm._p_marg * (p_margm._p_cond * p_condm)).marg(mnames, n_mc) # goto L0 221 | raise RuntimeError(f"Unable to marginalize {self} for {mnames}. Check the model or try other factorizations.") 222 | 223 | class DistrElem(Distr): 224 | def __init__(self, name: str, shape: tc.Size, device = None, **params): 225 | # for distributions whose parameter and random variable (or, one sample) have the same shape 226 | if device is None: device = Distr.default_device 227 | fnnames, fnvals, tennames, tenvals = [], [], [], [] 228 | for pmname, pmval in params.items(): 229 | if callable(pmval): 230 | fnnames.append(pmname) 231 | fnvals.append(pmval) 232 | else: 233 | tennames.append(pmname) 234 | tenvals.append(pmval) 235 | tenvals = tensorify(device, *tenvals) 236 | if shape is None: 237 | if Distr.has_name(name): shape = Distr.shape_var(name) 238 | elif tenvals: shape = tc.broadcast_tensors(*tenvals)[0].shape 239 | else: shape = tc.Size() 240 | parents = set() 241 | for fname, fval in zip(fnnames, fnvals): 242 | parents_inc = fargnames(fval) 243 | if parents_inc: 244 | parents |= parents_inc 245 | setattr(self, fname, fedic(fval)) 246 | else: 247 | setattr( self, fname, 248 | lambda conds, fval=fval: tensorify(device, fval())[0].expand(Distr.shape_bat(conds) + shape) ) 249 | for tname, tval in zip(tennames, tenvals): 250 | setattr( self, tname, 251 | lambda conds, tval=tval: tval.expand(Distr.shape_bat(conds) + shape) ) 252 | super(DistrElem, self).__init__(names_shapes = {name: shape}, parents = parents) 253 | self._name, self._shape, self._device = name, shape, device 254 | 255 | @property 256 | def name(self): return self._name 257 | @property 258 | def shape(self): return self._shape 259 | @property 260 | def device(self): return self._device 261 | 262 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 263 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 264 | raise NotImplementedError 265 | 266 | -------------------------------------------------------------------------------- /distr/instances.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Probabilistic Programming Package. 3 | 4 | The prototype is distributions, which can be a conditional one with 5 | functions for parameters to define the dependency. Distribution 6 | multiplication is implemented, as well as the mean, expectation, 7 | sampling with backprop capability, and log-probability. 8 | 9 | This file is greatly inspired by `torch.distributions`, with some components adopted. 10 | ''' 11 | 12 | import warnings 13 | import math 14 | from contextlib import suppress 15 | import torch as tc 16 | import torch.distributions.utils as tcdu 17 | from .base import Distr, DistrElem 18 | from .utils import edic, edicify, expand_front, fargnames, fedic, is_scalar, normalize_logits, normalize_probs, precision_to_scale_tril, reduce_last, tcsize_div, tensorify 19 | 20 | __author__ = "Chang Liu" 21 | __version__ = "1.0.1" 22 | __email__ = "changliu@microsoft.com" 23 | 24 | class Determ(DistrElem): 25 | def __init__(self, name: str, val, shape = None, checkval: bool=False, device = None): 26 | super(Determ, self).__init__(name, shape, device, _fn = val) 27 | self._checkval = checkval 28 | 29 | def mean(self, conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 30 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 31 | return edic({self.name: self._fn(conds)}) # `repar` is only for `draw()` 32 | 33 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 34 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 35 | return self.mean(conds, 0, repar) 36 | 37 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 38 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 39 | with suppress() if repar else tc.no_grad(): 40 | return edic({self.name: expand_front(self._fn(conds), shape_mc)}) 41 | 42 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 43 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 44 | val = vals[self.name] 45 | if self._checkval: 46 | equals_all = (val - self._fn(conds)).abs() <= 1e-8 + 1e-5*val.abs() # shape matches 47 | probs = reduce_last(tc.all, equals_all, len(self.shape)).type(val.dtype) # 0. or 1. 48 | return probs.log() # -inf or 0. (probs can be recovered by tc.exp) 49 | else: 50 | return tc.zeros(tcsize_div(val.shape, self.shape), 51 | dtype=val.dtype, device=val.device, layout=val.layout) 52 | 53 | def entropy(self, conds: edic=edic()) -> tc.Tensor: 54 | # [shape_bat, shape_cond] -> [shape_bat] 55 | return tc.zeros(size=Distr.shape_bat(conds), device=self.device) 56 | 57 | class Normal(DistrElem): 58 | def __init__(self, name: str, mean = 0., std = 1., shape = None, device = None): 59 | super(Normal, self).__init__(name, shape, device, _meanfn = mean, _stdfn = std) 60 | self._log_const = .5 * math.log(2*math.pi) * tc.tensor(self.shape).prod() 61 | 62 | def mean(self, conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 63 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 64 | return edic({self.name: self._meanfn(conds)}) 65 | 66 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 67 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 68 | return self.mean(conds, 0, repar) 69 | 70 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 71 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 72 | with suppress() if repar else tc.no_grad(): 73 | meanval = expand_front(self._meanfn(conds), shape_mc) # [shape_mc, shape_bat, shape_var] 74 | return edic({self.name: self._stdfn(conds) * tc.randn_like(meanval) + meanval}) 75 | 76 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 77 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 78 | meanval, stdval = self._meanfn(conds), self._stdfn(conds) 79 | normalized_vals = (vals[self.name] - meanval) / stdval 80 | quads = reduce_last(tc.sum, normalized_vals ** 2, len(self.shape)) 81 | half_log_det = reduce_last(tc.sum, stdval.log(), len(self.shape)) 82 | return -.5 * quads - half_log_det - self._log_const 83 | 84 | def entropy(self, conds: edic=edic()) -> tc.Tensor: 85 | # [shape_bat, shape_cond] -> [shape_bat] 86 | stdval = self._stdfn(conds) 87 | half_log_det = reduce_last(tc.sum, stdval.log(), len(self.shape)) 88 | return half_log_det + self._log_const + .5 * tc.tensor(self.shape).prod() 89 | 90 | class MVNormal(DistrElem): 91 | def __init__(self, name: str, mean = 0., cov = None, prec = None, std_tril = None, shape = None, device = None): 92 | input_indicator = (cov is not None) + (prec is not None) + (std_tril is not None) 93 | if input_indicator > 1: raise ValueError(f"For {self}, at most one of covariance_matrix or precision_matrix or scale_tril can be specified") 94 | if input_indicator == 0: std_tril = 1. 95 | if device is None: device = Distr.default_device 96 | if shape is None and Distr.has_name(name): shape = Distr.shape_var(name) 97 | 98 | def _vecterize_mean(mean): 99 | if mean.ndim == 0: 100 | warnings.warn("shape of `mean` expanded by 1 ndim") 101 | return mean[None] 102 | else: return mean 103 | 104 | def _matrixize_std(std_arg): 105 | if std_arg.ndim == 0: 106 | warnings.warn("shape of `std_arg` expanded by 2 ndims") 107 | return std_arg[None, None] 108 | elif std_arg.ndim == 1: 109 | warnings.warn("shape of `std_arg` expanded by 1 ndim") 110 | return tc.diag_embed(std_arg) 111 | else: return std_arg 112 | 113 | if cov is not None: std_arg = cov 114 | elif prec is not None: std_arg = prec 115 | else: std_arg = std_tril 116 | if callable(std_arg): 117 | if callable(mean): 118 | if shape is None: raise RuntimeError(f"For {self}, argument `shape` has to be provided when both parameters are functions") 119 | parents = fargnames(mean) | fargnames(std_arg) 120 | if fargnames(mean): self._meanfn = fedic(mean) 121 | else: self._meanfn = lambda conds: tensorify(device, mean())[0].expand(Distr.shape_bat(conds) + shape) 122 | else: 123 | parents = fargnames(std_arg) 124 | mean, = tensorify(device, mean,) 125 | mean = _vecterize_mean(mean) 126 | if shape is None: shape = mean.shape 127 | self._meanfn = lambda conds: mean.expand(Distr.shape_bat(conds) + shape) 128 | # un-indent 129 | if fargnames(std_arg): _std_argfn = fedic(std_arg) 130 | else: _std_argfn = lambda conds: tensorify(device, std_arg())[0].expand(Distr.shape_bat(conds) + shape + shape[-1:]) 131 | if cov is not None: self._std_trilfn = lambda conds: tc.cholesky(_std_argfn(conds)) 132 | elif prec is not None: self._std_trilfn = lambda conds: precision_to_scale_tril(_std_argfn(conds)) 133 | else: self._std_trilfn = _std_argfn 134 | else: 135 | if callable(mean): 136 | parents = fargnames(mean) 137 | std_arg, = tensorify(device, std_arg,) 138 | std_arg = _matrixize_std(std_arg) 139 | if shape is None: shape = std_arg.shape[:-1] 140 | if parents: self._meanfn = fedic(mean) 141 | else: self._meanfn = lambda conds: tensorify(device, mean())[0].expand(Distr.shape_bat(conds) + shape) 142 | else: 143 | parents = set() 144 | mean, std_arg = tensorify(device, mean, std_arg) 145 | mean = _vecterize_mean(mean) 146 | std_arg = _matrixize_std(std_arg) 147 | if shape is None: shape = tc.broadcast_tensors(mean.unsqueeze(-1), std_arg)[0].shape[:-1] 148 | self._meanfn = lambda conds: mean.expand(Distr.shape_bat(conds) + shape) 149 | # un-indent 150 | if cov is not None: std_tril = tc.cholesky(std_arg) 151 | elif prec is not None: std_tril = precision_to_scale_tril(std_arg) 152 | else: std_tril = std_arg 153 | self._std_trilfn = lambda conds: std_tril.expand(Distr.shape_bat(conds) + shape + shape[-1:]) 154 | super(DistrElem, self).__init__(names_shapes = {name: shape}, parents = parents) 155 | self._name, self._shape, self._device = name, shape, device 156 | self._log_const = .5 * math.log(2*math.pi) * tc.tensor(shape).prod() 157 | 158 | def mean(self, conds: edic=edic(), n_mc: int=10, repar: bool=True) -> edic: 159 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 160 | return edic({self.name: self._meanfn(conds)}) 161 | 162 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 163 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 164 | return self.mean(conds, 0, repar) 165 | 166 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 167 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 168 | with suppress() if repar else tc.no_grad(): 169 | meanval = expand_front(self._meanfn(conds), shape_mc) # [shape_mc, shape_bat, shape_var] 170 | eps_ = tc.randn_like(meanval).unsqueeze(-1) 171 | return edic({self.name: (self._std_trilfn(conds) @ eps_).squeeze(-1) + meanval}) 172 | 173 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 174 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 175 | meanval, std_trilval = self._meanfn(conds), self._std_trilfn(conds) 176 | centered_vals_ = (vals[self.name] - meanval).unsqueeze(-1) 177 | normalized_vals = tc.triangular_solve(centered_vals_, std_trilval, upper=False)[0].squeeze(-1) 178 | quads = reduce_last(tc.sum, normalized_vals ** 2, len(self.shape)) 179 | half_log_det = reduce_last(tc.sum, std_trilval.diagonal(dim1=-2, dim2=-1).log(), len(self.shape)) 180 | return -.5 * quads - half_log_det - self._log_const 181 | 182 | def entropy(self, conds: edic=edic()) -> tc.Tensor: 183 | # [shape_bat, shape_cond] -> [shape_bat] 184 | std_trilval = self._std_trilfn(conds) 185 | half_log_det = reduce_last(tc.sum, std_trilval.diagonal(dim1=-2, dim2=-1).log(), len(self.shape)) 186 | return half_log_det + self._log_const + .5 * tc.tensor(self.shape).prod() 187 | 188 | class Catg(DistrElem): 189 | def __init__(self, name: str, *, probs = None, logits = None, shape = None, normalized: bool=False, device = None): 190 | # This logit should be normalized (sumexp == 1). logit == log prob. 191 | if (probs is None) == (logits is None): 192 | raise ValueError(f"For {self}, one and only one of `probs` and `logits` required") 193 | params = logits if probs is None else probs 194 | if device is None: device = Distr.default_device 195 | if shape is None and Distr.has_name(name): shape = Distr.shape_var(name) 196 | if callable(params): 197 | if shape is None: shape = tc.Size() 198 | parents = fargnames(params) 199 | if parents: paramsfn = fedic(params) 200 | else: paramsfn = lambda conds: tensorify(device, params())[0].expand(Distr.shape_bat(conds) + shape + (-1,)) 201 | if probs is None: 202 | self._logitsfn = paramsfn if normalized else lambda conds: normalize_logits(paramsfn(conds)) 203 | else: 204 | self._probsfn = paramsfn if normalized else lambda conds: normalize_probs(paramsfn(conds)) 205 | else: 206 | params, = tensorify(device, params,) 207 | if params.ndim < 1 or params.shape[-1] <= 1: raise ValueError(f"For {self}, use `Bern` for binary variables") 208 | if shape is None: shape = params.shape[:-1] 209 | parents = set() 210 | if not normalized: 211 | params = normalize_logits(params) if probs is None else normalize_probs(params) 212 | setattr( self, '_logitsfn' if probs is None else '_probsfn', 213 | lambda conds: params.expand(Distr.shape_bat(conds) + shape + (-1,)) ) 214 | super(DistrElem, self).__init__(names_shapes = {name: shape}, parents = parents) 215 | self._name, self._shape, self._device = name, shape, device 216 | 217 | @tcdu.lazy_property 218 | def _probsfn(self): return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds)) 219 | @tcdu.lazy_property 220 | def _logitsfn(self): return lambda conds: tcdu.probs_to_logits(self._probsfn(conds)) 221 | # No `mean()`. 222 | 223 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 224 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 225 | return edic({self.name: self._logitsfn(conds).argmax(dim=-1)}) 226 | 227 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 228 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 229 | if repar: warnings.warn(f"For categorical {self}, reparameterization for `draw` is not allowed") 230 | with tc.no_grad(): 231 | probs = expand_front(self._probsfn(conds), shape_mc) 232 | shape_out, n_catg = probs.shape[:-1], probs.shape[-1] 233 | probs_flat = probs.reshape(-1, n_catg) 234 | return edic({self.name: 235 | tc.multinomial(probs_flat, num_samples=1).reshape(shape_out)}) 236 | 237 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 238 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 239 | val = vals[self.name] 240 | logits = self._logitsfn(conds) # logit == log prob 241 | logits = logits.expand(val.shape + logits.shape[-1:]) 242 | logps_all = logits.gather(dim=-1, index=val.unsqueeze(-1)).squeeze(-1) 243 | return reduce_last(tc.sum, logps_all, len(self.shape)) 244 | 245 | def entropy(self, conds: edic=edic()) -> tc.Tensor: 246 | # [shape_bat, shape_cond] -> [shape_bat] 247 | logits = self._logitsfn(conds) # logit == log prob 248 | return - reduce_last(tc.sum, logits.exp() * logits, len(self.shape) + 1) 249 | 250 | class Bern(DistrElem): 251 | def __init__(self, name: str, *, probs = None, logits = None, shape = None, device = None): 252 | # This logit is NOT normalized (it has the logit of 0 being 0). So logit != log prob. 253 | if (probs is None) == (logits is None): 254 | raise ValueError(f"For {self}, one and only one of `probs` and `logits` required") 255 | super(Bern, self).__init__(name, shape, device, **({'_logitsfn':logits} if probs is None else {'_probsfn':probs})) 256 | 257 | @tcdu.lazy_property 258 | def _probsfn(self): return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds), is_binary=True) 259 | @tcdu.lazy_property 260 | def _logitsfn(self): return lambda conds: tcdu.probs_to_logits(self._probsfn(conds), is_binary=True) 261 | # No `mean()`. 262 | 263 | def mode(self, conds: edic=edic(), repar: bool=True) -> edic: 264 | # [shape_bat, shape_cond] -> [shape_bat, shape_var] 265 | return edic({self.name: (self._logitsfn(conds) > 0).long()}) 266 | 267 | def draw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic(), repar: bool=False) -> edic: # vals 268 | # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var] 269 | if repar: warnings.warn(f"For Bernoulli {self}, reparameterization for `draw` is not allowed") 270 | with tc.no_grad(): 271 | return edic({self.name: 272 | tc.bernoulli( expand_front(self._probsfn(conds), shape_mc) ).long()}) # tc.bernoulli returns float32, and allows backprop (.long() doesn't) 273 | 274 | def logp(self, vals: edic, conds: edic=edic()) -> tc.Tensor: # log_probs 275 | # [shape_bat, shape_var], [shape_bat, shape_cond] -> [shape_bat] 276 | logits = self._logitsfn(conds) 277 | logprobs = -tc.log1p(tc.exp(-logits)) # logit != log prob 278 | logps_all = tc.where(vals[self.name].bool(), logprobs, logprobs - logits) # shape matches 279 | return reduce_last(tc.sum, logps_all, len(self.shape)) 280 | 281 | def entropy(self, conds: edic=edic()) -> tc.Tensor: 282 | # [shape_bat, shape_cond] -> [shape_bat] 283 | logits = self._logitsfn(conds) 284 | logprobs = -tc.log1p(tc.exp(-logits)) # logit != log prob 285 | return reduce_last(tc.sum, -logprobs + logits / (1+tc.exp(logits)), len(self.shape)) 286 | 287 | -------------------------------------------------------------------------------- /distr/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Probabilistic Programming Package. 3 | 4 | The prototype is distributions, which can be a conditional one with 5 | functions for parameters to define the dependency. Distribution 6 | multiplication is implemented, as well as the mean, expectation, 7 | sampling with backprop capability, and log-probability. 8 | ''' 9 | 10 | import math 11 | import torch as tc 12 | from .utils import edic 13 | from .base import Distr 14 | 15 | __author__ = "Chang Liu" 16 | __version__ = "1.0.1" 17 | __email__ = "changliu@microsoft.com" 18 | 19 | def elbo(p_joint: Distr, q_cond: Distr, obs: edic, n_mc: int=10, repar: bool=True) -> tc.Tensor: # [shape_bat] -> [shape_bat] 20 | if hasattr(q_cond, "entropy"): 21 | return q_cond.expect(lambda dc: p_joint.logp(dc,dc), obs, n_mc, repar) + q_cond.entropy(obs) 22 | else: 23 | return q_cond.expect(lambda dc: p_joint.logp(dc,dc) - q_cond.logp(dc,dc), obs, n_mc, repar) 24 | 25 | def elbo_z2xy(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, repar: bool=True) -> tc.Tensor: 26 | """ For supervised VAE with structure x <- z -> y. 27 | Observations are supervised (x,y) pairs. 28 | For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x. """ 29 | if n_mc == 0: 30 | q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) 31 | if hasattr(q_z1x, "entropy"): # No difference for Gaussian 32 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc), obs_xy, 0, repar) + q_z1x.entropy(obs_xy) 33 | else: 34 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc) - q_z1x.logp(dc,dc), obs_xy, 0, repar) 35 | return q_y1x_logpval + expc_val 36 | else: 37 | q_y1x_pval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp(), obs_xy, n_mc, repar) 38 | expc_val = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp() * (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)), 39 | obs_xy, n_mc, repar) 40 | return q_y1x_pval.log() + expc_val / q_y1x_pval 41 | # q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, n_mc, repar, 42 | # reducefn=tc.logsumexp) - math.log(n_mc) 43 | # expc_logval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc) + (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)).log(), 44 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) 45 | # return q_y1x_logpval + (expc_logval - q_y1x_logpval).exp() 46 | 47 | def elbo_z2xy_twist(pt_zx: Distr, p_y1z: Distr, p_z: Distr, pt_z: Distr, qt_z1x: Distr, obs_xy: edic, n_mc: int=0, repar: bool=True) -> tc.Tensor: 48 | vwei_p_y1z_logp = lambda dc: p_z.logp(dc,dc) - pt_z.logp(dc,dc) + p_y1z.logp(dc,dc) # z, y: 49 | if n_mc == 0: 50 | r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) 51 | if hasattr(qt_z1x, "entropy"): # No difference for Gaussian 52 | expc_val = qt_z1x.expect(lambda dc: pt_zx.logp(dc,dc), obs_xy, 0, repar) + qt_z1x.entropy(obs_xy) 53 | else: 54 | expc_val = qt_z1x.expect(lambda dc: pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc), obs_xy, 0, repar) 55 | return r_y1x_logpval + expc_val 56 | else: 57 | r_y1x_pval = qt_z1x.expect(lambda dc: vwei_p_y1z_logp(dc).exp(), obs_xy, n_mc, repar) 58 | expc_val = qt_z1x.expect( lambda dc: # z, x, y: 59 | vwei_p_y1z_logp(dc).exp() * (pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc)), 60 | obs_xy, n_mc, repar) 61 | return r_y1x_pval.log() + expc_val / r_y1x_pval 62 | # r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, n_mc, repar, 63 | # reducefn=tc.logsumexp) - math.log(n_mc) # z, y: 64 | # expc_logval = qt_z1x.expect(lambda dc: # z, x, y: 65 | # vwei_p_y1z_logp(dc) + (pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc)).log(), 66 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) 67 | # return r_y1x_logpval + (expc_logval - r_y1x_logpval).exp() 68 | 69 | def elbo_zy2x(p_zyx: Distr, q_y1x: Distr, q_z1xy: Distr, obs_x: edic, n_mc: int=0, repar: bool=True) -> tc.Tensor: 70 | """ For supervised VAE with structure z -> x <- y (Kingma's semi-supervised VAE, M2). (z,y) correlation also allowed. 71 | Observations are unsupervised x data. 72 | For supervised observations of (x,y) pairs, use `elbo(p_zyx, q_z1xy, obs_xy)` as VAE z -> (x,y). """ 73 | if hasattr(q_y1x, "entropy"): 74 | return q_y1x.expect(lambda dc: elbo(p_zyx, q_z1xy, dc, n_mc, repar), 75 | obs_x, n_mc, repar) + q_y1x.entropy(obs_x) 76 | else: 77 | return q_y1x.expect(lambda dc: elbo(p_zyx, q_z1xy, dc, n_mc, repar) - q_y1x.logp(dc,dc), 78 | obs_x, n_mc, repar) 79 | 80 | -------------------------------------------------------------------------------- /distr/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Probabilistic Programming Package. 3 | 4 | The prototype is distributions, which can be a conditional one with 5 | functions for parameters to define the dependency. Distribution 6 | multiplication is implemented, as well as the mean, expectation, 7 | sampling with backprop capability, and log-probability. 8 | ''' 9 | 10 | import torch as tc 11 | from functools import partial, wraps 12 | from inspect import signature 13 | 14 | __author__ = "Chang Liu" 15 | __version__ = "1.0.1" 16 | __email__ = "changliu@microsoft.com" 17 | 18 | # enhanced dictionary 19 | class edic(dict): 20 | def __and__(self, other): return edic({k:other[k] for k in set(self) & set(other)}) # & 21 | def __rand__(self, other): return edic({k:self[k] for k in set(other) & set(self)}) # & 22 | def __or__(self, other): return edic({**self, **other}) # | 23 | def __ror__(self, other): return edic({**other, **self}) # | 24 | def __sub__(self, other): return edic({k:v for k,v in self.items() if k not in other}) # - 25 | def __rsub__(self, other): return edic({k:v for k,v in other.items() if k not in self}) # - 26 | def isdisjoint(self, other) -> bool: return set(self).isdisjoint(set(other)) 27 | 28 | def sub(self, it, fn = None): 29 | return edic({k:self[k] for k in it} if fn is None else {k:fn(self[k]) for k in it}) 30 | def sub_expand_front(self, it, shape: tc.Size): 31 | return self.sub(it, partial(expand_front, shape = shape)) 32 | def subedic(self, it, fn = None, use_default = False, default = None): 33 | if not use_default: 34 | if fn is None: return edic({k:self[k] for k in it if k in self}) 35 | else: return edic({k:fn(self[k]) for k in it if k in self}) 36 | else: 37 | if fn is None: return edic({k: (self[k] if k in self else default) for k in it}) 38 | else: return edic({k: fn(self[k] if k in self else default) for k in it}) 39 | def sublist(self, it, fn = None, use_default = False, default = None): 40 | if not use_default: 41 | if fn is None: return [self[k] for k in it if k in self] 42 | else: return [fn(self[k]) for k in it if k in self] 43 | else: 44 | if fn is None: return [(self[k] if k in self else default) for k in it] 45 | else: return [fn(self[k] if k in self else default) for k in it] 46 | 47 | def key0(self): return next(iter(self)) 48 | def value0(self): return next(iter(self.values())) 49 | def item0(self): return next(iter(self.items())) 50 | 51 | def mean(self, dim, keepdim: bool=False): 52 | return edic({k: v.mean(dim, keepdim) for k,v in self.items()}) 53 | def expand_front(self, shape: tc.Size): 54 | return edic({k: v.expand(shape + v.shape) for k,v in self.items()}) 55 | def broadcast(self): 56 | return edic(zip( self.keys(), tc.broadcast_tensors(*self.values()) )) 57 | 58 | def edicify(*args) -> tuple: 59 | return tuple(arg if type(arg) is edic else edic(arg) for arg in args) 60 | 61 | # helper functions 62 | def fargnames(fn) -> set: 63 | # return set(fn.__code__.co_varnames) - {'self'} # Also includes temporary local variables 64 | return set(signature(fn).parameters.keys()) # Do not need to substract 'self' 65 | 66 | def fedic(fn): 67 | return wraps(fn)( lambda dc: fn(**( fargnames(fn) & edicify(dc)[0] )) ) 68 | # pms = signature(fn).parameters 69 | # return wraps(fn)( lambda dc: fn(**( 70 | # edic({k:v.default for k,v in pms.items() if v.default is not v.empty}) 71 | # | (set(pms.keys()) & edicify(dc)[0]) )) ) 72 | 73 | def wrap4_multi_batchdims(fn, ndim_vars = 1): 74 | """ 75 | Function decorator to allow multiple batch dims at the front of input tensors. 76 | For incorporating functions (e.g., `torch.nn.Conv*`, `torch.nn.BatchNorm*`) that require only one batch dim 77 | into the `distr` package. 78 | """ 79 | allowed_types = [int, list, tuple] 80 | if type(ndim_vars) not in allowed_types: 81 | raise ValueError("`ndim_vars` must be within types {allowed_types}") 82 | 83 | def fn_new(*args, **kwargs): 84 | keys = tuple(kwargs.keys()) 85 | args += tuple(kwargs.values()) 86 | if type(ndim_vars) is int: ndims = [ndim_vars] * len(args) 87 | else: ndims = ndim_vars 88 | shapes_bat_var = [ 89 | ( (arg.shape[:-ndim], arg.shape[-ndim:]) if ndim else (args.shape, tc.Size()) ) 90 | if type(arg) is tc.Tensor else 91 | (None, None) 92 | for arg, ndim in zip(args, ndims)] 93 | args_batflat = [ 94 | arg.reshape(-1, *shape_var) 95 | if type(arg) is tc.Tensor and len(shape_bat) > 1 else 96 | arg 97 | for arg, (shape_bat, shape_var) in zip(args, shapes_bat_var)] 98 | 99 | if len(keys): 100 | outs_batflat = fn( *args_batflat[:-len(keys)], 101 | **dict(zip(keys, args_batflat[-len(keys):])) ) 102 | else: 103 | outs_batflat = fn(*args_batflat) 104 | single_out = type(outs_batflat) not in {list, tuple} 105 | if single_out: outs_batflat = (outs_batflat,) 106 | outs = tuple( 107 | out.reshape(*shapes_bat_var[0][0], *out.shape[1:]) # use `shape_bat` of the first input var 108 | if type(out) is tc.Tensor else 109 | out 110 | for out in outs_batflat) 111 | if single_out: return outs[0] 112 | else: return outs 113 | return fn_new 114 | 115 | def append_attrs(obj, vardict, attrs): # obj = self, vardict = locals(), attrs = set(locals.keys()) - {'self'} 116 | for attr in attrs: setattr(obj, attr, vardict[attr]) 117 | 118 | # for tc.Tensor 119 | def tensorify(device=None, *args) -> tuple: 120 | return tuple(arg.to(device) if type(arg) is tc.Tensor else tc.tensor(arg, device=device) for arg in args) 121 | 122 | def is_scalar(ten: tc.Tensor) -> bool: 123 | return ten.squeeze().ndim == 0 124 | 125 | def is_same_tensor(ten1: tc.Tensor, ten2: tc.Tensor) -> bool: 126 | return (ten1 is ten2) or ( 127 | type(ten1) == type(ten2) == tc.Tensor 128 | and ten1.data_ptr() == ten2.data_ptr() and ten1.shape == ten2.shape) 129 | 130 | def expand_front(ten: tc.Tensor, shape: tc.Size) -> tc.Tensor: 131 | return ten.expand(shape + ten.shape) 132 | 133 | def flatten_last(ten: tc.Tensor, ndims: int): 134 | if ndims <= 1: return ten 135 | else: return ten.reshape(ten.shape[:-ndims] + (-1,)) 136 | 137 | def reduce_last(reducefn, ten: tc.Tensor, ndims: int): # tc.distributions.utils 138 | if ndims < 1: return ten 139 | else: return reducefn(ten.reshape(ten.shape[:-ndims] + (-1,)), dim=-1) 140 | # return reducefn(ten, dim=list(range(-1, -ndims-1, -1))) if ndims else ten # doesn't work for `tc.all` 141 | 142 | def swap_dim_ranges(ten: tc.Tensor, dims1: tuple, dims2: tuple) -> tc.Tensor: 143 | if len(dims1) != 2 or len(dims2) != 2: 144 | raise ValueError("`dims1` and `dims2` must be 2-tuples of integers") 145 | dims1 = tuple(dim if dim >= 0 else dim+ten.ndim for dim in dims1) 146 | dims2 = tuple(dim if dim >= 0 else dim+ten.ndim for dim in dims2) 147 | if dims1[0] > dims1[1]: dims1 = (dims1[1], dims1[0]) 148 | if dims2[0] > dims2[1]: dims2 = (dims2[1], dims2[0]) 149 | if dims2[0] < dims1[1] and dims1[0] < dims2[1]: 150 | raise ValueError("`dims1` and `dims2` must define disjoint intevals") 151 | if dims2[1] <= dims1[0]: dims1, dims2 = dims2, dims1 152 | dimord = list(range(0, dims1[0])) + list(range(*dims2)) \ 153 | + list(range(dims1[1], dims2[0])) \ 154 | + list(range(*dims1)) + list(range(dims2[1], ten.ndim)) 155 | return ten.permute(*dimord) 156 | 157 | def expand_middle(ten: tc.Tensor, shape: tc.Size, pos: int) -> tc.Tensor: 158 | # Expand with `shape` in front of dim `pos`. 159 | if len(shape) == 0: return ten 160 | if pos < 0: pos += ten.ndim 161 | ten_expd = expand_front(ten, shape) 162 | if pos == 0: return ten_expd 163 | else: return swap_dim_ranges(ten_expd, (0, len(shape)), (len(shape), len(shape)+pos)) 164 | 165 | # for tc.Size 166 | def tcsizeify(*args) -> tuple: 167 | return tuple(arg if type(arg) is tc.Size else tc.Size(arg) for arg in args) 168 | 169 | def tcsize_div(sz1: tc.Size, sz2: tc.Size) -> tc.Size: 170 | if not sz2 or sz1[-len(sz2):] == sz2: return sz1[:(len(sz1)-len(sz2))] 171 | else: raise ValueError("sizes not match") 172 | 173 | def tcsize_broadcast(*sizes) -> tc.Size: 174 | szfinal = tc.Size() 175 | for sz in sizes: 176 | szlong, szshort = (szfinal, sz) if len(szfinal) >= len(sz) else (sz, szfinal) 177 | for i in range(1, 1 + len(szshort)): 178 | if szshort[-i] != 1: 179 | if szlong[-i] == 1: szlong[-i] = szshort[-i] 180 | elif szshort[-i] != szlong[-i]: raise ValueError("sizes not match") 181 | szfinal = szlong 182 | return szfinal 183 | 184 | # specific distribution utilities 185 | def normalize_probs(probs: tc.Tensor) -> tc.Tensor: 186 | return probs / probs.sum(dim=-1, keepdim=True) 187 | 188 | def normalize_logits(logits: tc.Tensor) -> tc.Tensor: 189 | return logits - logits.logsumexp(dim=-1, keepdim=True) 190 | 191 | def precision_to_scale_tril(P: tc.Tensor) -> tc.Tensor: 192 | # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril 193 | Lf = tc.cholesky(tc.flip(P, dims=(-2, -1))) 194 | L_inv = tc.transpose(tc.flip(Lf, dims=(-2, -1)), -2, -1) 195 | L = tc.triangular_solve(tc.eye(P.shape[-1], dtype=P.dtype, device=P.device), 196 | L_inv, upper=False)[0] 197 | return L 198 | 199 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | __author__ = "Chang Liu" 3 | __email__ = "changliu@microsoft.com" 4 | 5 | from .semvar import SemVar 6 | from .supvae import SupVAE 7 | from .cnbb import CNBBLoss 8 | 9 | -------------------------------------------------------------------------------- /methods/cnbb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | """ Implementation of the CNBB method "ConvNet with Batch Balancing". 3 | 4 | Based on the original description in . No official code found. 5 | """ 6 | import torch as tc 7 | from torch.nn.functional import normalize 8 | 9 | __author__ = "Chang Liu" 10 | __email__ = "changliu@microsoft.com" 11 | # tc.autograd.set_detect_anomaly(True) 12 | 13 | class CNBBLoss: 14 | def __init__(self, f_feat, actv, f_logit, dim_y, reg_w, reg_s, lr, n_iter): 15 | if actv not in {"Sigmoid", "Tanh"}: raise ValueError(f"unknown activation type '{actv}'") 16 | if dim_y == 1: 17 | celossobj = tc.nn.BCEWithLogitsLoss(reduction='none') 18 | self.celoss = lambda logits, y: celossobj(logits, y.float()) 19 | else: self.celoss = tc.nn.CrossEntropyLoss(reduction='none') 20 | self.f_feat, self.actv, self.f_logit, self.reg_w, self.reg_s, self.lr, self.n_iter \ 21 | = f_feat, actv, f_logit, reg_w, reg_s, lr, n_iter 22 | 23 | def __call__(self, x, y): 24 | n_bat = x.shape[0] 25 | # Inner iteration for weight 26 | with tc.no_grad(): feat = self.f_feat(x) 27 | feat = feat.reshape(n_bat, -1) 28 | if self.actv == "Sigmoid": is_treat = feat > .5 29 | elif self.actv == "Tanh": is_treat = feat > 0. 30 | weight = tc.full([n_bat], 1/n_bat, device=x.device, requires_grad=True) 31 | proj = (tc.eye(n_bat) - tc.ones(n_bat, n_bat) / n_bat).to(x.device) 32 | for it in range(self.n_iter): 33 | loss = ((feat.T @ ( 34 | normalize(weight[:,None] * is_treat, p=1, dim=0) 35 | - normalize(weight[:,None] * ~is_treat, p=1, dim=0) 36 | ) * ~tc.eye(feat.shape[-1], dtype=bool, device=x.device))**2).sum() \ 37 | + self.reg_w * (weight**2).sum() 38 | loss.backward() 39 | with tc.no_grad(): 40 | weight -= self.lr * (proj @ weight.grad) 41 | weight.abs_() 42 | weight /= weight.sum() 43 | weight.grad.zero_() 44 | # Optimize the model 45 | if self.actv == "Sigmoid": sqnorm_s = ((self.f_feat(x) - .5)**2).sum() 46 | elif self.actv == "Tanh": sqnorm_s = (self.f_feat(x)**2).sum() 47 | loss = weight.detach() @ self.celoss(self.f_logit(x), y) - self.reg_s * sqnorm_s 48 | return loss 49 | 50 | -------------------------------------------------------------------------------- /methods/semvar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | ''' The Semantic-Variation Generative Model. 3 | 4 | I.e., the proposed Causal Semantic Generative model (CSG). 5 | ''' 6 | import sys 7 | import math 8 | import torch as tc 9 | sys.path.append('..') 10 | import distr as ds 11 | from . import xdistr as xds 12 | 13 | __author__ = "Chang Liu" 14 | __email__ = "changliu@microsoft.com" 15 | 16 | class SemVar: 17 | @staticmethod 18 | def _get_priors(mean_s, std_s, shape_s, mean_v, std_v, shape_v, corr_sv, mvn_prior: bool = False, device = None): 19 | if not mvn_prior: 20 | if not callable(mean_s): mean_s_val = mean_s; mean_s = lambda: mean_s_val 21 | if not callable(std_s): std_s_val = std_s; std_s = lambda: std_s_val 22 | if not callable(mean_v): mean_v_val = mean_v; mean_v = lambda: mean_v_val 23 | if not callable(std_v): std_v_val = std_v; std_v = lambda: std_v_val 24 | if not callable(corr_sv): 25 | if not corr_sv**2 < 1.: raise ValueError("correlation coefficient larger than 1") 26 | corr_sv_val = corr_sv; corr_sv = lambda: corr_sv_val 27 | 28 | p_s = ds.Normal('s', mean=mean_s, std=std_s, shape=shape_s) 29 | dim_s, dim_v = tc.tensor(shape_s).prod(), tc.tensor(shape_v).prod() 30 | def mean_v1s(s): 31 | shape_bat = s.shape[:-len(shape_s)] if len(shape_s) else s.shape 32 | s_normal_flat = ((s - mean_s()) / std_s()).reshape(shape_bat+(dim_s,)) 33 | v_normal_flat = s_normal_flat[..., :dim_v] if dim_v <= dim_s \ 34 | else tc.cat([s_normal_flat, tc.zeros(shape_bat+(dim_v-dim_s,), dtype=s.dtype, device=s.device)], dim=-1) 35 | return mean_v() + corr_sv() * std_v() * v_normal_flat.reshape(shape_bat+shape_v) 36 | def std_v1s(s): 37 | corr_sv_val = ds.tensorify(device, corr_sv())[0] 38 | return ( std_v() * (1. - corr_sv_val**2).sqrt() 39 | ).expand( (s.shape[:-len(shape_s)] if len(shape_s) else s.shape) + shape_v ) 40 | p_v1s = ds.Normal('v', mean=mean_v1s, std=std_v1s, shape=shape_v) 41 | p_v = ds.Normal('v', mean=mean_v, std=std_v, shape=shape_v) 42 | prior_params_list = [] 43 | else: 44 | if len(shape_s) != 1 or len(shape_v) != 1: 45 | raise RuntimeError("only 1-dim vectors are supported for `s` and `v` in `mvn_prior` mode") 46 | dim_s = shape_s[0]; dim_v = shape_v[0] 47 | mean_s = tc.zeros(shape_s, device=device) if callable(mean_s) else ds.tensorify(device, mean_s)[0].expand(shape_s).clone().detach() 48 | mean_v = tc.zeros(shape_v, device=device) if callable(mean_v) else ds.tensorify(device, mean_v)[0].expand(shape_v).clone().detach() 49 | # Sigma_sv = L_sv L_sv^T, L_sv = (L_ss, 0; M_vs, L_vv) 50 | std_s_offdiag = tc.zeros((dim_s, dim_s), device=device) # lower triangular of L_ss (excl. diag) 51 | std_v_offdiag = tc.zeros((dim_v, dim_v), device=device) # lower triangular of L_vv (excl. diag) 52 | if callable(std_s): # for diag of L_ss 53 | std_s_diag_param = tc.zeros(shape_s, device=device) 54 | else: 55 | std_s = ds.tensorify(device, std_s)[0].expand(shape_s) 56 | std_s_diag_param = std_s.log().clone().detach() 57 | if callable(std_v): # for diag of L_vv 58 | std_v_diag_param = tc.zeros(shape_v, device=device) 59 | else: 60 | std_v = ds.tensorify(device, std_v)[0].expand(shape_v) 61 | std_v_diag_param = std_v.log().clone().detach() 62 | if any(callable(var) for var in [std_s, std_v, corr_sv]): # M_vs 63 | std_vs_mat = tc.zeros(dim_v, dim_s, device=device) 64 | else: 65 | std_vs_mat = tc.eye(dim_v, dim_s, device=device) 66 | dim_min = min(dim_s, dim_v) 67 | std_vs_diag = (ds.tensorify(device, corr_sv)[0].expand((dim_min,)) * std_s[:dim_min] * std_v[:dim_min]).sqrt() 68 | if dim_min == dim_s: std_vs_mat = (std_vs_mat @ std_vs_diag.diagflat()).clone().detach() 69 | else: std_vs_mat = (std_vs_diag.diagflat() @ std_vs_mat).clone().detach() 70 | prior_params_list = [mean_s, std_s_diag_param, std_s_offdiag, mean_v, std_v_diag_param, std_v_offdiag, std_vs_mat] 71 | 72 | def std_s_tril(): # L_ss 73 | return std_s_offdiag.tril(-1) + std_s_diag_param.exp().diagflat() 74 | p_s = ds.MVNormal('s', mean=mean_s, std_tril=std_s_tril, shape=shape_s) 75 | 76 | def mean_v1s(s): 77 | return mean_v + ( std_vs_mat @ tc.triangular_solve( 78 | (s - mean_s).unsqueeze(-1), std_s_tril(), upper=False)[0] ).squeeze(-1) 79 | def std_v1s_tril(s): # L_vv 80 | return ( std_v_offdiag.tril(-1) + std_v_diag_param.exp().diagflat() 81 | ).expand( (s.shape[:-len(shape_s)] if len(shape_s) else s.shape) + (dim_v, dim_v) ) 82 | p_v1s = ds.MVNormal('v', mean=mean_v1s, std_tril=std_v1s_tril, shape=shape_v) 83 | 84 | def cov_v(): # M_vs M_vs^T + L_vv L_vv^T 85 | L_vv = std_v_offdiag.tril(-1) + std_v_diag_param.exp().diagflat() 86 | return std_vs_mat @ std_vs_mat.T + L_vv @ L_vv.T 87 | p_v = ds.MVNormal('v', mean=mean_v, cov=cov_v, shape=shape_v) 88 | return p_s, p_v1s, p_v, prior_params_list 89 | 90 | def __init__(self, shape_s, shape_v, shape_x, dim_y, 91 | mean_x1sv, std_x1sv, logit_y1s, 92 | mean_v1x = None, std_v1x = None, mean_s1vx = None, std_s1vx = None, 93 | tmean_v1x = None, tstd_v1x = None, tmean_s1vx = None, tstd_s1vx = None, 94 | mean_s = 0., std_s = 1., mean_v = 0., std_v = 1., corr_sv = .5, 95 | learn_tprior = False, src_mvn_prior = False, tgt_mvn_prior = False, device = None): 96 | if device is not None: ds.Distr.default_device = device 97 | self._parameter_dict = {} 98 | self.shape_s, self.shape_v, self.shape_x, self.dim_y = shape_s, shape_v, shape_x, dim_y 99 | self.learn_tprior = learn_tprior 100 | 101 | self.p_x1sv = ds.Normal('x', mean=mean_x1sv, std=std_x1sv, shape=shape_x) 102 | self.p_y1s = getattr(ds, 'Bern' if dim_y == 1 else 'Catg')('y', logits=logit_y1s) 103 | 104 | self.p_s, self.p_v1s, self.p_v, prior_params_list = self._get_priors( 105 | mean_s, std_s, shape_s, mean_v, std_v, shape_v, corr_sv, src_mvn_prior, device) 106 | if src_mvn_prior: self._parameter_dict.update(zip([ 107 | 'mean_s', 'std_s_diag_param', 'std_s_offdiag', 'mean_v', 'std_v_diag_param', 'std_v_offdiag', 'std_vs_mat' 108 | ], prior_params_list)) 109 | self.p_sv = self.p_s * self.p_v1s 110 | self.p_svx = self.p_sv * self.p_x1sv 111 | 112 | if mean_v1x is not None: 113 | self.q_v1x = ds.Normal('v', mean=mean_v1x, std=std_v1x, shape=shape_v) 114 | self.q_s1vx = ds.Normal('s', mean=mean_s1vx, std=std_s1vx, shape=shape_s) 115 | self.q_sv1x = self.q_v1x * self.q_s1vx 116 | else: self.q_v1x, self.q_s1vx, self.q_sv1x = None, None, None 117 | 118 | if tmean_v1x is not None: 119 | self.qt_v1x = ds.Normal('v', mean=tmean_v1x, std=tstd_v1x, shape=shape_v) 120 | self.qt_s1vx = ds.Normal('s', mean=tmean_s1vx, std=tstd_s1vx, shape=shape_s) 121 | self.qt_sv1x = self.qt_v1x * self.qt_s1vx 122 | else: self.qt_v1x, self.qt_s1vx, self.qt_sv1x = None, None, None 123 | 124 | if learn_tprior: 125 | if not tgt_mvn_prior: 126 | tmean_s = tc.zeros(shape_s, device=device) if callable(mean_s) else ds.tensorify(device, mean_s)[0].expand(shape_s).clone().detach() 127 | tmean_v = tc.zeros(shape_v, device=device) if callable(mean_v) else ds.tensorify(device, mean_v)[0].expand(shape_v).clone().detach() 128 | tstd_s_param = tc.zeros(shape_s, device=device) if callable(std_s) else ds.tensorify(device, std_s)[0].expand(shape_s).log().clone().detach() 129 | tstd_v_param = tc.zeros(shape_v, device=device) if callable(std_v) else ds.tensorify(device, std_v)[0].expand(shape_v).log().clone().detach() 130 | if callable(corr_sv): tcorr_sv_param = tc.zeros((), device=device) 131 | else: 132 | val = (ds.tensorify(device, corr_sv)[0].reshape(()) + 1.) / 2. 133 | tcorr_sv_param = (val / (1-val)).clone().log().detach() 134 | self._parameter_dict.update({'tmean_s': tmean_s, 'tmean_v': tmean_v, 135 | 'tstd_s_param': tstd_s_param, 'tstd_v_param': tstd_v_param, 'tcorr_sv_param': tcorr_sv_param}) 136 | 137 | def tstd_s(): return tc.exp(tstd_s_param) 138 | def tstd_v(): return tc.exp(tstd_v_param) 139 | def tcorr_sv(): return 2. * tc.sigmoid(tcorr_sv_param) - 1. 140 | self.pt_s, self.pt_v1s, self.pt_v, tprior_params_list = self._get_priors( 141 | tmean_s, tstd_s, shape_s, tmean_v, tstd_v, shape_v, tcorr_sv, False, device) 142 | else: 143 | self.pt_s, self.pt_v1s, self.pt_v, tprior_params_list = self._get_priors( 144 | mean_s, std_s, shape_s, mean_v, std_v, shape_v, corr_sv, True, device) 145 | self._parameter_dict.update(zip([ 146 | 'tmean_s', 'tstd_s_diag_param', 'tstd_s_offdiag', 'tmean_v', 'tstd_v_diag_param', 'tstd_v_offdiag', 'tstd_vs_mat' 147 | ], tprior_params_list)) 148 | else: self.pt_s, self.pt_v1s, self.pt_v = self.p_s, self.p_v, self.p_v # independent prior 149 | self.pt_sv = self.pt_s * self.pt_v1s 150 | self.pt_svx = self.pt_sv * self.p_x1sv 151 | for param in self._parameter_dict.values(): param.requires_grad_() 152 | 153 | def parameters(self): 154 | for param in self._parameter_dict.values(): yield param 155 | 156 | def state_dict(self): 157 | return self._parameter_dict 158 | 159 | def load_state_dict(self, state_dict: dict): 160 | for name in list(self._parameter_dict.keys()): 161 | with tc.no_grad(): self._parameter_dict[name].copy_(state_dict[name]) 162 | 163 | def get_lossfn(self, n_mc_q: int=0, reduction: str="mean", mode: str="defl", weight_da: float=None, wlogpi: float=None): 164 | if reduction == "mean": reducefn = tc.mean 165 | elif reduction == "sum": reducefn = tc.sum 166 | elif reduction is None or reduction == "none": reducefn = lambda x: x 167 | else: raise ValueError(f"unknown `reduction` '{reduction}'") 168 | 169 | if self.q_sv1x is not None: # svgm, svgm-da2 170 | def lossfn_src(x: tc.Tensor, y: tc.LongTensor) -> tc.Tensor: 171 | return -reducefn( xds.elbo_z2xy(self.p_svx, self.p_y1s, self.q_sv1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 172 | else: 173 | if self.learn_tprior: # svgm-da 174 | def lossfn_src(x: tc.Tensor, y: tc.LongTensor) -> tc.Tensor: 175 | return -reducefn( xds.elbo_z2xy_twist(self.pt_svx, self.p_y1s, self.p_sv, self.pt_sv, self.qt_sv1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 176 | # return -reducefn( xds.elbo_z2xy_twist_fixpt(self.p_x1sv, self.p_y1s, self.p_sv, self.pt_sv, self.qt_sv1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 177 | else: # svgm-ind 178 | def lossfn_src(x: tc.Tensor, y: tc.LongTensor) -> tc.Tensor: 179 | return -reducefn( xds.elbo_z2xy_twist(self.pt_svx, self.p_y1s, self.p_v1s, self.p_v, self.qt_sv1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 180 | 181 | def lossfn_tgt(xt: tc.Tensor) -> tc.Tensor: 182 | return -reducefn( ds.elbo(self.pt_svx, self.qt_sv1x, {'x': xt}, n_mc_q) ) 183 | # return -reducefn( xds.elbo_fixllh(self.pt_sv, self.p_x1sv, self.qt_sv1x, {'x': xt}, n_mc_q) ) 184 | 185 | if mode == "src": return lossfn_src 186 | elif mode == "tgt": return lossfn_tgt # may be invalid 187 | elif not mode or mode == "defl": 188 | if self.learn_tprior: 189 | def lossfn(x: tc.Tensor, y: tc.LongTensor, xt: tc.Tensor) -> tc.Tensor: 190 | return lossfn_src(x,y) + weight_da * lossfn_tgt(xt) 191 | return lossfn 192 | else: return lossfn_src 193 | else: raise ValueError(f"unknown `mode` '{mode}'") 194 | 195 | # Utilities 196 | def llh(self, x: tc.Tensor, y: tc.LongTensor=None, n_mc_marg: int=64, use_q: bool=True, mode: str="src") -> float: 197 | if mode == "src": 198 | p_joint = self.p_svx 199 | q_cond = self.q_sv1x if self.q_sv1x else self.qt_sv1x 200 | elif mode == "tgt": 201 | p_joint = self.pt_svx 202 | q_cond = self.qt_sv1x if self.qt_sv1x else self.q_sv1x 203 | else: raise ValueError(f"unknown `mode` '{mode}'") 204 | if not use_q: 205 | if y is None: llh_vals = p_joint.marg({'x'}, n_mc_marg).logp({'x': x}) 206 | else: llh_vals = (p_joint * self.p_y1s).marg({'x', 'y'}, n_mc_marg).logp({'x': x, 'y': y}) 207 | else: 208 | if y is None: llh_vals = q_cond.expect(lambda dc: p_joint.logp(dc) - q_cond.logp(dc,dc), 209 | {'x': x}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) 210 | else: llh_vals = q_cond.expect(lambda dc: (p_joint * self.p_y1s).logp(dc) - q_cond.logp(dc,dc), 211 | {'x': x, 'y': y}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) 212 | return llh_vals.mean().item() 213 | 214 | def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True): 215 | dim_y = 2 if self.dim_y == 1 else self.dim_y 216 | y_eval = ds.expand_front(tc.arange(dim_y, device=x.device), ds.tcsize_div(x.shape, self.shape_x)) 217 | x_eval = ds.expand_middle(x, (dim_y,), -len(self.shape_x)) 218 | obs_xy = ds.edic({'x': x_eval, 'y': y_eval}) 219 | if self.q_sv1x is not None: 220 | logits = (self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) 221 | ) if n_mc_q == 0 else ( 222 | self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc), 223 | obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) 224 | ) 225 | else: 226 | vwei_p_y1s_logp = lambda dc: self.p_sv.logp(dc,dc) - self.pt_sv.logp(dc,dc) + self.p_y1s.logp(dc,dc) 227 | logits = (self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) 228 | ) if n_mc_q == 0 else ( 229 | self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) 230 | ) 231 | return (logits[..., 1] - logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits 232 | 233 | def generate(self, shape_mc: tc.Size=tc.Size(), mode: str="src") -> tuple: 234 | if mode == "src": smp_sv = self.p_sv.draw(shape_mc) 235 | elif mode == "tgt": smp_sv = self.pt_sv.draw(shape_mc) 236 | else: raise ValueError(f"unknown 'mode' '{mode}'") 237 | return self.p_x1sv.mode(smp_sv, False)['x'], self.p_y1s.mode(smp_sv, False)['y'] 238 | 239 | -------------------------------------------------------------------------------- /methods/supvae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | '''Supervised VAE (no s-v split, the CSGz ablation baseline) 3 | ''' 4 | import sys 5 | import math 6 | import torch as tc 7 | sys.path.append('..') 8 | import distr as ds 9 | from . import xdistr as xds 10 | 11 | __author__ = "Chang Liu" 12 | __email__ = "changliu@microsoft.com" 13 | 14 | class SupVAE: 15 | @staticmethod 16 | def _get_priors(mean_s, std_s, shape_s, mvn_prior: bool = False, device = None): 17 | if not mvn_prior: 18 | p_s = ds.Normal('s', mean=mean_s, std=std_s, shape=shape_s) 19 | prior_params_list = [] 20 | else: 21 | if len(shape_s) != 1: 22 | raise RuntimeError("only 1-dim vector is supported for `s` in `mvn_prior` mode") 23 | dim_s = shape_s[0] 24 | mean_s = tc.zeros(shape_s, device=device) if callable(mean_s) else ds.tensorify(device, mean_s)[0].expand(shape_s).clone().detach() 25 | std_s_offdiag = tc.zeros((dim_s, dim_s), device=device) # lower triangular of L_ss (excl. diag) 26 | if callable(std_s): # for diag of L_ss 27 | std_s_diag_param = tc.zeros(shape_s, device=device) 28 | else: 29 | std_s = ds.tensorify(device, std_s)[0].expand(shape_s) 30 | std_s_diag_param = std_s.log().clone().detach() 31 | prior_params_list = [mean_s, std_s_diag_param, std_s_offdiag] 32 | 33 | def std_s_tril(): # L_ss 34 | return std_s_offdiag.tril(-1) + std_s_diag_param.exp().diagflat() 35 | p_s = ds.MVNormal('s', mean=mean_s, std_tril=std_s_tril, shape=shape_s) 36 | return p_s, prior_params_list 37 | 38 | def __init__(self, shape_s, shape_x, dim_y, 39 | mean_x1s, std_x1s, logit_y1s, 40 | mean_s1x = None, std_s1x = None, 41 | tmean_s1x = None, tstd_s1x = None, 42 | mean_s = 0., std_s = 1., 43 | learn_tprior = False, src_mvn_prior = False, tgt_mvn_prior = False, device = None): 44 | if device is not None: ds.Distr.default_device = device 45 | self._parameter_dict = {} 46 | self.shape_x, self.dim_y, self.shape_s = shape_x, dim_y, shape_s 47 | self.learn_tprior = learn_tprior 48 | 49 | self.p_x1s = ds.Normal('x', mean=mean_x1s, std=std_x1s, shape=shape_x) 50 | self.p_y1s = getattr(ds, 'Bern' if dim_y == 1 else 'Catg')('y', logits=logit_y1s) 51 | 52 | self.p_s, prior_params_list = self._get_priors(mean_s, std_s, shape_s, src_mvn_prior, device) 53 | if src_mvn_prior: self._parameter_dict.update(zip(['mean_s', 'std_s_diag_param', 'std_s_offdiag'], prior_params_list)) 54 | self.p_sx = self.p_s * self.p_x1s 55 | 56 | if mean_s1x is not None: 57 | self.q_s1x = ds.Normal('s', mean=mean_s1x, std=std_s1x, shape=shape_s) 58 | else: self.q_s1x = None 59 | 60 | if tmean_s1x is not None: 61 | self.qt_s1x = ds.Normal('s', mean=tmean_s1x, std=tstd_s1x, shape=shape_s) 62 | else: self.qt_s1x = None 63 | 64 | if learn_tprior: 65 | if not tgt_mvn_prior: 66 | tmean_s = tc.zeros(shape_s, device=device) if callable(mean_s) else ds.tensorify(device, mean_s)[0].expand(shape_s).clone().detach() 67 | tstd_s_param = tc.zeros(shape_s, device=device) if callable(std_s) else ds.tensorify(device, std_s)[0].log().expand(shape_s).clone().detach() 68 | self._parameter_dict.update({'tmean_s': tmean_s, 'tstd_s_param': tstd_s_param}) 69 | def tstd_s(): return tc.exp(tstd_s_param) 70 | self.pt_s, tprior_params_list = self._get_priors(tmean_s, tstd_s, shape_s, False, device) 71 | else: 72 | self.pt_s, tprior_params_list = self._get_priors(mean_s, std_s, shape_s, True, device) 73 | self._parameter_dict.update(zip(['tmean_s', 'tstd_s_diag_param', 'tstd_s_offdiag'], tprior_params_list)) 74 | else: self.pt_s = self.p_s 75 | self.pt_sx = self.pt_s * self.p_x1s 76 | for param in self._parameter_dict.values(): param.requires_grad_() 77 | 78 | def parameters(self): 79 | for param in self._parameter_dict.values(): yield param 80 | 81 | def state_dict(self): 82 | return self._parameter_dict 83 | 84 | def load_state_dict(self, state_dict: dict): 85 | for name in list(self._parameter_dict.keys()): 86 | with tc.no_grad(): self._parameter_dict[name].copy_(state_dict[name]) 87 | 88 | def get_lossfn(self, n_mc_q: int=0, reduction: str="mean", mode: str="defl", weight_da: float=None, wlogpi: float=None): 89 | if reduction == "mean": reducefn = tc.mean 90 | elif reduction == "sum": reducefn = tc.sum 91 | elif reduction is None or reduction == "none": reducefn = lambda x: x 92 | else: raise ValueError(f"unknown `reduction` '{reduction}'") 93 | 94 | if self.q_s1x is not None: # svae, svae-da2 95 | def lossfn_src(x: tc.Tensor, y: tc.LongTensor) -> tc.Tensor: 96 | return -reducefn( xds.elbo_z2xy(self.p_sx, self.p_y1s, self.q_s1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 97 | else: 98 | if self.learn_tprior: # svae-da 99 | def lossfn_src(x: tc.Tensor, y: tc.LongTensor) -> tc.Tensor: 100 | return -reducefn( xds.elbo_z2xy_twist(self.pt_sx, self.p_y1s, self.p_s, self.pt_s, self.qt_s1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 101 | # return -reducefn( xds.elbo_z2xy_twist_fixpt(self.p_x1s, self.p_y1s, self.p_s, self.pt_s, self.qt_s1x, {'x':x, 'y':y}, n_mc_q, wlogpi) ) 102 | else: raise ValueError("Use `q_s1x` for the source loss when no new prior") 103 | 104 | def lossfn_tgt(xt: tc.Tensor) -> tc.Tensor: 105 | return -reducefn( ds.elbo(self.pt_sx, self.qt_s1x, {'x': xt}, n_mc_q) ) 106 | # return -reducefn( xds.elbo_fixllh(self.pt_s, self.p_x1s, self.qt_s1x, {'x': xt}, n_mc_q) ) 107 | 108 | if mode == "src": return lossfn_src 109 | elif mode == "tgt": return lossfn_tgt # may be invalid 110 | elif not mode or mode == "defl": 111 | if self.learn_tprior: 112 | def lossfn(x: tc.Tensor, y: tc.LongTensor, xt: tc.Tensor) -> tc.Tensor: 113 | return lossfn_src(x,y) + weight_da * lossfn_tgt(xt) 114 | return lossfn 115 | else: return lossfn_src 116 | else: raise ValueError(f"unknown `mode` '{mode}'") 117 | 118 | # Utilities 119 | def llh(self, x: tc.Tensor, y: tc.LongTensor=None, n_mc_marg: int=64, use_q: bool=True, mode: str="src") -> float: 120 | if mode == "src": 121 | p_joint = self.p_sx 122 | q_cond = self.q_s1x if self.q_s1x else self.qt_s1x 123 | elif mode == "tgt": 124 | p_joint = self.pt_sx 125 | q_cond = self.qt_s1x if self.qt_s1x else self.q_s1x 126 | else: raise ValueError(f"unknown `mode` '{mode}'") 127 | if not use_q: 128 | if y is None: llh_vals = p_joint.marg({'x'}, n_mc_marg).logp({'x': x}) 129 | else: llh_vals = (p_joint * self.p_y1s).marg({'x', 'y'}, n_mc_marg).logp({'x': x, 'y': y}) 130 | else: 131 | if y is None: llh_vals = q_cond.expect(lambda dc: p_joint.logp(dc) - q_cond.logp(dc,dc), 132 | {'x': x}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) 133 | else: llh_vals = q_cond.expect(lambda dc: (p_joint * self.p_y1s).logp(dc) - q_cond.logp(dc,dc), 134 | {'x': x, 'y': y}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) 135 | return llh_vals.mean().item() 136 | 137 | def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True): 138 | dim_y = 2 if self.dim_y == 1 else self.dim_y 139 | y_eval = ds.expand_front(tc.arange(dim_y, device=x.device), ds.tcsize_div(x.shape, self.shape_x)) 140 | x_eval = ds.expand_middle(x, (dim_y,), -len(self.shape_x)) 141 | obs_xy = ds.edic({'x': x_eval, 'y': y_eval}) 142 | if self.q_s1x is not None: 143 | logits = (self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) 144 | ) if n_mc_q == 0 else ( 145 | self.q_s1x.expect(lambda dc: self.p_y1s.logp(dc,dc), 146 | obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) 147 | ) 148 | else: 149 | vwei_p_y1s_logp = lambda dc: self.p_s.logp(dc,dc) - self.pt_s.logp(dc,dc) + self.p_y1s.logp(dc,dc) 150 | logits = (self.qt_s1x.expect(vwei_p_y1s_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) 151 | ) if n_mc_q == 0 else ( 152 | self.qt_s1x.expect(vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) 153 | ) 154 | return (logits[..., 1] - logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits 155 | 156 | def generate(self, shape_mc: tc.Size=tc.Size(), mode: str="src") -> tuple: 157 | if mode == "src": smp_s = self.p_s.draw(shape_mc) 158 | elif mode == "tgt": smp_s = self.pt_s.draw(shape_mc) 159 | else: raise ValueError(f"unknown 'mode' '{mode}'") 160 | return self.p_x1s.mode(smp_s, False)['x'], self.p_y1s.mode(smp_s, False)['y'] 161 | 162 | -------------------------------------------------------------------------------- /methods/xdistr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | """ Modification to the `distr` package for the structure of the 3 | Causal Semantic Generative model. 4 | """ 5 | import sys 6 | import math 7 | import torch as tc 8 | sys.path.append('..') 9 | from distr import Distr, edic 10 | 11 | __author__ = "Chang Liu" 12 | __email__ = "changliu@microsoft.com" 13 | 14 | def elbo_z2xy(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor: 15 | """ For supervised VAE with structure x <- z -> y. 16 | Observations are supervised (x,y) pairs. 17 | For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x. """ 18 | if n_mc == 0: 19 | q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) 20 | if hasattr(q_z1x, "entropy"): # No difference for Gaussian 21 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc), obs_xy, 0, repar) + q_z1x.entropy(obs_xy) 22 | else: 23 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc) - q_z1x.logp(dc,dc), obs_xy, 0, repar) 24 | return wlogpi * q_y1x_logpval + expc_val 25 | else: 26 | q_y1x_pval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp(), obs_xy, n_mc, repar) 27 | expc_val = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp() * (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)), 28 | obs_xy, n_mc, repar) 29 | return wlogpi * q_y1x_pval.log() + expc_val / q_y1x_pval 30 | # q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, n_mc, repar, 31 | # reducefn=tc.logsumexp) - math.log(n_mc) 32 | # expc_logval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc) + (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)).log(), 33 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) 34 | # return wlogpi * q_y1x_logpval + (expc_logval - q_y1x_logpval).exp() 35 | 36 | def elbo_z2xy_twist(pt_zx: Distr, p_y1z: Distr, p_z: Distr, pt_z: Distr, qt_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor: 37 | vwei_p_y1z_logp = lambda dc: p_z.logp(dc,dc) - pt_z.logp(dc,dc) + p_y1z.logp(dc,dc) # z, y: 38 | if n_mc == 0: 39 | r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) 40 | if hasattr(qt_z1x, "entropy"): # No difference for Gaussian 41 | expc_val = qt_z1x.expect(lambda dc: pt_zx.logp(dc,dc), obs_xy, 0, repar) + qt_z1x.entropy(obs_xy) 42 | else: 43 | expc_val = qt_z1x.expect(lambda dc: pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc), obs_xy, 0, repar) 44 | return wlogpi * r_y1x_logpval + expc_val 45 | else: 46 | r_y1x_pval = qt_z1x.expect(lambda dc: vwei_p_y1z_logp(dc).exp(), obs_xy, n_mc, repar) 47 | expc_val = qt_z1x.expect( lambda dc: # z, x, y: 48 | vwei_p_y1z_logp(dc).exp() * (pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc)), 49 | obs_xy, n_mc, repar) 50 | return wlogpi * r_y1x_pval.log() + expc_val / r_y1x_pval 51 | # r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, n_mc, repar, 52 | # reducefn=tc.logsumexp) - math.log(n_mc) # z, y: 53 | # expc_logval = qt_z1x.expect(lambda dc: # z, x, y: 54 | # vwei_p_y1z_logp(dc) + (pt_zx.logp(dc,dc) - qt_z1x.logp(dc,dc)).log(), 55 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) 56 | # return wlogpi * r_y1x_logpval + (expc_logval - r_y1x_logpval).exp() 57 | 58 | def elbo_fixllh(p_prior: Distr, p_llh: Distr, q_cond: Distr, obs: edic, n_mc: int=10, repar: bool=True) -> tc.Tensor: # [shape_bat] -> [shape_bat] 59 | def logp_llh_nograd(dc): 60 | with tc.no_grad(): return p_llh.logp(dc,dc) 61 | if hasattr(q_cond, "entropy"): 62 | return q_cond.expect(lambda dc: p_prior.logp(dc,dc) + logp_llh_nograd(dc), 63 | obs, n_mc, repar) + q_cond.entropy(obs) 64 | else: 65 | return q_cond.expect(lambda dc: p_prior.logp(dc,dc) + logp_llh_nograd(dc) - q_cond.logp(dc,dc), 66 | obs, n_mc, repar) 67 | 68 | def elbo_z2xy_twist_fixpt(p_x1z: Distr, p_y1z: Distr, p_z: Distr, pt_z: Distr, qt_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor: 69 | def logpt_z_nograd(dc): 70 | with tc.no_grad(): return pt_z.logp(dc,dc) 71 | vwei_p_y1z_logp = lambda dc: p_z.logp(dc,dc) - logpt_z_nograd(dc) + p_y1z.logp(dc,dc) # z, y: 72 | if n_mc == 0: 73 | r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) 74 | if hasattr(qt_z1x, "entropy"): 75 | expc_val = qt_z1x.expect(lambda dc: logpt_z_nograd(dc) + p_x1z.logp(dc,dc), 76 | obs_xy, 0, repar) + qt_z1x.entropy(obs_xy) 77 | else: 78 | expc_val = qt_z1x.expect(lambda dc: logpt_z_nograd(dc) + p_x1z.logp(dc,dc) - qt_z1x.logp(dc,dc), obs_xy, 0, repar) 79 | return wlogpi * r_y1x_logpval + expc_val 80 | else: 81 | r_y1x_pval = qt_z1x.expect(lambda dc: vwei_p_y1z_logp(dc).exp(), obs_xy, n_mc, repar) 82 | expc_val = qt_z1x.expect( lambda dc: # z, x, y: 83 | vwei_p_y1z_logp(dc).exp() * (logpt_z_nograd(dc) + p_x1z.logp(dc,dc) - qt_z1x.logp(dc,dc)), 84 | obs_xy, n_mc, repar) 85 | return wlogpi * r_y1x_pval.log() + expc_val / r_y1x_pval 86 | # r_y1x_logpval = qt_z1x.expect(vwei_p_y1z_logp, obs_xy, n_mc, repar, 87 | # reducefn=tc.logsumexp) - math.log(n_mc) # z, y: 88 | # expc_logval = qt_z1x.expect(lambda dc: # z, x, y: 89 | # vwei_p_y1z_logp(dc) + (logpt_z_nograd(dc) + p_x1z.logp(dc,dc) - qt_z1x.logp(dc,dc)).log(), 90 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) 91 | # return wlogpi * r_y1x_logpval + (expc_logval - r_y1x_logpval).exp() 92 | 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cvxopt==1.2.7 2 | dalib==0.2 3 | lime==0.2.0.1 4 | jupyter==1.0.0 5 | matplotlib==3.3.2 6 | numpy==1.19.5 7 | scikit-image==0.17.1 8 | tabulate==0.8.9 9 | torch==1.8.1 10 | torchvision==0.9.1 11 | tqdm==4.60.0 12 | -------------------------------------------------------------------------------- /test/distr_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import numpy as np\n", 11 | "import torch as tc\n", 12 | "sys.path.append('..')\n", 13 | "import distr as ds\n", 14 | "\n", 15 | "shape_s, shape_z = (2,3), (2,2)\n", 16 | "shape_bat = (30000,)\n", 17 | "mu_S, mu_Z = -1., 1.\n", 18 | "std_S, std_Z = 1.3, 1.\n", 19 | "corr_SZ = .7\n", 20 | "dim_s, dim_z = np.array(shape_s).prod(), np.array(shape_z).prod()\n", 21 | "\n", 22 | "S = mu_S + std_S * np.random.randn(*(shape_bat+shape_s)).astype(np.float32)\n", 23 | "S_normal_flat = ((S - mu_S) / std_S).reshape(shape_bat+(dim_s,))\n", 24 | "Z_normal_flat = S_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n", 25 | " else tc.cat([S_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=np.float32)], dim=-1)\n", 26 | "mu_Z1S = mu_Z + corr_SZ*std_Z * Z_normal_flat.reshape(shape_bat+shape_z)\n", 27 | "std_Z1S = std_Z * np.sqrt(1. - corr_SZ**2)\n", 28 | "Z = mu_Z1S + std_Z1S * np.random.randn(*(shape_bat+shape_z)).astype(np.float32)\n", 29 | "\n", 30 | "device = tc.device(\"cuda:0\" if tc.cuda.is_available() else \"cpu\")\n", 31 | "S, Z = tc.from_numpy(S).to(device), tc.from_numpy(Z).to(device)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Learning by Normal" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "mu_s = tc.randn(shape_s, requires_grad=True, device=device)\n", 48 | "std_s_param = tc.randn(shape_s, requires_grad=True, device=device)\n", 49 | "mu_z = tc.randn(shape_z, requires_grad=True, device=device)\n", 50 | "std_z_param = tc.randn(shape_z, requires_grad=True, device=device)\n", 51 | "corr_param = tc.randn(1, requires_grad=True, device=device)\n", 52 | "\n", 53 | "def corr_sz(): return 1. - (2*tc.sigmoid(corr_param)-1.)**2\n", 54 | "def std_s(): return tc.exp(std_s_param)\n", 55 | "def std_z(): return tc.exp(std_z_param)\n", 56 | "\n", 57 | "def mu_z1s(s):\n", 58 | " s_normal_flat = ((s - mu_s) / std_s()).reshape(shape_bat+(dim_s,))\n", 59 | " z_normal_flat = s_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n", 60 | " else tc.cat([s_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=s.dtype, device=s.device)], dim=-1)\n", 61 | " return mu_z + corr_sz()*std_z() * z_normal_flat.reshape(shape_bat+shape_z)\n", 62 | "\n", 63 | "def std_z1s():\n", 64 | " return std_z() * (1. - corr_sz()**2).sqrt()\n", 65 | "\n", 66 | "ds.Distr.clear()\n", 67 | "ds.Distr.default_device = device\n", 68 | "p_s = ds.Normal('s', mean=mu_s, std=std_s, shape=shape_s)\n", 69 | "p_z1s = ds.Normal('z',mean=mu_z1s, std=std_z1s, shape=shape_z)\n", 70 | "p_sz = p_s * p_z1s" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": { 77 | "scrolled": true 78 | }, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "tensor([[-0.9941, -0.9730, -0.9932],\n", 85 | " [-0.9953, -0.9977, -0.9990]], device='cuda:0')\n", 86 | "tensor([[1.2980, 1.2980, 1.3023],\n", 87 | " [1.3075, 1.3032, 1.2968]], device='cuda:0')\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "# print(std_s_param.data)\n", 93 | "# print(mu_s.data, std_s().data, sep='\\n')\n", 94 | "opt = tc.optim.SGD([mu_s, std_s_param], lr=1e-3)\n", 95 | "for i in range(10000):\n", 96 | " opt.zero_grad()\n", 97 | " mlogp = -p_s.logp({'s': S}).mean()\n", 98 | "# print(i, mlogp.data)\n", 99 | " mlogp.backward()\n", 100 | " opt.step()\n", 101 | "# print(std_s_param.data)\n", 102 | "# print(mu_s.data, std_s().data, sep='\\n')\n", 103 | "print(mu_s.data, std_s().data, sep='\\n')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "scrolled": false 111 | }, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "tensor([[-1.0678, -1.0178, -1.0556],\n", 118 | " [-1.0065, -1.0068, -0.9978]], device='cuda:0')\n", 119 | "tensor([[1.3031, 1.3017, 1.3065],\n", 120 | " [1.3065, 1.3031, 1.2968]], device='cuda:0')\n", 121 | "tensor([[0.9547, 0.9831],\n", 122 | " [0.9609, 0.9943]], device='cuda:0')\n", 123 | "tensor([[1.0030, 1.0033],\n", 124 | " [1.0072, 1.0018]], device='cuda:0')\n", 125 | "tensor([0.7024], device='cuda:0')\n", 126 | "tensor([[0.7140, 0.7141],\n", 127 | " [0.7169, 0.7131]], device='cuda:0')\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# print(std_s_param.data, std_z_param.data, corr_param.data, sep='\\n')\n", 133 | "# print('-'*5)\n", 134 | "# print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')\n", 135 | "opt = tc.optim.SGD([mu_s, std_s_param, mu_z, std_z_param, corr_param], lr=1e-3)\n", 136 | "for i in range(10000):\n", 137 | " opt.zero_grad()\n", 138 | " mlogp = -p_sz.logp({'s': S, 'z': Z}).mean()\n", 139 | "# print(i, mlogp.data)\n", 140 | " mlogp.backward()\n", 141 | " opt.step()\n", 142 | "# print(std_s_param.data, std_z_param.data, corr_param.data, sep='\\n')\n", 143 | "# print('-'*5)\n", 144 | "# print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')\n", 145 | "print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# Learning by MVNormal" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "mu_s = tc.randn(shape_s, requires_grad=True, device=device)\n", 162 | "std_s_diagpm = tc.randn(shape_s, requires_grad=True, device=device)\n", 163 | "std_s_offdiag = tc.randn(shape_s+shape_s[-1:], requires_grad=True, device=device)\n", 164 | "mu_z = tc.randn(shape_z, requires_grad=True, device=device)\n", 165 | "std_z_diagpm = tc.randn(shape_z, requires_grad=True, device=device)\n", 166 | "std_z_offdiag = tc.randn(shape_z+shape_z[-1:], requires_grad=True, device=device)\n", 167 | "corr_param = tc.randn(1, requires_grad=True, device=device)\n", 168 | "\n", 169 | "def corr_sz(): return 1. - (2*tc.sigmoid(corr_param)-1.)**2\n", 170 | "def std_s(): return tc.exp(std_s_diagpm).diag_embed() + std_s_offdiag.tril(diagonal=-1)\n", 171 | "def std_z(): return tc.exp(std_z_diagpm).diag_embed() + std_z_offdiag.tril(diagonal=-1)\n", 172 | "\n", 173 | "def mu_z1s(s):\n", 174 | " s_normal = tc.triangular_solve((s - mu_s).unsqueeze(-1), std_s(), upper=False)[0].squeeze(-1)\n", 175 | " s_normal_flat = s_normal.reshape(shape_bat+(dim_s,))\n", 176 | " z_normal_flat = s_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n", 177 | " else tc.cat([s_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=s.dtype, device=s.device)], dim=-1)\n", 178 | " return mu_z + corr_sz() * (std_z() @ z_normal_flat.reshape(shape_bat+shape_z+(1,))).squeeze(-1)\n", 179 | "\n", 180 | "def std_z1s():\n", 181 | " return std_z() * (1. - corr_sz()**2).sqrt()\n", 182 | "\n", 183 | "ds.Distr.clear()\n", 184 | "ds.Distr.default_device = device\n", 185 | "p_s = ds.MVNormal('s', mean=mu_s, std_tril=std_s, shape=shape_s)\n", 186 | "p_z1s = ds.MVNormal('z',mean=mu_z1s, std_tril=std_z1s, shape=shape_z)\n", 187 | "p_sz = p_s * p_z1s" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "tensor([[-0.9912, -0.9900, -0.9952],\n", 200 | " [-0.9905, -0.9530, -0.9386]], device='cuda:0')\n", 201 | "tensor([[[ 1.2981, 0.0000, 0.0000],\n", 202 | " [-0.0110, 1.2976, 0.0000],\n", 203 | " [ 0.0114, -0.0080, 1.3025]],\n", 204 | "\n", 205 | " [[ 1.3076, 0.0000, 0.0000],\n", 206 | " [-0.0060, 1.3123, 0.0000],\n", 207 | " [ 0.0095, 0.0899, 1.3067]]], device='cuda:0')\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "opt = tc.optim.SGD([mu_s, std_s_diagpm, std_s_offdiag], lr=1e-3)\n", 213 | "for i in range(10000):\n", 214 | " opt.zero_grad()\n", 215 | " mlogp = -p_s.logp({'s': S}).mean()\n", 216 | " mlogp.backward()\n", 217 | " opt.step()\n", 218 | "print(mu_s.data, std_s().data, sep='\\n')" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 7, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "tensor([[-0.9969, -1.0019, -1.0026],\n", 231 | " [-1.0011, -1.0066, -0.9976]], device='cuda:0')\n", 232 | "tensor([[[ 1.2988e+00, 0.0000e+00, 0.0000e+00],\n", 233 | " [-3.6102e-04, 1.2994e+00, 0.0000e+00],\n", 234 | " [ 2.0667e-03, 3.0024e-03, 1.3029e+00]],\n", 235 | "\n", 236 | " [[ 1.3043e+00, 0.0000e+00, 0.0000e+00],\n", 237 | " [-1.8239e-02, 1.3030e+00, 0.0000e+00],\n", 238 | " [-1.4886e-03, -9.8890e-03, 1.2968e+00]]], device='cuda:0')\n", 239 | "tensor([[1.0038, 0.9942],\n", 240 | " [0.9977, 0.9980]], device='cuda:0')\n", 241 | "tensor([[[1.0001e+00, 0.0000e+00],\n", 242 | " [4.5207e-03, 1.0015e+00]],\n", 243 | "\n", 244 | " [[1.0047e+00, 0.0000e+00],\n", 245 | " [2.6764e-04, 1.0001e+00]]], device='cuda:0')\n", 246 | "tensor([0.7009], device='cuda:0')\n", 247 | "tensor([[[7.1335e-01, 0.0000e+00],\n", 248 | " [3.2245e-03, 7.1434e-01]],\n", 249 | "\n", 250 | " [[7.1660e-01, 0.0000e+00],\n", 251 | " [1.9090e-04, 7.1335e-01]]], device='cuda:0')\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "opt = tc.optim.SGD([mu_s, std_s_diagpm, std_s_offdiag, mu_z, std_z_diagpm, std_z_offdiag, corr_param], lr=1e-3)\n", 257 | "for i in range(10000):\n", 258 | " opt.zero_grad()\n", 259 | " mlogp = -p_sz.logp({'s': S, 'z': Z}).mean()\n", 260 | " mlogp.backward()\n", 261 | " opt.step()\n", 262 | "print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')" 263 | ] 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "Python 3", 269 | "language": "python", 270 | "name": "python3" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 3 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython3", 282 | "version": "3.7.7" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 4 287 | } 288 | -------------------------------------------------------------------------------- /test/distr_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import sys 3 | import torch as tc 4 | sys.path.append('..') 5 | import distr as ds 6 | from distr.utils import expand_front, swap_dim_ranges 7 | '''Test cases of the 'distr' package. 8 | ''' 9 | 10 | __author__ = "Chang Liu" 11 | __email__ = "changliu@microsoft.com" 12 | 13 | shape_x = (1,2) 14 | shape_bat = (3,4) 15 | device = tc.device("cuda:0" if tc.cuda.is_available() else "cpu") 16 | ds.Distr.default_device = device 17 | 18 | def test_fun(title, p_z, p_x1z): 19 | print(title) 20 | print("p_z:", p_z.names, p_z.parents) 21 | print("p_x1z:", p_x1z.names, p_x1z.parents) 22 | p_zx = p_z * p_x1z 23 | print("p_zx:", p_zx.names, p_zx.parents) 24 | smp_z = p_z.draw(shape_bat) 25 | print("sample shape z:", smp_z['z'].shape) 26 | smp_x1z = p_x1z.draw((), smp_z) 27 | print("sample shape x:", smp_x1z['x'].shape) 28 | print("logp match:", tc.allclose( 29 | p_z.logp(smp_z) + p_x1z.logp(smp_x1z, smp_z), 30 | p_zx.logp(smp_z|smp_x1z) )) 31 | smp_zx = p_zx.draw(shape_bat) 32 | print("sample shape z:", smp_zx['z'].shape) 33 | print("sample shape x:", smp_zx['x'].shape) 34 | print("logp match:", tc.allclose( 35 | p_z.logp(smp_zx) + p_x1z.logp(smp_zx, smp_zx), 36 | p_zx.logp(smp_zx) )) 37 | print("logp_cartes shape:", p_x1z.logp_cartes(smp_x1z, smp_z).shape) 38 | print() 39 | ds.Distr.clear() 40 | 41 | # Normal 42 | ndim_x = len(shape_x) 43 | test_fun("Normal:", 44 | p_z = ds.Normal('z', 0., 1.), 45 | p_x1z = ds.Normal('x', shape = shape_x, mean = 46 | lambda z: swap_dim_ranges( expand_front(z, shape_x), (0, ndim_x), (ndim_x, ndim_x+z.ndim) ), 47 | std = 1. 48 | )) 49 | 50 | # MVNormal 51 | test_fun("MVNormal:", 52 | p_z = ds.MVNormal('z', 0., 1.), 53 | p_x1z = ds.MVNormal('x', shape = shape_x, mean = 54 | lambda z: swap_dim_ranges( expand_front(z, shape_x).squeeze(-1), (0, ndim_x), (ndim_x, ndim_x+z.ndim-1) ), 55 | cov = 1. 56 | )) 57 | 58 | # Catg 59 | ncat_z = 3 60 | ncat_x = 4 61 | w_z = tc.rand(ncat_z) 62 | w_z = w_z / w_z.sum() 63 | w_x = tc.rand((ncat_z,) + shape_x + (ncat_x,), device=device) 64 | w_x = w_x / w_x.sum(dim=-1, keepdim=True) 65 | test_fun("Catg:", 66 | p_z = ds.Catg('z', probs = w_z), 67 | p_x1z = ds.Catg('x', shape = shape_x, probs = 68 | lambda z: w_x.index_select(dim=0, index=z.flatten()).reshape(z.shape + w_x.shape[1:]) 69 | )) 70 | 71 | # Bern 72 | w_x = tc.rand(shape_x, device=device) 73 | w_x = tc.stack([1-w_x, w_x], dim=0) 74 | test_fun("Bern:", 75 | p_z = ds.Bern('z', probs = tc.rand(())), 76 | p_x1z = ds.Bern('x', shape = shape_x, probs = 77 | lambda z: w_x.index_select(dim=0, index=z.flatten()).reshape(z.shape + w_x.shape[1:]) 78 | )) 79 | 80 | -------------------------------------------------------------------------------- /test/utils_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import sys 3 | from time import time 4 | import numpy as np 5 | sys.path.append('..') 6 | import utils 7 | '''Test cases of the 'distr' package. 8 | ''' 9 | 10 | __author__ = "Chang Liu" 11 | __email__ = "changliu@microsoft.com" 12 | 13 | shape = (500, 200, 50) 14 | 15 | arr = np.random.rand(*shape) 16 | length = arr.shape[-1] 17 | for n_win in range(1, 3*shape[-1] + 2): 18 | print(f"{n_win:3d}", end=", ") 19 | lwid = (n_win - 1) // 2 20 | rwid = n_win//2 + 1 21 | t = time(); ma_slim = utils.moving_average_slim(arr, n_win); print(f"{time() - t:.6f}", end=", ") 22 | t = time(); ma_full = utils.moving_average_full(arr, n_win); print(f"{time() - t:.6f}", end=", ") 23 | t = time(); ma_full_check = utils.moving_average_full_checker(arr, n_win); print(f"{time() - t:.6f}", end=", ") 24 | ma_slim_check = ma_full_check[..., lwid: length-rwid+1] 25 | # print(ma_slim.shape, ma_slim_check.shape, ma_full.shape, ma_full_check.shape) 26 | print(f"slim: {np.allclose(ma_slim, ma_slim_check)}, full: {np.allclose(ma_full, ma_full_check)}") 27 | 28 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | __author__ = "Chang Liu" 3 | __email__ = "changliu@microsoft.com" 4 | 5 | from .utils import * 6 | 7 | -------------------------------------------------------------------------------- /utils/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | """ This module is adapted from Jindong Wang 2 | . 3 | """ 4 | -------------------------------------------------------------------------------- /utils/preprocess/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | from sklearn.preprocessing import StandardScaler 6 | import random 7 | from PIL import Image 8 | import torch.utils.data as data 9 | import os 10 | import os.path 11 | 12 | 13 | def make_dataset(image_list, labels): 14 | if labels: 15 | len_ = len(image_list) 16 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 17 | else: 18 | if len(image_list[0].split()) > 2: 19 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 20 | else: 21 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 22 | return images 23 | 24 | 25 | def pil_loader(path): 26 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | 32 | def accimage_loader(path): 33 | import accimage 34 | try: 35 | return accimage.Image(path) 36 | except IOError: 37 | # Potentially a decoding problem, fall back to PIL.Image 38 | return pil_loader(path) 39 | 40 | 41 | def default_loader(path): 42 | #from torchvision import get_image_backend 43 | #if get_image_backend() == 'accimage': 44 | # return accimage_loader(path) 45 | #else: 46 | return pil_loader(path) 47 | 48 | 49 | class ImageList(object): 50 | """A generic data loader where the images are arranged in this way: :: 51 | root/dog/xxx.png 52 | root/dog/xxy.png 53 | root/dog/xxz.png 54 | root/cat/123.png 55 | root/cat/nsdf3.png 56 | root/cat/asd932_.png 57 | Args: 58 | root (string): Root directory path. 59 | transform (callable, optional): A function/transform that takes in an PIL image 60 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 61 | target_transform (callable, optional): A function/transform that takes in the 62 | target and transforms it. 63 | loader (callable, optional): A function to load an image given its path. 64 | Attributes: 65 | classes (list): List of the class names. 66 | class_to_idx (dict): Dict with items (class_name, class_index). 67 | imgs (list): List of (image path, class_index) tuples 68 | """ 69 | 70 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 71 | loader=default_loader, root='.'): 72 | imgs = make_dataset(image_list, labels) 73 | if len(imgs) == 0: 74 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 75 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 76 | 77 | self.root = root 78 | self.imgs = imgs 79 | self.transform = transform 80 | self.target_transform = target_transform 81 | self.loader = loader 82 | 83 | def __getitem__(self, index): 84 | """ 85 | Args: 86 | index (int): Index 87 | Returns: 88 | tuple: (image, target) where target is class_index of the target class. 89 | """ 90 | path, target = self.imgs[index] 91 | img = self.loader(self.root + path) 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | if self.target_transform is not None: 95 | target = self.target_transform(target) 96 | 97 | return img, target 98 | 99 | def __len__(self): 100 | return len(self.imgs) 101 | -------------------------------------------------------------------------------- /utils/preprocess/data_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from torchvision import datasets, transforms 3 | import torch 4 | from torch.utils import data 5 | import numpy as np 6 | from torchvision import transforms 7 | import os 8 | from PIL import Image 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | from skimage import io 11 | 12 | 13 | class PlaceCrop(object): 14 | def __init__(self, size, start_x, start_y): 15 | if isinstance(size, int): 16 | self.size = (int(size), int(size)) 17 | else: 18 | self.size = size 19 | self.start_x = start_x 20 | self.start_y = start_y 21 | 22 | def __call__(self, img): 23 | th, tw = self.size 24 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 25 | 26 | 27 | class ResizeImage(): 28 | def __init__(self, size): 29 | if isinstance(size, int): 30 | self.size = (int(size), int(size)) 31 | else: 32 | self.size = size 33 | 34 | def __call__(self, img): 35 | th, tw = self.size 36 | return img.resize((th, tw)) 37 | 38 | 39 | class myDataset(data.Dataset): 40 | def __init__(self, root, transform=None, train=True): 41 | self.train = train 42 | class_dirs = [os.path.join(root, i) for i in os.listdir(root)] 43 | imgs = [] 44 | for i in class_dirs: 45 | imgs += [os.path.join(i, img) for img in os.listdir(i)] 46 | np.random.shuffle(imgs) 47 | imgs_mun = len(imgs) 48 | # target:val = 8 :2 49 | if self.train: 50 | self.imgs = imgs[:int(0.3*imgs_mun)] 51 | else: 52 | self.imgs = imgs[int(0.3*imgs_mun):] 53 | if transform: 54 | self.transforms = transforms.Compose( 55 | [transforms.Resize([256, 256]), 56 | transforms.RandomCrop(224), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor()]) 59 | else: 60 | start_center = (256 - 224 - 1) / 2 61 | self.transforms = transforms.Compose( 62 | [transforms.Resize([224, 224]), 63 | PlaceCrop(224, start_center, start_center), 64 | transforms.ToTensor()]) 65 | 66 | def __getitem__(self, index): 67 | img_path = self.imgs[index] 68 | label = int(img_path.strip().split('/')[10]) 69 | print(img_path, label) 70 | #data = Image.open(img_path) 71 | data = io.imread(img_path) 72 | data = Image.fromarray(data) 73 | if data.getbands()[0] == 'L': 74 | data = data.convert('RGB') 75 | data = self.transforms(data) 76 | return data, label 77 | 78 | def __len__(self): 79 | return len(self.imgs) 80 | 81 | 82 | def load_training(root_path, domain, batch_size, kwargs, train_val_split=.5, rand_split=True): 83 | kwargs_fin = dict(shuffle=True, drop_last=True) 84 | kwargs_fin.update(kwargs) 85 | normalize = transforms.Normalize( 86 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 87 | transform = transforms.Compose( 88 | [ResizeImage(256), 89 | transforms.Resize(256), 90 | transforms.CenterCrop(224), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | normalize]) 94 | data = datasets.ImageFolder(root=os.path.join( 95 | root_path, domain), transform=transform) 96 | if train_val_split <= 0: 97 | train_loader = torch.utils.data.DataLoader( 98 | data, batch_size=batch_size, **kwargs_fin) 99 | return train_loader 100 | else: 101 | train_loader, val_loader = load_train_valid_split( 102 | data, batch_size, kwargs_fin, val_ratio=1.-train_val_split, rand_split=rand_split) 103 | return train_loader, val_loader 104 | 105 | 106 | def load_testing(root_path, domain, batch_size, kwargs): 107 | kwargs_fin = dict(shuffle=False, drop_last=False) 108 | kwargs_fin.update(kwargs) 109 | normalize = transforms.Normalize( 110 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 111 | start_center = (256 - 224 - 1) / 2 112 | transform = transforms.Compose( 113 | [ResizeImage(256), 114 | PlaceCrop(224, start_center, start_center), 115 | transforms.ToTensor(), 116 | normalize]) 117 | dataset = datasets.ImageFolder(root=os.path.join( 118 | root_path, domain), transform=transform) 119 | test_loader = torch.utils.data.DataLoader( 120 | dataset, batch_size=batch_size, **kwargs_fin) 121 | return test_loader 122 | 123 | 124 | def load_train_valid_split(dataset, batch_size, kwargs, val_ratio=0.4, rand_split=True): 125 | dataset_size = len(dataset) 126 | indices = list(range(dataset_size)) 127 | if rand_split: np.random.shuffle(indices) 128 | len_val = int(np.floor(val_ratio * dataset_size)) 129 | train_indices, val_indices = indices[len_val:], indices[:len_val] 130 | 131 | train_sampler = SubsetRandomSampler(train_indices) 132 | valid_sampler = SubsetRandomSampler(val_indices) 133 | 134 | __ = kwargs.pop('shuffle', None) 135 | __ = kwargs.pop('drop_last', None) 136 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 137 | sampler=train_sampler, **kwargs, drop_last=True) 138 | validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 139 | sampler=valid_sampler, **kwargs, drop_last=True) 140 | return train_loader, validation_loader 141 | 142 | def load_data(root_path, source_dir, target_dir, batch_size): 143 | kwargs = {'num_workers': 4, 'pin_memory': True} 144 | source_loader = load_training( 145 | root_path, source_dir, batch_size, kwargs) 146 | target_loader = load_training( 147 | root_path, target_dir, batch_size, kwargs) 148 | test_loader = load_testing( 149 | root_path, target_dir, batch_size, kwargs) 150 | return source_loader, target_loader, test_loader 151 | 152 | def load_all_test(root_path, dataset, batch_size, train, kwargs): 153 | ls = [] 154 | domains = {'Office-31': ['amazon', 'dslr', 'webcam'], 155 | 'Office-Home': ['Art', 'Clipart', 'Product', 'RealWorld']} 156 | for dom in domains[dataset]: 157 | if train: 158 | loader = load_training(root_path, dom, batch_size, kwargs, train_val_split=-1) 159 | else: 160 | loader = load_testing(root_path, dom, batch_size, kwargs) 161 | ls.append(loader) 162 | return ls 163 | -------------------------------------------------------------------------------- /utils/preprocess/data_provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .data_list import ImageList 3 | import torch.utils.data as util_data 4 | from torchvision import transforms 5 | from PIL import Image, ImageOps 6 | 7 | 8 | class ResizeImage(): 9 | def __init__(self, size): 10 | if isinstance(size, int): 11 | self.size = (int(size), int(size)) 12 | else: 13 | self.size = size 14 | def __call__(self, img): 15 | th, tw = self.size 16 | return img.resize((th, tw)) 17 | 18 | 19 | class PlaceCrop(object): 20 | 21 | def __init__(self, size, start_x, start_y): 22 | if isinstance(size, int): 23 | self.size = (int(size), int(size)) 24 | else: 25 | self.size = size 26 | self.start_x = start_x 27 | self.start_y = start_y 28 | 29 | def __call__(self, img): 30 | th, tw = self.size 31 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 32 | 33 | 34 | def load_images(images_file_path, batch_size, resize_size=256, is_train=True, crop_size=224, is_cen=False, num_workers=4): 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 36 | if not is_train: 37 | start_center = (resize_size - crop_size - 1) / 2 38 | transformer = transforms.Compose([ 39 | ResizeImage(resize_size), 40 | PlaceCrop(crop_size, start_center, start_center), 41 | transforms.ToTensor(), 42 | normalize]) 43 | images = ImageList(open(images_file_path).readlines(), transform=transformer) 44 | images_loader = util_data.DataLoader(images, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) 45 | else: 46 | if is_cen: 47 | transformer = transforms.Compose([ResizeImage(resize_size), 48 | transforms.Scale(resize_size), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.CenterCrop(crop_size), 51 | transforms.ToTensor(), 52 | normalize]) 53 | else: 54 | transformer = transforms.Compose([ResizeImage(resize_size), 55 | transforms.RandomResizedCrop(crop_size), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | normalize]) 59 | images = ImageList(open(images_file_path).readlines(), transform=transformer) 60 | images_loader = util_data.DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) 61 | return images_loader 62 | 63 | -------------------------------------------------------------------------------- /utils/reprun.sh: -------------------------------------------------------------------------------- 1 | n=$1 2 | command=$2 3 | shift 4 | shift 5 | for i in $(seq 1 $n) 6 | do 7 | $command $* 8 | done 9 | 10 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | import os 3 | import warnings 4 | from itertools import product, chain 5 | import math 6 | import torch as tc 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from tabulate import tabulate 10 | 11 | __author__ = "Chang Liu" 12 | __email__ = "changliu@microsoft.com" 13 | 14 | # General Utilities 15 | def unique_filename(prefix: str="", suffix: str="", n_digits: int=2, count_start: int=0) -> str: 16 | fmt = "{:0" + str(n_digits) + "d}" 17 | if prefix and prefix[-1] not in {"/", "\\"}: prefix += "_" 18 | while True: 19 | filename = prefix + fmt.format(count_start) + suffix 20 | if not os.path.exists(filename): return filename 21 | else: count_start += 1 22 | 23 | class Averager: 24 | """Computes and stores the average and current value""" 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self._val = 0 30 | self._avg = 0 31 | self._sum = 0 32 | self._count = 0 33 | 34 | def update(self, val, nrep = 1): 35 | self._val = val 36 | self._sum += val * nrep 37 | self._count += nrep 38 | self._avg = self._sum / self._count 39 | 40 | @property 41 | def val(self): return self._val 42 | @property 43 | def avg(self): return self._avg 44 | @property 45 | def sum(self): return self._sum 46 | @property 47 | def count(self): return self._count 48 | 49 | def repeat_iterator(itr, n_repeat): 50 | # The built-in `itertools.cycle` stores all results over `itr` and does not initialize `itr` again. 51 | return chain.from_iterable([itr] * n_repeat) 52 | 53 | class RepeatIterator: 54 | def __init__(self, itr, n_repeat): 55 | self.itr = itr 56 | self.n_repeat = n_repeat 57 | self.len = len(itr) * n_repeat 58 | 59 | def __iter__(self): 60 | return chain.from_iterable([self.itr] * self.n_repeat) 61 | 62 | def __len__(self): return self.len 63 | 64 | def zip_longer(itr1, itr2): 65 | len_ratio = len(itr1) / len(itr2) 66 | if len_ratio > 1: 67 | return zip(itr1, repeat_iterator(itr2, math.ceil(len_ratio))) 68 | elif len_ratio < 1: 69 | return zip(repeat_iterator(itr1, math.ceil(1/len_ratio)), itr2) 70 | else: return zip(itr1, itr2) 71 | 72 | def zip_longest(*itrs): 73 | itr_longest = max(itrs, key=len) 74 | len_longest = len(itr_longest) 75 | return zip(*[itr if len(itr) == len_longest 76 | else repeat_iterator(itr, math.ceil(len_longest / len(itr))) 77 | for itr in itrs]) 78 | 79 | class ZipLongest: 80 | def __init__(self, *itrs): 81 | self.itrs = itrs 82 | self.itr_longest = max(itrs, key=len) 83 | self.len = len(self.itr_longest) 84 | 85 | def __iter__(self): 86 | return zip(*[itr if len(itr) == self.len 87 | else repeat_iterator(itr, math.ceil(self.len / len(itr))) 88 | for itr in self.itrs]) 89 | 90 | def __len__(self): return self.len 91 | 92 | 93 | class CyclicLoader: 94 | def __init__(self, datalist: list, shuffle: bool=True, cycle: int=None): 95 | self.len = len(datalist[0]) 96 | if shuffle: 97 | ids = np.random.permutation(self.len) 98 | self.datalist = [data[ids] for data in datalist] 99 | else: self.datalist = datalist 100 | self.cycle = self.len if cycle is None else cycle 101 | self.head = 0 102 | def iter(self): 103 | self.head = 0 104 | return self 105 | def next(self, n: int) -> tuple: 106 | ids = [i % self.cycle for i in range(self.head, self.head + n)] 107 | self.head = (self.head + n) % self.cycle 108 | return tuple(data[ids] for data in self.datalist) 109 | def back(self, n: int): 110 | self.head = (self.head - n) % self.cycle 111 | return self 112 | 113 | def boolstr(s: str) -> bool: 114 | # for argparse argument of type bool 115 | if isinstance(s, str): 116 | true_strings = {'1', 'true', 'True', 'T', 'yes', 'Yes', 'Y'} 117 | false_strings = {'0', 'false', 'False', 'F', 'no', 'No', 'N'} 118 | if s not in true_strings | false_strings: 119 | raise ValueError('Not a valid boolean string') 120 | return s in true_strings 121 | else: 122 | return bool(s) 123 | 124 | ## For lists/tuples 125 | def getlist(ls: list, ids: list) -> list: 126 | return [ls[i] for i in ids] 127 | 128 | def interleave(*lists) -> list: 129 | return [v for row in zip(*lists) for v in row] 130 | 131 | def flatten(ls: list, depth: int=None) -> list: 132 | i = 0 133 | while (depth is None or i < depth) and bool(ls) and all( 134 | type(row) is list or type(row) is tuple for row in ls): 135 | ls = [v for row in ls if bool(row) for v in row] 136 | i += 1 137 | return ls 138 | 139 | class SlicesAt: 140 | def __init__(self, axis: int, ndim: int): 141 | if ndim <= 0: raise ValueError(f"`ndim` (which is {ndim}) should be a positive integer") 142 | if not -ndim <= axis < ndim: raise ValueError(f"`axis` (which is {axis}) should be within [{-ndim}, {ndim})") 143 | self._axis, self._ndim = axis % ndim, ndim 144 | 145 | def __getitem__(self, idx): 146 | slices = [slice(None)] * self._ndim 147 | slices[self._axis] = idx 148 | return tuple(slices) 149 | 150 | # For numpy/torch 151 | def moving_average_slim(arr: np.ndarray, n_win: int=2, axis: int=-1) -> np.ndarray: 152 | # `(n_win-1)` shorter. Good for any positive `n_win`. Good if `arr` is empty in `axis` 153 | if n_win <= 0: raise ValueError(f"nonpositive `n_win` {n_win} not allowed") 154 | slc = SlicesAt(axis, arr.ndim) 155 | concatfn = tc.cat if type(arr) is tc.Tensor else np.concatenate 156 | cum = arr.cumsum(axis) # , dtype=float) 157 | return concatfn([ cum[slc[n_win-1:n_win]], cum[slc[n_win:]]-cum[slc[:-n_win]] ], axis) / float(n_win) 158 | 159 | def moving_average_full(arr: np.ndarray, n_win: int=2, axis: int=-1) -> np.ndarray: 160 | # Same length as `arr`. Good for any positive `n_win`. Good if `arr` is empty in `axis` 161 | if n_win <= 0: raise ValueError(f"nonpositive `n_win` {n_win} not allowed") 162 | slc = SlicesAt(axis, arr.ndim) 163 | concatfn = tc.cat if type(arr) is tc.Tensor else np.concatenate 164 | cum = arr.cumsum(axis) # , dtype=float) 165 | stem = concatfn([ cum[slc[n_win-1:n_win]], cum[slc[n_win:]]-cum[slc[:-n_win]] ], axis) / float(n_win) 166 | length = arr.shape[axis] 167 | lwid = (n_win - 1) // 2 168 | rwid = n_win//2 + 1 169 | return concatfn([ 170 | *[ cum[slc[j-1: j]] / float(j) for i in range(min(lwid, length)) for j in [min(i+rwid, length)] ], 171 | stem, 172 | *[ (cum[slc[-1:]] - cum[slc[i-lwid-1: i-lwid]] if i-lwid > 0 else cum[slc[-1:]]) / float(length-i+lwid) 173 | for i in range(max(length-rwid+1, lwid), length) ] 174 | ], axis) 175 | 176 | def moving_average_full_checker(arr: np.ndarray, n_win: int=2, axis: int=-1) -> np.ndarray: 177 | # Same length as `arr`. Good for any positive `n_win`. Good if `arr` is empty in `axis` 178 | if n_win <= 0: raise ValueError(f"nonpositive `n_win` {n_win} not allowed") 179 | if arr.shape[axis] < 2: return arr 180 | slc = SlicesAt(axis, arr.ndim) 181 | concatfn = tc.cat if type(arr) is tc.Tensor else np.concatenate 182 | lwid = (n_win - 1) // 2 183 | rwid = n_win//2 + 1 184 | return concatfn([ arr[slc[max(0, i-lwid): (i+rwid)]].mean(axis, keepdims=True) for i in range(arr.shape[axis]) ], axis) 185 | 186 | # Plotting Utilities 187 | class Plotter: 188 | def __init__(self, var_xlab: dict, metr_ylab: dict, tab_items: list=None, 189 | check_var: bool=False, check_tab: bool=False, loader = tc.load): 190 | for var in var_xlab: 191 | if not var_xlab[var]: var_xlab[var] = var 192 | for metr in metr_ylab: 193 | if not metr_ylab[metr]: metr_ylab[metr] = metr 194 | self.var_xlab, self.metr_ylab = var_xlab, metr_ylab 195 | self.variables, self.metrics = list(var_xlab), list(metr_ylab) 196 | if tab_items is None: self.tab_items = [] 197 | else: self.tab_items = tab_items 198 | self.check_var, self.check_tab = check_var, check_tab 199 | self.loader = loader 200 | self._plt_data, self._tab_data = [], [] 201 | 202 | def _get_res(self, dataholder): # does not change `self` 203 | res_x = {var: [] for var in self.variables} 204 | res_ymean = {metr: np.array([]) for metr in self.metrics} 205 | res_ystd = {metr: np.array([]) for metr in self.metrics} 206 | res_tab = [None for item in self.tab_items] 207 | if type(dataholder) is not dict: # treated as a list of data file names 208 | resfiles = [] 209 | for file in dataholder: 210 | if os.path.isfile(file): resfiles.append(file) 211 | else: warnings.warn(f"file '{file}' does not exist") 212 | dataholder = dict() 213 | for file in resfiles: 214 | ckp = self.loader(file) 215 | for name in self.metrics + self.variables + self.tab_items: 216 | if name not in ckp: warnings.warn(f"metric or variable or item '{name}' not found in file '{file}'") 217 | else: 218 | if name not in dataholder: dataholder[name] = [] 219 | dataholder[name].append(ckp[name]) 220 | for metr in self.metrics: 221 | if metr not in dataholder or not dataholder[metr]: 222 | warnings.warn(f"metric '{metr}' not found or empty") 223 | continue 224 | n_align = min((len(line) for line in dataholder[metr]), default=0) 225 | if n_align: 226 | vals = np.array([line[:n_align] for line in dataholder[metr]]) 227 | res_ymean[metr] = vals.mean(0) 228 | res_ystd[metr] = vals.std(0) 229 | for var in self.variables: 230 | if var not in dataholder or not dataholder[var]: 231 | warnings.warn(f"variable '{var}' not found or empty") 232 | continue 233 | n_align = min((len(line) for line in dataholder[var]), default=0) 234 | if n_align: 235 | res_x[var] = dataholder[var][0] 236 | if self.check_var: 237 | for line in dataholder[var][1:]: 238 | if line != res_x[x]: raise RuntimeError(f"variable '{var}' not match") 239 | for i, item in enumerate(self.tab_items): 240 | if item not in dataholder or not dataholder[item]: 241 | warnings.warn(f"item '{item}' not found or empty") 242 | continue 243 | res_tab[i] = dataholder[item][0] 244 | if self.check_tab: 245 | for val in dataholder[item][1:]: 246 | if val != res_tab[i]: raise RuntimeError(f"item '{item}' not match") 247 | return res_x, res_ymean, res_ystd, res_tab 248 | 249 | def load(self, *triplets): 250 | # each triplet = (legend, pltsty, [filename1, filename2]), or 251 | # (legend, pltsty, {var1: [val1_1, val1_2], metr2: [val2_1, val2_2, val2_3]}) 252 | data = [(legend, pltsty, *self._get_res(dataholder)) for legend, pltsty, dataholder in triplets] 253 | self._plt_data += [entry[:-1] for entry in data] 254 | self._tab_data += [[entry[0]] + entry[-1] for entry in data] 255 | 256 | def clear(self): 257 | self._plt_data.clear() 258 | self._tab_data.clear() 259 | 260 | def plot(self, variables: list=None, metrics: list=None, 261 | var_xlim: dict=None, metr_ylim: dict=None, 262 | n_start: int=None, n_stop: int=None, n_step: int=None, n_win: int=1, 263 | plot_err: bool=True, ncol: int=None, 264 | fontsize: int=20, figheight: int=8, linewidth: int=4, alpha: float=.2, show_legend: bool=True): 265 | if variables is None: variables = self.variables 266 | if metrics is None: metrics = self.metrics 267 | if var_xlim is None: var_xlim = {} 268 | if metr_ylim is None: metr_ylim = {} 269 | slc = slice(n_start, n_stop, n_step) 270 | if ncol is None: ncol = max(2, len(variables)) 271 | nfig = len(variables) * len(metrics) 272 | nrow = (nfig-1) // ncol + 1 273 | if nfig < ncol: ncol = nfig 274 | plt.rcParams.update({'font.size': fontsize}) 275 | 276 | fig, axes0 = plt.subplots(nrow, ncol, figsize=(ncol*figheight, nrow*figheight)) 277 | if nfig == 1: axes = [axes0] 278 | elif nrow > 1: axes = [ax for row in axes0 for ax in row][:nfig] 279 | else: axes = axes0[:nfig] 280 | for ax, (metr, var) in zip(axes, product(metrics, variables)): 281 | plotted = False 282 | for legend, pltsty, res_x, res_ymean, res_ystd in self._plt_data: 283 | y, std = res_ymean[metr], res_ystd[metr] 284 | x = res_x[var] if var is not None else list(range(min(len(y), len(std)))) 285 | n_align = min(len(x), len(y), len(std)) 286 | x, y, std = x[:n_align], y[:n_align], std[:n_align] 287 | if n_win > 1: 288 | y = moving_average_full(y, n_win) 289 | if plot_err: std = moving_average_full(std, n_win) # Not precise! std and averaging is not interchangeable, since sqrt(sum ^2) is not linear 290 | x, y, std = x[slc], y[slc], std[slc] 291 | if len(x): 292 | if plot_err: 293 | ax.fill_between(x, y-std, y+std, facecolor=pltsty[0], alpha=alpha, linewidth=0) 294 | ax.plot(x, y, pltsty, label=legend, linewidth=linewidth) 295 | plotted = True 296 | if show_legend and plotted: ax.legend() 297 | if var in var_xlim: ax.set_xlim(var_xlim[var]) 298 | if metr in metr_ylim: ax.set_ylim(metr_ylim[metr]) 299 | ax.set_xlabel(self.var_xlab[var] if var is not None else "index") 300 | ax.set_ylabel(self.metr_ylab[metr]) 301 | return fig, axes0 302 | 303 | def inspect(self, metr: str, ids: list=None, var: str=None, vals: list=None, 304 | show_std: bool=True, **tbformat): 305 | if (ids is None) == (var is None and vals is None): 306 | raise ValueError("exactly one of `ids`, or `var` and `vals`, should be provided") 307 | if ids is not None: 308 | if not show_std: 309 | table = [[legend, *getlist(res_ymean[metr], ids)] for legend, pltsty, res_x, res_ymean, res_ystd in self._plt_data] 310 | print(tabulate(table, headers = ["indices"] + ids, **tbformat)) 311 | else: 312 | table = [[legend, *interleave(getlist(res_ymean[metr], ids), getlist(res_ystd[metr], ids))] 313 | for legend, pltsty, res_x, res_ymean, res_ystd in self._plt_data] 314 | print(tabulate(table, headers = ["indices"] + interleave(ids, ids), **tbformat)) 315 | else: 316 | if not show_std: 317 | table = [[legend, *[res_ymean[metr][res_x[var].index(val)] for val in vals]] 318 | for legend, pltsty, res_x, res_ymean, res_ystd in self._plt_data] 319 | print(tabulate(table, headers = [var] + vals, **tbformat)) 320 | else: 321 | ids_list = [[res_x[var].index(val) for val in vals] for _, _, res_x, _, _ in self._plt_data] 322 | table = [[legend, *interleave(getlist(res_ymean[metr], ids), getlist(res_ystd[metr], ids))] 323 | for ids, (legend, pltsty, res_x, res_ymean, res_ystd) in zip(ids_list, self._plt_data)] 324 | print(tabulate(table, headers = [var] + interleave(vals, vals), **tbformat)) 325 | return table 326 | 327 | def tabulate(self, **tbformat): 328 | print(tabulate(self._tab_data, headers = ["legend"] + self.tab_items, **tbformat)) 329 | 330 | --------------------------------------------------------------------------------