├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 0808_112404_cbcg.csv ├── 0810_191625_bcg.csv ├── 0821_213901_rcbcg.csv ├── 314393_began_Bald_topdo1_botcond1.png ├── 314393_began_Mustache_topdo1_botcond1.png ├── big_causal_graph.png ├── causalbegan_pictures │ ├── 190001_G_diversity.png │ ├── 190001_intvcond_Bald=1_2x10.png │ ├── 190001_intvcond_Eyeglasses=1_2x10.png │ ├── 190001_intvcond_Mouth_Slightly_Open=1_2x10.png │ ├── 190001_intvcond_Mustache=1_2x10.png │ ├── 190001_intvcond_Narrow_Eyes=1_2x10.png │ ├── 190001_intvcond_Smiling=1_2x10.png │ └── 190001_intvcond_Wearing_Lipstick=1_2x10.png ├── causalgan_pictures │ ├── 45507_G_diversity.png │ ├── 45507_intvcond_Bald=1_2x10.png │ ├── 45507_intvcond_Eyeglasses=1_2x10.png │ ├── 45507_intvcond_Mouth_Slightly_Open=1_2x10.png │ ├── 45507_intvcond_Mustache=1_2x10.png │ ├── 45507_intvcond_Narrow_Eyes=1_2x10.png │ ├── 45507_intvcond_Smiling=1_2x10.png │ └── 45507_intvcond_Wearing_Lipstick=1_2x10.png ├── guide_to_gifs.txt ├── tvd_vs_step.pdf ├── tvd_vs_step.png └── tvdplot.ipynb ├── causal_began ├── CausalBEGAN.py ├── __init__.py ├── config.py ├── models.py └── utils.py ├── causal_controller ├── ArrayDict.py ├── CausalController.py ├── __init__.py ├── config.py ├── models.py └── utils.py ├── causal_dcgan ├── CausalGAN.py ├── __init__.py ├── config.py ├── models.py ├── ops.py └── utils.py ├── causal_graph.py ├── config.py ├── data_loader.py ├── download.py ├── figure_scripts ├── __init__.py ├── distributions.py ├── encode.py ├── high_level.py ├── pairwise.py ├── probability_table.txt ├── sample.py └── utils.py ├── main.py ├── synthetic ├── README.md ├── assets │ ├── 0818_072052 │ │ ├── x1x3_all.pdf │ │ ├── x1x3_collider.pdf │ │ ├── x1x3_complete.pdf │ │ ├── x1x3_data.pdf │ │ ├── x1x3_fc5.pdf │ │ ├── x1x3_linear.pdf │ │ ├── x1x3_notextcollider.pdf │ │ ├── x1x3_notextcomplete.pdf │ │ ├── x1x3_notextdata.pdf │ │ ├── x1x3_notextfc5.pdf │ │ ├── x1x3_notextlinear.pdf │ │ ├── x1x3_notitlecollider.pdf │ │ ├── x1x3_notitlecomplete.pdf │ │ ├── x1x3_notitledata.pdf │ │ ├── x1x3_notitlefc5.pdf │ │ └── x1x3_notitlelinear.pdf │ ├── collider_synth_tvd_vs_time.pdf │ ├── complete_synth_tvd_vs_time.pdf │ ├── linear_synth_tvd_vs_time.pdf │ ├── liny_tvd │ │ ├── liny_collider_synth_tvd_vs_time.pdf │ │ ├── liny_complete_synth_tvd_vs_time.pdf │ │ └── liny_linear_synth_tvd_vs_time.pdf │ └── synth_tvd_vs_time_titled.pdf ├── collect_stats.py ├── config.py ├── figure_generation.ipynb ├── main.py ├── models.py ├── run_datasets.sh ├── tboard.py ├── trainer.py └── utils.py ├── tboard.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | data 3 | .*.swp 4 | 5 | logs 6 | old 7 | 8 | final_checkpoints 9 | checkpoint/ 10 | figures/ 11 | *.pyc 12 | .DS_Store 13 | .ipynb_checkpoints 14 | [._]*.s[a-v][a-z] 15 | [._]*.sw[a-p] 16 | [._]s[a-v][a-z] 17 | [._]sw[a-p] 18 | 19 | samples 20 | outputs 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Murat Kocaoglu, Christopher Snyder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CausalGAN/CausalBEGAN in Tensorflow 2 | 3 | Tensorflow implementation of [CausalGAN: Learning Causal Implicit Generative Models with Adversarial Training](https://arxiv.org/abs/1709.02023) 4 | 5 | ### Top: Random samples from do(Bald=1); Bottom: Random samples from cond(Bald=1) 6 | ![alt text](./assets/314393_began_Bald_topdo1_botcond1.png) 7 | ### Top: Random samples from do(Mustache=1); Bottom: Random samples from cond(Mustache=1) 8 | ![alt text](./assets/314393_began_Mustache_topdo1_botcond1.png) 9 | 10 | 11 | ## Requirements 12 | - Python 2.7 13 | - [Pillow](https://pillow.readthedocs.io/en/4.0.x/) 14 | - [tqdm](https://github.com/tqdm/tqdm) 15 | - [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset) 16 | - [TensorFlow 1.1.0](https://github.com/tensorflow/tensorflow) 17 | 18 | ## Getting Started 19 | 20 | First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with: 21 | 22 | $ apt-get install p7zip-full # ubuntu 23 | $ brew install p7zip # Mac 24 | $ pip install tqdm 25 | $ python download.py 26 | 27 | ## Usage 28 | 29 | The CausalGAN/CausalBEGAN code factorizes into two components, which can be trained or loaded independently: the causal_controller module specifies the model which learns a causal generative model over labels, and the causal_dcgan or causal_began modules learn a GAN over images given those labels. We denote training the causal controller over labels as "pretraining" (--is_pretrain=True), and training a GAN over images given labels as "training" (--is_train=True) 30 | 31 | To train a causal implicit model over labels and then over the image given the labels use 32 | 33 | $ python main.py --causal_model big_causal_graph --is_pretrain True --model_type began --is_train True 34 | 35 | where "big_causal_graph" is one of the causal graphs specified by the keys in the causal_graphs dictionary in causal_graph.py. 36 | 37 | Alternatively, one can first train a causal implicit model over labels only with the following command: 38 | 39 | $ python main.py --causal_model big_causal_graph --is_pretrain True 40 | 41 | One can then train a conditional generative model for the images given the trained causal generative model for the labels (causal controller), which yields a causal implicit generative model for the image and the labels, as suggested in [arXiv link to the paper]: 42 | 43 | $ echo CC-MODEL_PATH='./logs/celebA_0810_191625_0.145tvd_bcg/controller/checkpoints/CC-Model-20000' 44 | $ python main.py --causal_model big_causal_graph --pt_load_path $CC-MODEL_PATH --model_type began --is_train True 45 | 46 | Instead of loading the model piecewise, once image training has been run once, the entire joint model can be loaded more simply by specifying the model directory: 47 | 48 | $ python main.py --causal_model big_causal_graph --load_path ./logs/celebA_0815_170635 --model_type began --is_train True 49 | 50 | Tensorboard visualization of the most recently created model is simply (as long as port 6006 is free): 51 | 52 | $ python tboard.py 53 | 54 | 55 | To interact with an already trained model I recommend the following procedure: 56 | 57 | ipython 58 | In [1]: %run main --causal_model big_causal_graph --load_path './logs/celebA_0815_170635' --model_type 'began' 59 | 60 | For example to sample N=22 interventional images from do(Smiling=1) (as long as your causal graph includes a "Smiling" node: 61 | 62 | In [2]: sess.run(model.G,{cc.Smiling.label:np.ones((22,1), trainer.batch_size:22}) 63 | 64 | Conditional sampling is most efficiently done through 2 session calls: the first to cc.sample_label to get, and the second feeds that sampled label to get an image. See trainer.causal_sampling for a more extensive example. Note that is also possible combine conditioning and intervention during sampling. 65 | 66 | In [3]: lab_samples=cc.sample_label(sess,do_dict={'Bald':1}, cond_dict={'Mustache':1},N=22) 67 | 68 | will sample all labels from the joint distribution conditioned on Mustache=1 and do(Bald=1). These label samples can be turned into image samples as follows: 69 | 70 | In [4]: feed_dict={cc.label_dict[k]:v for k,v in lab_samples.iteritems()} 71 | In [5]: feed_dict[trainer.batch_size]=22 72 | In [6]: images=sess.run(trainer.G,feed_dict) 73 | 74 | 75 | ### Configuration 76 | Since this really controls training of 3 different models (CausalController, CausalGAN, and CausalBEGAN), many configuration options are available. To make things managable, there are 4 files corresponding to configurations specific to different parts of the model. Not all configuration combinations are tested. Default parameters are gauranteed to work. 77 | 78 | configurations: 79 | ./config.py : generic data and scheduling 80 | ./causal_controller/config : specific to CausalController 81 | ./causal_dcgan/config : specific to CausalGAN 82 | ./causal_began/config : specific to CausalBEGAN 83 | 84 | For convenience, the configurations used are saved in 4 .json files in the model directory for future reference. 85 | 86 | 87 | ## Results 88 | 89 | ### Causal Controller convergence 90 | We show convergence in TVD for Causal Graph 1 (big_causal_graph in causal_graph.py), a completed version of Causal Graph 1 (complete_big_causal_graph in causal_graph.py, and an edge reversed version of the complete Causal Graph 1 (reverse_big_causal_graph in causal_graph.py). We could get reasonable marginals with a complete DAG containing all 40 nodes, but TVD becomes very difficult to measure. We show TVD convergence for 9 nodes for two complete graphs. When the graph is incomplete, there is a "TVD gap" but reasonable convergence. 91 | 92 | ![alt text](./assets/tvd_vs_step.png) 93 | 94 | ### Conditional vs Interventional Sampling: 95 | We trained a causal implicit generative model assuming we are given the following causal graph over labels: 96 | For the following images when we condition or intervene, these operations can be reasoned about from the graph structure. e.g., conditioning on mustache=1 should give more male whereas intervening should not (since the edges from the parents are disconnected in an intervention). 97 | 98 | ### CausalGAN Conditioning vs Intervening 99 | For each label, images were randomly sampled by either _intervening_ (top row) or _conditioning_ (bottom row) on label=1. 100 | 101 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Bald=1_2x10.png) Bald 102 | 103 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Mouth_Slightly_Open=1_2x10.png) Mouth Slightly Open 104 | 105 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Mustache=1_2x10.png) Mustache 106 | 107 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Narrow_Eyes=1_2x10.png) Narrow Eyes 108 | 109 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Smiling=1_2x10.png) Smiling 110 | 111 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Eyeglasses=1_2x10.png) Eyeglasses 112 | 113 | ![alt text](./assets/causalgan_pictures/45507_intvcond_Wearing_Lipstick=1_2x10.png) Wearing Lipstick 114 | 115 | ### CausalBEGAN Conditioning vs Intervening 116 | For each label, images were randomly sampled by either _intervening_ (top row) or _conditioning_ (bottom row) on label=1. 117 | 118 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Bald=1_2x10.png) Bald 119 | 120 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Mouth_Slightly_Open=1_2x10.png) Mouth Slightly Open 121 | 122 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Mustache=1_2x10.png) Mustache 123 | 124 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Narrow_Eyes=1_2x10.png) Narrow Eyes 125 | 126 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Smiling=1_2x10.png) Smiling 127 | 128 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Eyeglasses=1_2x10.png) Eyeglasses 129 | 130 | ![alt text](./assets/causalbegan_pictures/190001_intvcond_Wearing_Lipstick=1_2x10.png) Wearing Lipstick 131 | 132 | ### CausalGAN Generator output (10x10) (randomly sampled label) 133 | ![alt text](https://user-images.githubusercontent.com/10726729/30076306-09743002-923e-11e7-8011-8523cd914f25.gif) 134 | 135 | ### CausalBEGAN Generator output (10x10) (randomly sampled label) 136 | ![alt text](https://user-images.githubusercontent.com/10726729/30076379-38b407fc-923e-11e7-81aa-4310c76a2e39.gif) 137 | 138 | <--- 139 | Repo originally forked from these two 140 | - [BEGAN-tensorflow](https://github.com/carpedm20/BEGAN-tensorflow) 141 | - [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) 142 | --> 143 | 144 | ## Related works 145 | - [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661) 146 | - [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) 147 | - [Wasserstein GAN](https://arxiv.org/abs/1701.07875) 148 | - [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717) 149 | 150 | ## Authors 151 | 152 | Christopher Snyder / [@22csnyder](http://22csnyder.github.io) 153 | Murat Kocaoglu / [@mkocaoglu](http://mkocaoglu.github.io) 154 | -------------------------------------------------------------------------------- /assets/0808_112404_cbcg.csv: -------------------------------------------------------------------------------- 1 | Wall time,Step,Value 2 | 1502209477.065396,1,0.9871935844421387 3 | 1502210175.629644,1001,0.5611526370048523 4 | 1502210858.027971,2001,0.48091334104537964 5 | 1502211539.450148,3001,0.3693326711654663 6 | 1502212228.305266,4001,0.2690610885620117 7 | 1502212916.163691,5001,0.1852252036333084 8 | 1502213605.455342,6001,0.11786147207021713 9 | 1502214290.655429,7001,0.10585799068212509 10 | 1502214974.834744,8001,0.11575613915920258 11 | 1502215664.377923,9001,0.09277261048555374 12 | 1502216342.813149,10001,0.08084549009799957 13 | 1502217004.542623,11001,0.07447165995836258 14 | 1502217677.840079,12001,0.07388914376497269 15 | 1502218338.794636,13001,0.06354445964097977 16 | 1502219000.20777,14001,0.058855485171079636 17 | 1502219659.079145,15001,0.06558254361152649 18 | 1502220348.8056,16001,0.051907140761613846 19 | 1502221033.399544,17001,0.04890892282128334 20 | 1502221718.709654,18001,0.04604059085249901 21 | 1502222403.268966,19001,0.04389917105436325 22 | 1502223087.183902,20001,0.04280887916684151 23 | 1502223772.410776,21001,0.04196497052907944 24 | 1502224457.815937,22001,0.038901761174201965 25 | 1502225141.198389,23001,0.04273799806833267 26 | 1502225826.618027,24001,0.041886329650878906 27 | 1502226518.698883,25001,0.04319506511092186 28 | 1502227208.700241,26001,0.042861778289079666 29 | 1502227899.513253,27001,0.04321207478642464 30 | 1502228588.126751,28001,0.035417430102825165 31 | 1502229277.24218,29001,0.03713845834136009 32 | 1502229964.6007,30001,0.03938867151737213 33 | -------------------------------------------------------------------------------- /assets/0810_191625_bcg.csv: -------------------------------------------------------------------------------- 1 | Wall time,Step,Value 2 | 1502410626.387592,1,0.9544087648391724 3 | 1502411081.292726,1001,0.5290326476097107 4 | 1502411533.622933,2001,0.44044023752212524 5 | 1502411981.535893,3001,0.35751280188560486 6 | 1502412434.074014,4001,0.2676760256290436 7 | 1502412884.345166,5001,0.20682139694690704 8 | 1502413336.727762,6001,0.1853639930486679 9 | 1502413786.845507,7001,0.19252602756023407 10 | 1502414239.265506,8001,0.19284175336360931 11 | 1502414689.356373,9001,0.16991157829761505 12 | 1502415145.18223,10001,0.15723274648189545 13 | 1502415595.021095,11001,0.15078511834144592 14 | 1502416037.124821,12001,0.14841803908348083 15 | 1502416478.158467,13001,0.1522006243467331 16 | 1502416920.270544,14001,0.15191766619682312 17 | 1502417364.060506,15001,0.14936088025569916 18 | 1502417803.97219,16001,0.14549562335014343 19 | 1502418242.907475,17001,0.14224907755851746 20 | 1502418684.820146,18001,0.13779735565185547 21 | 1502419124.551228,19001,0.14404024183750153 22 | -------------------------------------------------------------------------------- /assets/0821_213901_rcbcg.csv: -------------------------------------------------------------------------------- 1 | Wall time,Step,Value 2 | 1503369574.677247,1,0.8920440077781677 3 | 1503370041.447478,1001,0.512530505657196 4 | 1503370517.215026,2001,0.44317319989204407 5 | 1503370985.171754,3001,0.35666027665138245 6 | 1503371450.274446,4001,0.2928802967071533 7 | 1503371929.346399,5001,0.19688302278518677 8 | 1503372408.39261,6001,0.13801704347133636 9 | 1503372886.733545,7001,0.1106921136379242 10 | 1503373363.362404,8001,0.08717407286167145 11 | 1503373839.834317,9001,0.0857364684343338 12 | 1503374318.503915,10001,0.07331433147192001 13 | 1503374802.444324,11001,0.07706638425588608 14 | 1503375279.389205,12001,0.06169278547167778 15 | 1503375752.728541,13001,0.059477031230926514 16 | 1503376226.577342,14001,0.061632610857486725 17 | 1503376699.448754,15001,0.06138858571648598 18 | 1503377174.465165,16001,0.05955960601568222 19 | 1503377653.261056,17001,0.04774799197912216 20 | 1503378126.625743,18001,0.05300581455230713 21 | 1503378604.128631,19001,0.047743991017341614 22 | 1503379079.647434,20001,0.05426724627614021 23 | 1503379555.901424,21001,0.04658582806587219 24 | 1503380028.219916,22001,0.04909271374344826 25 | 1503380498.204313,23001,0.05326574668288231 26 | 1503380962.853232,24001,0.05447468161582947 27 | 1503381428.927937,25001,0.05708151310682297 28 | 1503381893.354328,26001,0.051777616143226624 29 | 1503382360.002207,27001,0.046131476759910583 30 | 1503382825.077767,28001,0.04513547569513321 31 | 1503383290.90524,29001,0.044165026396512985 32 | -------------------------------------------------------------------------------- /assets/314393_began_Bald_topdo1_botcond1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/314393_began_Bald_topdo1_botcond1.png -------------------------------------------------------------------------------- /assets/314393_began_Mustache_topdo1_botcond1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/314393_began_Mustache_topdo1_botcond1.png -------------------------------------------------------------------------------- /assets/big_causal_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/big_causal_graph.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_G_diversity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_G_diversity.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Bald=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Bald=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Eyeglasses=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Eyeglasses=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Mouth_Slightly_Open=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Mouth_Slightly_Open=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Mustache=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Mustache=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Narrow_Eyes=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Narrow_Eyes=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Smiling=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Smiling=1_2x10.png -------------------------------------------------------------------------------- /assets/causalbegan_pictures/190001_intvcond_Wearing_Lipstick=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalbegan_pictures/190001_intvcond_Wearing_Lipstick=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_G_diversity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_G_diversity.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Bald=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Bald=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Eyeglasses=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Eyeglasses=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Mouth_Slightly_Open=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Mouth_Slightly_Open=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Mustache=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Mustache=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Narrow_Eyes=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Narrow_Eyes=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Smiling=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Smiling=1_2x10.png -------------------------------------------------------------------------------- /assets/causalgan_pictures/45507_intvcond_Wearing_Lipstick=1_2x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/causalgan_pictures/45507_intvcond_Wearing_Lipstick=1_2x10.png -------------------------------------------------------------------------------- /assets/guide_to_gifs.txt: -------------------------------------------------------------------------------- 1 | #Approach uses imagemagick 2 | #Take the first 20 images in a folder and convert to gif 3 | ls -v | head -20 | xargs cp -t newfolder 4 | cd newfolder 5 | mogrify -format png *.pdf 6 | mogrify -crop 62.5%x62.5%+0+0 +repage *.png 7 | rm *.pdf 8 | convert -delay 20 $(ls -v) -loop 0 -layers optimize mygifname.gif 9 | -------------------------------------------------------------------------------- /assets/tvd_vs_step.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/tvd_vs_step.pdf -------------------------------------------------------------------------------- /assets/tvd_vs_step.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/assets/tvd_vs_step.png -------------------------------------------------------------------------------- /assets/tvdplot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Using matplotlib backend: TkAgg\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import matplotlib.pyplot as plt\n", 20 | "import tensorflow as tf\n", 21 | "import pandas as pd\n", 22 | "%matplotlib" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "collapsed": false 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "\n", 34 | "raw_data={'cG1': pd.read_csv('0808_112404_cbcg.csv'),\n", 35 | " 'G1' : pd.read_csv('0810_191625_bcg.csv'),\n", 36 | " 'rcG1': pd.read_csv('0821_213901_rcbcg.csv')}\n", 37 | "xlabel='Training Step'\n", 38 | "dfs=[pd.DataFrame(data={k:v['Value'].values,xlabel:v['Step'].values}) for k,v in raw_data.items()]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 7, 44 | "metadata": { 45 | "collapsed": false 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "\n", 50 | "raw_data={'Causal Graph 1' : pd.read_csv('0810_191625_bcg.csv'),\n", 51 | " 'complete Causal Graph 1': pd.read_csv('0808_112404_cbcg.csv'), \n", 52 | " 'edge-reversed complete Causal Graph 1': pd.read_csv('0821_213901_rcbcg.csv')}\n", 53 | "xlabel='Training Step'\n", 54 | "dfs=[pd.DataFrame(data={k:v['Value'].values,xlabel:v['Step'].values}) for k,v in raw_data.items()]" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 8, 60 | "metadata": { 61 | "collapsed": false 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "def my_merge(df1,df2):\n", 66 | " return pd.merge(df1,df2,how='outer',on=xlabel)\n", 67 | " \n", 68 | "\n", 69 | "plot_data=reduce(my_merge,dfs)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 9, 75 | "metadata": { 76 | "collapsed": false 77 | }, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "execution_count": 9, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "ax=plot_data.plot.line(x=xlabel,xlim=[0,18000],ylim=[0,1],style = ['bs-','ro-','y^-'])\n", 92 | "ax.set_ylabel('Total Variation Distance',fontsize=18)\n", 93 | "ax.set_title('TVD of Label Generation',fontsize=18)\n", 94 | "ax.set_xlabel(xlabel,fontsize=18)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 10, 100 | "metadata": { 101 | "collapsed": false 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "plt.savefig('tvd_vs_step.pdf')" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "collapsed": true 113 | }, 114 | "outputs": [], 115 | "source": [] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 2", 121 | "language": "python", 122 | "name": "python2" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 2 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython2", 134 | "version": "2.7.12" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 1 139 | } 140 | -------------------------------------------------------------------------------- /causal_began/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/causal_began/__init__.py -------------------------------------------------------------------------------- /causal_began/config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | #return (v is True) or (v.lower() in ('true', '1')) 6 | return v is True or v.lower() in ('true', '1') 7 | 8 | arg_lists = [] 9 | parser = argparse.ArgumentParser() 10 | 11 | def add_argument_group(name): 12 | arg = parser.add_argument_group(name) 13 | arg_lists.append(arg) 14 | return arg 15 | 16 | 17 | #Network 18 | net_arg = add_argument_group('Network') 19 | net_arg.add_argument('--c_dim',type=int, default=3, 20 | help='''number of color channels. I wouldn't really change 21 | this from 3''') 22 | net_arg.add_argument('--conv_hidden_num', type=int, default=128, 23 | choices=[64, 128],help='n in the paper') 24 | net_arg.add_argument('--separate_labeler', type=str2bool, default=True) 25 | net_arg.add_argument('--z_dim', type=int, default=64, choices=[64, 128], 26 | help='''dimension of the noise input to the generator along 27 | with the labels''') 28 | net_arg.add_argument('--z_num', type=int, default=64, 29 | help='''dimension of the hidden space of the autoencoder''') 30 | 31 | 32 | # Data 33 | data_arg = add_argument_group('Data') 34 | data_arg.add_argument('--dataset', type=str, default='celebA') 35 | data_arg.add_argument('--split', type=str, default='train') 36 | data_arg.add_argument('--batch_size', type=int, default=16) 37 | 38 | # Training / test parameters 39 | train_arg = add_argument_group('Training') 40 | train_arg.add_argument('--beta1', type=float, default=0.5) 41 | train_arg.add_argument('--beta2', type=float, default=0.999) 42 | train_arg.add_argument('--d_lr', type=float, default=0.00008) 43 | train_arg.add_argument('--g_lr', type=float, default=0.00008) 44 | train_arg.add_argument('--label_loss',type=str,default='squarediff',choices=['xe','absdiff','squarediff'], 45 | help='''what comparison should be made between the 46 | labeler output and the actual labels''') 47 | train_arg.add_argument('--lr_update_step', type=int, default=100000, choices=[100000, 75000]) 48 | train_arg.add_argument('--max_step', type=int, default=50000) 49 | train_arg.add_argument('--num_iter',type=int,default=250000, 50 | help='the number of training iterations to run the model for') 51 | train_arg.add_argument('--optimizer', type=str, default='adam') 52 | train_arg.add_argument('--round_fake_labels',type=str2bool,default=True, 53 | help='''Whether the label outputs of the causal 54 | controller should be rounded first before calculating 55 | the loss of generator or d-labeler''') 56 | train_arg.add_argument('--use_gpu', type=str2bool, default=True) 57 | train_arg.add_argument('--num_gpu', type=int, default=1, 58 | help='specify 0 for cpu. If k specified, will default to\ 59 | first k of n gpus detected. If use_gpu=True but num_gpu not\ 60 | specified will default to 1') 61 | 62 | margin_arg = add_argument_group('Margin') 63 | margin_arg.add_argument('--gamma', type=float, default=0.5) 64 | margin_arg.add_argument('--gamma_label', type=float, default=0.5) 65 | margin_arg.add_argument('--lambda_k', type=float, default=0.001) 66 | margin_arg.add_argument('--lambda_l', type=float, default=0.00008, 67 | help='''As mentioned in the paper this is lower because 68 | this margin can be responded to more quickly than the 69 | other margins. Im not sure if it definitely needs to be lower''') 70 | margin_arg.add_argument('--lambda_z', type=float, default=0.01) 71 | margin_arg.add_argument('--no_third_margin', type=str2bool, default=False, 72 | help='''Use True for appendix figure in paper. This is 73 | used to neglect the third margin (c3,b3)''') 74 | margin_arg.add_argument('--zeta', type=float, default=0.5, 75 | help='''This is gamma_3 in the paper''') 76 | 77 | # Misc 78 | misc_arg = add_argument_group('Misc') 79 | misc_arg.add_argument('--is_train',type=str2bool,default=False, 80 | help='''whether to enter the image training loop''') 81 | misc_arg.add_argument('--build_all', type=str2bool, default=False, 82 | help='''normally specifying is_pretrain=False will cause 83 | the pretraining components not to be built and likewise 84 | with is_train=False only the pretrain compoenent will 85 | (possibly) be built. This is here as a debug helper to 86 | enable building out the whole model without doing any 87 | training''') 88 | misc_arg.add_argument('--data_dir', type=str, default='data') 89 | misc_arg.add_argument('--dry_run', action='store_true') 90 | #misc_arg.add_argument('--dry_run', type=str2bool, default='False') 91 | misc_arg.add_argument('--log_step', type=int, default=100, 92 | help='''how often to log stuff. Sample images are created 93 | every 10*log_step''') 94 | misc_arg.add_argument('--num_log_samples', type=int, default=3) 95 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 96 | misc_arg.add_argument('--log_dir', type=str, default='logs') 97 | 98 | 99 | 100 | def gpu_logic(config): 101 | #consistency between use_gpu and num_gpu 102 | if config.num_gpu>0: 103 | config.use_gpu=True 104 | else: 105 | config.use_gpu=False 106 | # if config.use_gpu and config.num_gpu==0: 107 | # config.num_gpu=1 108 | return config 109 | 110 | 111 | def get_config(): 112 | config, unparsed = parser.parse_known_args() 113 | config=gpu_logic(config) 114 | 115 | #this has to respect gpu/cpu 116 | #data_format = 'NCHW' 117 | if config.use_gpu: 118 | data_format = 'NCHW' 119 | else: 120 | data_format = 'NHWC' 121 | setattr(config, 'data_format', data_format) 122 | 123 | 124 | print('Loaded ./causal_began/config.py') 125 | 126 | return config, unparsed 127 | 128 | if __name__=='__main__': 129 | #for debug of config 130 | config, unparsed = get_config() 131 | 132 | -------------------------------------------------------------------------------- /causal_began/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | 5 | 6 | def lrelu(x,leak=0.2,name='lrelu'): 7 | with tf.variable_scope(name): 8 | f1=0.5 * (1+leak) 9 | f2=0.5 * (1-leak) 10 | return f1*x + f2*tf.abs(x) 11 | 12 | def GeneratorCNN( z, config, reuse=None): 13 | hidden_num=config.conv_hidden_num 14 | output_num=config.c_dim 15 | repeat_num=config.repeat_num 16 | data_format=config.data_format 17 | 18 | with tf.variable_scope("G",reuse=reuse) as vs: 19 | x = slim.fully_connected(z, np.prod([8, 8, hidden_num]),activation_fn=None,scope='fc1') 20 | x = reshape(x, 8, 8, hidden_num, data_format) 21 | 22 | for idx in range(repeat_num): 23 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, 24 | data_format=data_format,scope='conv'+str(idx)+'a') 25 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, 26 | data_format=data_format,scope='conv'+str(idx)+'b') 27 | if idx < repeat_num - 1: 28 | x = upscale(x, 2, data_format) 29 | 30 | out = slim.conv2d(x, 3, 3, 1, activation_fn=None,data_format=data_format,scope='conv'+str(idx+1)) 31 | 32 | variables = tf.contrib.framework.get_variables(vs) 33 | return out, variables 34 | 35 | def DiscriminatorCNN(image, config, reuse=None): 36 | hidden_num=config.conv_hidden_num 37 | data_format=config.data_format 38 | input_channel=config.channel 39 | 40 | with tf.variable_scope("D",reuse=reuse) as vs: 41 | # Encoder 42 | with tf.variable_scope('encoder'): 43 | x = slim.conv2d(image, hidden_num, 3, 1, activation_fn=tf.nn.elu, 44 | data_format=data_format,scope='conv0') 45 | 46 | prev_channel_num = hidden_num 47 | for idx in range(config.repeat_num): 48 | channel_num = hidden_num * (idx + 1) 49 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, 50 | data_format=data_format,scope='conv'+str(idx+1)+'a') 51 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, 52 | data_format=data_format,scope='conv'+str(idx+1)+'b') 53 | if idx < config.repeat_num - 1: 54 | x = slim.conv2d(x, channel_num, 3, 2, activation_fn=tf.nn.elu, 55 | data_format=data_format,scope='conv'+str(idx+1)+'c') 56 | #x = tf.contrib.layers.max_pool2d(x, [2, 2], [2, 2], padding='VALID') 57 | 58 | x = tf.reshape(x, [-1, np.prod([8, 8, channel_num])]) 59 | z = x = slim.fully_connected(x, config.z_num, activation_fn=None,scope='proj') 60 | 61 | # Decoder 62 | with tf.variable_scope('decoder'): 63 | x = slim.fully_connected(x, np.prod([8, 8, hidden_num]), activation_fn=None) 64 | x = reshape(x, 8, 8, hidden_num, data_format) 65 | 66 | for idx in range(config.repeat_num): 67 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, 68 | data_format=data_format,scope='conv'+str(idx)+'a') 69 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, 70 | data_format=data_format,scope='conv'+str(idx)+'b') 71 | if idx < config.repeat_num - 1: 72 | x = upscale(x, 2, data_format) 73 | out = slim.conv2d(x, input_channel, 3, 1, activation_fn=None, 74 | data_format=data_format,scope='proj') 75 | 76 | variables = tf.contrib.framework.get_variables(vs) 77 | return out, z, variables 78 | 79 | 80 | def Discriminator_labeler(image, output_size, config, reuse=None): 81 | hidden_num=config.conv_hidden_num 82 | repeat_num=config.repeat_num 83 | data_format=config.data_format 84 | with tf.variable_scope("discriminator_labeler",reuse=reuse) as scope: 85 | 86 | x = slim.conv2d(image, hidden_num, 3, 1, activation_fn=tf.nn.elu, 87 | data_format=data_format,scope='conv0') 88 | 89 | prev_channel_num = hidden_num 90 | for idx in range(repeat_num): 91 | channel_num = hidden_num * (idx + 1) 92 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, 93 | data_format=data_format,scope='conv'+str(idx+1)+'a') 94 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, 95 | data_format=data_format,scope='conv'+str(idx+1)+'b') 96 | if idx < repeat_num - 1: 97 | x = slim.conv2d(x, channel_num, 3, 2, activation_fn=tf.nn.elu, 98 | data_format=data_format,scope='conv'+str(idx+1)+'c') 99 | #x = tf.contrib.layers.max_pool2d(x, [2, 2], [2, 2], padding='VALID') 100 | 101 | x = tf.reshape(x, [-1, np.prod([8, 8, channel_num])]) 102 | label_logit = slim.fully_connected(x, output_size, activation_fn=None,scope='proj') 103 | 104 | variables = tf.contrib.framework.get_variables(scope) 105 | return label_logit,variables 106 | 107 | def next(loader): 108 | return loader.next()[0].data.numpy() 109 | 110 | def to_nhwc(image, data_format): 111 | if data_format == 'NCHW': 112 | #Isn't this backward? 113 | new_image = nchw_to_nhwc(image) 114 | else: 115 | new_image = image 116 | return new_image 117 | 118 | def to_nchw_numpy(image): 119 | if image.shape[3] in [1, 3]: 120 | new_image = image.transpose([0, 3, 1, 2]) 121 | else: 122 | new_image = image 123 | return new_image 124 | 125 | def norm_img(image, data_format=None): 126 | image = image/127.5 - 1. 127 | if data_format: 128 | image = to_nhwc(image, data_format) 129 | return image 130 | 131 | def denorm_img(norm, data_format): 132 | return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 133 | 134 | def slerp(val, low, high): 135 | """Code from https://github.com/soumith/dcgan.torch/issues/14""" 136 | omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1)) 137 | so = np.sin(omega) 138 | if so == 0: 139 | return (1.0-val) * low + val * high # L'Hopital's rule/LERP 140 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high 141 | 142 | def int_shape(tensor): 143 | shape = tensor.get_shape().as_list() 144 | return [num if num is not None else -1 for num in shape] 145 | 146 | def get_conv_shape(tensor, data_format): 147 | shape = int_shape(tensor) 148 | # always return [N, H, W, C] 149 | if data_format == 'NCHW': 150 | return [shape[0], shape[2], shape[3], shape[1]] 151 | elif data_format == 'NHWC': 152 | return shape 153 | 154 | def nchw_to_nhwc(x): 155 | return tf.transpose(x, [0, 2, 3, 1]) 156 | 157 | def nhwc_to_nchw(x): 158 | return tf.transpose(x, [0, 3, 1, 2]) 159 | 160 | def reshape(x, h, w, c, data_format): 161 | if data_format == 'NCHW': 162 | x = tf.reshape(x, [-1, c, h, w]) 163 | else: 164 | x = tf.reshape(x, [-1, h, w, c]) 165 | return x 166 | 167 | def resize_nearest_neighbor(x, new_size, data_format): 168 | if data_format == 'NCHW': 169 | x = nchw_to_nhwc(x) 170 | x = tf.image.resize_nearest_neighbor(x, new_size) 171 | x = nhwc_to_nchw(x) 172 | else: 173 | x = tf.image.resize_nearest_neighbor(x, new_size) 174 | return x 175 | 176 | def upscale(x, scale, data_format): 177 | _, h, w, _ = get_conv_shape(x, data_format) 178 | return resize_nearest_neighbor(x, (h*scale, w*scale), data_format) 179 | 180 | 181 | 182 | #https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L168 183 | def average_gradients(tower_grads): 184 | """Calculate the average gradient for each shared variable across all towers. 185 | Note that this function provides a synchronization point across all towers. 186 | Args: 187 | tower_grads: List of lists of (gradient, variable) tuples. 188 | The outer list 189 | is over individual gradients. The inner list is over the gradient 190 | calculation for each tower. 191 | Returns: 192 | List of pairs of (gradient, variable) where the gradient has been averaged across all towers. 193 | """ 194 | average_grads = [] 195 | for grad_and_vars in zip(*tower_grads): 196 | # Note that each grad_and_vars looks like the following: 197 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 198 | grads = [] 199 | for g, _ in grad_and_vars: 200 | # Add 0 dimension to the gradients to represent the tower. 201 | expanded_g = tf.expand_dims(g, 0) 202 | 203 | # Append on a 'tower' dimension which we will average over below. 204 | grads.append(expanded_g) 205 | 206 | # Average over the 'tower' dimension. 207 | grad = tf.concat(axis=0, values=grads) 208 | grad = tf.reduce_mean(grad, 0) 209 | 210 | # Keep in mind that the Variables are redundant because they are shared 211 | # across towers. So .. we will just return the first tower's pointer to the Variable. 212 | v = grad_and_vars[0][1] 213 | grad_and_var = (grad, v) 214 | average_grads.append(grad_and_var) 215 | return average_grads 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /causal_began/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import os 4 | from os import listdir 5 | from os.path import isfile, join 6 | import shutil 7 | import sys 8 | import math 9 | import json 10 | import logging 11 | import numpy as np 12 | from PIL import Image 13 | from datetime import datetime 14 | from tensorflow.core.framework import summary_pb2 15 | 16 | def make_summary(name, val): 17 | return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)]) 18 | 19 | def summary_stats(name,tensor,collections=None,hist=False): 20 | collections=collections or [tf.GraphKeys.SUMMARIES] 21 | ave=tf.reduce_mean(tensor) 22 | std=tf.sqrt(tf.reduce_mean(tf.square(ave-tensor))) 23 | tf.summary.scalar(name+'_ave',ave,collections) 24 | tf.summary.scalar(name+'_std',std,collections) 25 | if hist: 26 | tf.summary.histogram(name+'_hist',tensor,collections) 27 | 28 | 29 | def prepare_dirs_and_logger(config): 30 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 31 | logger = logging.getLogger() 32 | 33 | for hdlr in logger.handlers: 34 | logger.removeHandler(hdlr) 35 | 36 | handler = logging.StreamHandler() 37 | handler.setFormatter(formatter) 38 | 39 | logger.addHandler(handler) 40 | 41 | if config.load_path: 42 | if config.load_path.startswith(config.log_dir): 43 | config.model_dir = config.load_path 44 | else: 45 | if config.load_path.startswith(config.dataset): 46 | config.model_name = config.load_path 47 | else: 48 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 49 | else: 50 | config.model_name = "{}_{}".format(config.dataset, get_time()) 51 | 52 | if not hasattr(config, 'model_dir'): 53 | config.model_dir = os.path.join(config.log_dir, config.model_name) 54 | config.data_path = os.path.join(config.data_dir, config.dataset) 55 | 56 | if not config.load_path: 57 | config.log_code_dir=os.path.join(config.model_dir,'code') 58 | for path in [config.log_dir, config.data_dir, 59 | config.model_dir, config.log_code_dir]: 60 | if not os.path.exists(path): 61 | os.makedirs(path) 62 | 63 | #Copy python code in directory into model_dir/code for future reference: 64 | code_dir=os.path.dirname(os.path.realpath(sys.argv[0])) 65 | model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))] 66 | for f in model_files: 67 | if f.endswith('.py'): 68 | shutil.copy2(f,config.log_code_dir) 69 | 70 | def get_time(): 71 | return datetime.now().strftime("%m%d_%H%M%S") 72 | 73 | def save_config(config): 74 | param_path = os.path.join(config.model_dir, "params.json") 75 | 76 | print("[*] MODEL dir: %s" % config.model_dir) 77 | print("[*] PARAM path: %s" % param_path) 78 | 79 | with open(param_path, 'w') as fp: 80 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 81 | 82 | def get_available_gpus(): 83 | from tensorflow.python.client import device_lib 84 | local_device_protos = device_lib.list_local_devices() 85 | return [x.name for x in local_device_protos if x.device_type=='GPU'] 86 | 87 | def distribute_input_data(data_loader,num_gpu): 88 | ''' 89 | data_loader is a dictionary of tensors that are fed into our model 90 | 91 | This function takes that dictionary of n*batch_size dimension tensors 92 | and breaks it up into n dictionaries with the same key of tensors with 93 | dimension batch_size. One is given to each gpu 94 | ''' 95 | if num_gpu==0: 96 | return {'/cpu:0':data_loader} 97 | 98 | gpus=get_available_gpus() 99 | if num_gpu > len(gpus): 100 | raise ValueError('number of gpus specified={}, more than gpus available={}'.format(num_gpu,len(gpus))) 101 | 102 | gpus=gpus[:num_gpu] 103 | 104 | 105 | data_by_gpu={g:{} for g in gpus} 106 | for key,value in data_loader.items(): 107 | spl_vals=tf.split(value,num_gpu) 108 | for gpu,val in zip(gpus,spl_vals): 109 | data_by_gpu[gpu][key]=val 110 | 111 | return data_by_gpu 112 | 113 | 114 | def rank(array): 115 | return len(array.shape) 116 | 117 | def make_grid(tensor, nrow=8, padding=2, 118 | normalize=False, scale_each=False): 119 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 120 | nmaps = tensor.shape[0] 121 | xmaps = min(nrow, nmaps) 122 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 123 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 124 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 125 | k = 0 126 | for y in range(ymaps): 127 | for x in range(xmaps): 128 | if k >= nmaps: 129 | break 130 | h, h_width = y * height + 1 + padding // 2, height - padding 131 | w, w_width = x * width + 1 + padding // 2, width - padding 132 | 133 | grid[h:h+h_width, w:w+w_width] = tensor[k] 134 | k = k + 1 135 | return grid 136 | 137 | def save_image(tensor, filename, nrow=8, padding=2, 138 | normalize=False, scale_each=False): 139 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 140 | normalize=normalize, scale_each=scale_each) 141 | im = Image.fromarray(ndarr) 142 | im.save(filename) 143 | -------------------------------------------------------------------------------- /causal_controller/ArrayDict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class ArrayDict(object): 3 | 4 | ''' 5 | This is a class for manipulating dictionaries of arrays 6 | or dictionaries of scalars. I find this comes up pretty often when dealing 7 | with tensorflow, because you can pass dictionaries to feed_dict and get 8 | dictionaries back. If you use a smaller batch_size, you then want to 9 | "concatenate" these outputs for each key. 10 | ''' 11 | 12 | def __init__(self): 13 | self.dict={} 14 | def __len__(self): 15 | if len(self.dict)==0: 16 | return 0 17 | else: 18 | return len(self.dict.values()[0]) 19 | def __repr__(self): 20 | return repr(self.dict) 21 | def keys(self): 22 | return self.dict.keys() 23 | def items(self): 24 | return self.dict.items() 25 | 26 | def validate_dict(self,a_dict): 27 | #Check keys 28 | for key,val in self.dict.items(): 29 | if not key in a_dict.keys(): 30 | raise ValueError('key:',key,'was not in a_dict.keys()') 31 | 32 | for key,val in a_dict.items(): 33 | #Check same keys 34 | if not key in self.dict.keys(): 35 | raise ValueError('argument key:',key,'was not in self.dict') 36 | 37 | if isinstance(val,np.ndarray): 38 | #print('ndarray') 39 | my_val=self.dict[key] 40 | if not np.all(val.shape[1:]==my_val.shape[1:]): 41 | raise ValueError('key:',key,'value shape',val.shape,'does\ 42 | not match existing shape',my_val.shape) 43 | else: #scalar 44 | a_val=np.array([[val]])#[1,1]shape array 45 | my_val=self.dict[key] 46 | if not np.all(my_val.shape[1:]==a_val.shape[1:]): 47 | raise ValueError('key:',key,'value shape',val.shape,'does\ 48 | not match existing shape',my_val.shape) 49 | def arr_dict(self,a_dict): 50 | if isinstance(a_dict.values()[0],np.ndarray): 51 | return a_dict 52 | else: 53 | return {k:np.array([[v]]) for k,v in a_dict.items()} 54 | 55 | 56 | def concat(self,a_dict): 57 | if self.dict=={}: 58 | self.dict=self.arr_dict(a_dict)#store interally as array 59 | else: 60 | self.validate_dict(a_dict) 61 | self.dict={k:np.vstack([v,a_dict[k]]) for k,v in self.items()} 62 | 63 | def __getitem__(self,at): 64 | return {k:v[at] for k,v in self.items()} 65 | 66 | #debug, run tests 67 | if __name__=='__main__': 68 | out1=ArrayDict() 69 | d1={'Male':np.ones((3,1)),'Young':2*np.ones((3,1))} 70 | d2={'Male':3,'Young':33} 71 | d3={'Male':4*np.ones((4,1)),'Young':4*np.ones((4,1))} 72 | 73 | out1.concat(d1) 74 | out1.concat(d2) 75 | 76 | out2=ArrayDict() 77 | out2.concat(d2) 78 | out2.concat(d1) 79 | out2.concat(d3) 80 | 81 | -------------------------------------------------------------------------------- /causal_controller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/causal_controller/__init__.py -------------------------------------------------------------------------------- /causal_controller/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | These are the command line parameters that pertain exlusively to the 4 | CausalController. 5 | 6 | ''' 7 | 8 | from __future__ import print_function 9 | import argparse 10 | 11 | def str2bool(v): 12 | #return (v is True) or (v.lower() in ('true', '1')) 13 | return v is True or v.lower() in ('true', '1') 14 | 15 | arg_lists = [] 16 | parser = argparse.ArgumentParser() 17 | 18 | def add_argument_group(name): 19 | arg = parser.add_argument_group(name) 20 | arg_lists.append(arg) 21 | return arg 22 | 23 | #Pretrain network 24 | pretrain_arg=add_argument_group('Pretrain') 25 | pretrain_arg.add_argument('--pt_load_path', type=str, default='') 26 | pretrain_arg.add_argument('--is_pretrain',type=str2bool,default=False, 27 | help='to do pretraining') 28 | #pretrain_arg.add_argument('--only_pretrain', action='store_true', 29 | # help='simply complete pretrain and exit') 30 | 31 | #Used to be an option, but now is solved 32 | #pretrain_arg.add_argument('--pretrain_type',type=str,default='wasserstein',choices=['wasserstein','gan']) 33 | 34 | pretrain_arg.add_argument('--pt_cc_lr',type=float,default=0.00008,# 35 | help='learning rate for causal controller') 36 | pretrain_arg.add_argument('--pt_dcc_lr',type=float,default=0.00008,# 37 | help='learning rate for causal controller') 38 | pretrain_arg.add_argument('--lambda_W',type=float,default=0.1,# 39 | help='penalty for gradient of W critic') 40 | pretrain_arg.add_argument('--n_critic',type=int,default=20,#5 for speed 41 | help='number of critic iterations between gen update') 42 | pretrain_arg.add_argument('--critic_layers',type=int,default=6,#4 usual.8 might help 43 | help='number of layers in the Wasserstein discriminator') 44 | pretrain_arg.add_argument('--critic_hidden_size',type=int,default=15,#10,15 45 | help='hidden_size for critic of discriminator') 46 | 47 | pretrain_arg.add_argument('--min_tvd',type=float,default=0.02, 48 | help='if tvd bs x n_ker x bs : 135 | abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(act, 3) - tf.expand_dims(act_tp, 0)), 2) 136 | eye=tf.expand_dims( tf.eye( tf.shape(abs_dif)[0] ), 1)#bs x 1 x bs 137 | masked=tf.exp(-abs_dif) - eye 138 | f1=tf.reduce_mean( masked, 2) 139 | mb_features = tf.reshape(f1, [-1, 1, 1, n_kernels]) 140 | return conv_cond_concat(image, mb_features) 141 | 142 | ## following is from https://github.com/openai/improved-gan/blob/master/imagenet/discriminator.py#L88 143 | #def add_minibatch_features(image,df_dim,batch_size): 144 | # shape = image.get_shape().as_list() 145 | # dim = np.prod(shape[1:]) # dim = prod(9,2) = 18 146 | # h_mb0 = lrelu(conv2d(image, df_dim, name='d_mb0_conv')) 147 | # h_mb1 = conv2d(h_mb0, df_dim, name='d_mbh1_conv') 148 | # 149 | # dims=h_mb1.get_shape().as_list() 150 | # conv_dims=np.prod(dims[1:]) 151 | # 152 | # image_ = tf.reshape(h_mb1, tf.stack([-1, conv_dims])) 153 | # #image_ = tf.reshape(h_mb1, tf.stack([batch_size, -1])) 154 | # 155 | # n_kernels = 300 156 | # dim_per_kernel = 50 157 | # x = linear(image_, n_kernels * dim_per_kernel,'d_mbLinear') 158 | # activation = tf.reshape(x, (batch_size, n_kernels, dim_per_kernel)) 159 | # big = np.zeros((batch_size, batch_size), dtype='float32') 160 | # big += np.eye(batch_size) 161 | # big = tf.expand_dims(big, 1) 162 | # abs_dif = tf.reduce_sum(tf.abs(tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)), 2) 163 | # mask = 1. - big 164 | # masked = tf.exp(-abs_dif) * mask 165 | # f1 = tf.reduce_sum(masked, 2) / tf.reduce_sum(mask) 166 | # mb_features = tf.reshape(f1, [batch_size, 1, 1, n_kernels]) 167 | # return conv_cond_concat(image, mb_features) 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /causal_dcgan/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import json 7 | import random 8 | import pprint 9 | import scipy.misc 10 | import numpy as np 11 | from time import gmtime, strftime 12 | from six.moves import xrange 13 | import os 14 | 15 | pp = pprint.PrettyPrinter() 16 | 17 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 18 | 19 | 20 | def get_image(image_path, input_height, input_width, 21 | resize_height=64, resize_width=64, 22 | is_crop=True, is_grayscale=False): 23 | image = imread(image_path, is_grayscale) 24 | return transform(image, input_height, input_width, 25 | resize_height, resize_width, is_crop) 26 | 27 | def save_images(images, size, image_path): 28 | return imsave(inverse_transform(images), size, image_path) 29 | 30 | def imread(path, is_grayscale = False): 31 | if (is_grayscale): 32 | return scipy.misc.imread(path, flatten = True).astype(np.float) 33 | else: 34 | return scipy.misc.imread(path).astype(np.float) 35 | 36 | def merge_images(images, size): 37 | return inverse_transform(images) 38 | 39 | def merge(images, size): 40 | h, w = images.shape[1], images.shape[2] 41 | img = np.zeros((h * size[0], w * size[1], 3)) 42 | for idx, image in enumerate(images): 43 | i = idx % size[1] 44 | j = idx // size[1] 45 | img[j*h:j*h+h, i*w:i*w+w, :] = image 46 | return img 47 | 48 | def imsave(images, size, path): 49 | return scipy.misc.imsave(path, merge(images, size)) 50 | 51 | def center_crop(x, crop_h, crop_w, 52 | resize_h=64, resize_w=64): 53 | if crop_w is None: 54 | crop_w = crop_h 55 | h, w = x.shape[:2] 56 | j = int(round((h - crop_h)/2.)) 57 | i = int(round((w - crop_w)/2.)) 58 | return scipy.misc.imresize( 59 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 60 | 61 | def transform(image, input_height, input_width, 62 | resize_height=64, resize_width=64, is_crop=True): 63 | if is_crop: 64 | cropped_image = center_crop( 65 | image, input_height, input_width, 66 | resize_height, resize_width) 67 | else: 68 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 69 | return np.array(cropped_image)/127.5 - 1. 70 | 71 | def inverse_transform(images): 72 | return (images+1.)/2. 73 | 74 | def to_json(output_path, *layers): 75 | with open(output_path, "w") as layer_f: 76 | lines = "" 77 | for w, b, bn in layers: 78 | layer_idx = w.name.split('/')[0].split('h')[1] 79 | 80 | B = b.eval() 81 | 82 | if "lin/" in w.name: 83 | W = w.eval() 84 | depth = W.shape[1] 85 | else: 86 | W = np.rollaxis(w.eval(), 2, 0) 87 | depth = W.shape[0] 88 | 89 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 90 | if bn != None: 91 | gamma = bn.gamma.eval() 92 | beta = bn.beta.eval() 93 | 94 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 95 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 96 | else: 97 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 98 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 99 | 100 | if "lin/" in w.name: 101 | fs = [] 102 | for w in W.T: 103 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 104 | 105 | lines += """ 106 | var layer_%s = { 107 | "layer_type": "fc", 108 | "sy": 1, "sx": 1, 109 | "out_sx": 1, "out_sy": 1, 110 | "stride": 1, "pad": 0, 111 | "out_depth": %s, "in_depth": %s, 112 | "biases": %s, 113 | "gamma": %s, 114 | "beta": %s, 115 | "filters": %s 116 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 117 | else: 118 | fs = [] 119 | for w_ in W: 120 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 121 | 122 | lines += """ 123 | var layer_%s = { 124 | "layer_type": "deconv", 125 | "sy": 5, "sx": 5, 126 | "out_sx": %s, "out_sy": %s, 127 | "stride": 2, "pad": 1, 128 | "out_depth": %s, "in_depth": %s, 129 | "biases": %s, 130 | "gamma": %s, 131 | "beta": %s, 132 | "filters": %s 133 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 134 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 135 | layer_f.write(" ".join(lines.replace("'","").split())) 136 | 137 | def make_gif(images, fname, duration=2, true_image=False): 138 | import moviepy.editor as mpy 139 | 140 | def make_frame(t): 141 | try: 142 | x = images[int(len(images)/duration*t)] 143 | except: 144 | x = images[-1] 145 | 146 | if true_image: 147 | return x.astype(np.uint8) 148 | else: 149 | return ((x+1)/2*255).astype(np.uint8) 150 | 151 | clip = mpy.VideoClip(make_frame, duration=duration) 152 | clip.write_gif(fname, fps = len(images) / duration) 153 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | 4 | def str2bool(v): 5 | #return (v is True) or (v.lower() in ('true', '1')) 6 | return v is True or v.lower() in ('true', '1') 7 | 8 | arg_lists = [] 9 | parser = argparse.ArgumentParser() 10 | 11 | def add_argument_group(name): 12 | arg = parser.add_argument_group(name) 13 | arg_lists.append(arg) 14 | return arg 15 | 16 | # Data 17 | data_arg = add_argument_group('Data') 18 | #data_arg.add_argument('--batch_size', type=int, default=16)#default set elsewhere 19 | data_arg.add_argument('--causal_model', type=str, 20 | help='''Matches the argument with a key in ./causal_graph.py and sets the graph attribute of cc_config to be a list of lists defining the causal graph''') 21 | data_arg.add_argument('--data_dir', type=str, default='data') 22 | data_arg.add_argument('--dataset', type=str, default='celebA') 23 | data_arg.add_argument('--do_shuffle', type=str2bool, default=True)#never used 24 | data_arg.add_argument('--input_scale_size', type=int, default=64, 25 | help='input image will be resized with the given value as width and height') 26 | data_arg.add_argument('--is_crop', type=str2bool, default='True') 27 | data_arg.add_argument('--grayscale', type=str2bool, default=False)#never used 28 | data_arg.add_argument('--split', type=str, default='train')#never used 29 | data_arg.add_argument('--num_worker', type=int, default=24, 30 | help='number of threads to use for loading and preprocessing data') 31 | data_arg.add_argument('--resize_method',type=str,default='AREA',choices=['AREA','BILINEAR','BICUBIC','NEAREST_NEIGHBOR'], 32 | help='''methods to resize image to 64x64. AREA seems to work 33 | best, possibly some scipy methods could work better. It 34 | wasn't clear to me why the results should be so different''') 35 | 36 | 37 | # Training / test parameters 38 | train_arg = add_argument_group('Training') 39 | 40 | 41 | train_arg.add_argument('--build_train', type=str2bool, default=False, 42 | help='''You may want to build all the components for 43 | training, without doing any training right away. This is 44 | for that. This arg is effectively True when is_train=True''') 45 | train_arg.add_argument('--build_pretrain', type=str2bool, default=False, 46 | help='''You may want to build all the components for 47 | training, without doing any training right away. This is 48 | for that. This arg is effectively True when is_pretrain=True''') 49 | 50 | 51 | train_arg.add_argument('--model_type',type=str,default='',choices=['dcgan','began'], 52 | help='''Which model to use. If the argument is not 53 | passed, only causal_controller is built. This overrides 54 | is_train=True, since no image model to train''') 55 | train_arg.add_argument('--use_gpu', type=str2bool, default=True) 56 | train_arg.add_argument('--num_gpu', type=int, default=1, 57 | help='specify 0 for cpu. If k specified, will default to\ 58 | first k of n detected. If use_gpu=True but num_gpu not\ 59 | specified will default to 1') 60 | 61 | # Misc 62 | misc_arg = add_argument_group('Misc') 63 | #misc_arg.add_argument('--build_all', type=str2bool, default=False, 64 | # help='''normally specifying is_pretrain=False will cause 65 | # the pretraining components not to be built and likewise 66 | # with is_train=False only the pretrain compoenent will 67 | # (possibly) be built. This is here as a debug helper to 68 | # enable building out the whole model without doing any 69 | # training''') 70 | 71 | misc_arg.add_argument('--descrip', type=str, default='',help=''' 72 | Only use this when creating a new model. New model folder names 73 | are generated automatically by using the time-date. Then 74 | you cant rename them while the model is running. If 75 | provided, this is a short string that appends to the end 76 | of a model folder name to help keep track of what the 77 | contents of that folder were without getting into the 78 | content of that folder. No weird characters''') 79 | 80 | misc_arg.add_argument('--dry_run', action='store_true',help='''Build and load 81 | the model and all the specified components, but don't actually do 82 | any pretraining/training etc. This overrides 83 | --is_pretrain, --is_train. This is mostly used for just 84 | bringing the model into the workspace if you say wanted 85 | to manipulated it in ipython''') 86 | 87 | misc_arg.add_argument('--load_path', type=str, default='', 88 | help='''This is a "global" load path. You can simply pass 89 | the model_dir of the whatever run, and all the variables 90 | (dcgan/began and causal_controller both). If you want to 91 | just load one component: for example, the pretrained part 92 | of a previous model, use pt_load_path from the 93 | causal_controller.config section''') 94 | 95 | misc_arg.add_argument('--log_step', type=int, default=100, 96 | help='''this is used for generic summaries that are common 97 | to both models. Use model specific config files for 98 | logging done within train_step''') 99 | #misc_arg.add_argument('--save_step', type=int, default=5000) 100 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 101 | misc_arg.add_argument('--log_dir', type=str, default='logs', help='''where to store model and model results. Do not put a leading "./" out front''') 102 | 103 | #misc_arg.add_argument('--sample_per_image', type=int, default=64, 104 | # help='# of sample per image during test sample generation') 105 | 106 | misc_arg.add_argument('--seed', type=int, default=22,help= 107 | '''Not working right now: TF seed should be fixed to make sure exogenous noise for each causal node is fixed also''') 108 | 109 | #Doesn't do anything atm 110 | #misc_arg.add_argument('--visualize', action='store_true') 111 | 112 | 113 | def gpu_logic(config): 114 | 115 | #consistency between use_gpu and num_gpu 116 | if config.num_gpu>0: 117 | config.use_gpu=True 118 | else: 119 | config.use_gpu=False 120 | # if config.use_gpu and config.num_gpu==0: 121 | # config.num_gpu=1 122 | return config 123 | 124 | 125 | def get_config(): 126 | config, unparsed = parser.parse_known_args() 127 | config=gpu_logic(config) 128 | config.num_devices=max(1,config.num_gpu)#that are used in backprop 129 | 130 | 131 | #Just for BEGAN: 132 | ##this has to respect gpu/cpu 133 | ##data_format = 'NCHW' 134 | #if config.use_gpu: 135 | # data_format = 'NCHW' 136 | #else: 137 | # data_format = 'NHWC' 138 | #setattr(config, 'data_format', data_format) 139 | 140 | print('Loaded ./config.py') 141 | 142 | return config, unparsed 143 | 144 | if __name__=='__main__': 145 | #for debug of config 146 | config, unparsed = get_config() 147 | 148 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from PIL import Image 5 | from glob import glob 6 | import tensorflow as tf 7 | 8 | from IPython.core import debugger 9 | debug = debugger.Pdb().set_trace 10 | 11 | 12 | def logodds(p): 13 | return np.log(p/(1.-p)) 14 | 15 | class DataLoader(object): 16 | '''This loads the image and the labels through a tensorflow queue. 17 | All of the labels are loaded regardless of what is specified in graph, 18 | because this model is gpu throttled anyway so there shouldn't be any 19 | overhead 20 | 21 | For multiple gpu, the strategy here is to have 1 queue with 2xbatch_size 22 | then use tf.split within trainer.train() 23 | ''' 24 | def __init__(self,label_names,config): 25 | self.label_names=label_names 26 | self.config=config 27 | self.scale_size=config.input_scale_size 28 | #self.data_format=config.data_format 29 | self.split=config.split 30 | self.do_shuffle=config.do_shuffle 31 | self.num_worker=config.num_worker 32 | self.is_crop=config.is_crop 33 | self.is_grayscale=config.grayscale 34 | 35 | attr_file= glob("{}/*.{}".format(config.data_path, 'txt'))[0] 36 | setattr(config,'attr_file',attr_file) 37 | 38 | attributes = pd.read_csv(config.attr_file,delim_whitespace=True) #+-1 39 | #Store all labels for reference 40 | self.all_attr= 0.5*(attributes+1)# attributes is {0,1} 41 | self.all_label_means=self.all_attr.mean() 42 | 43 | #but only return desired labels in queues 44 | self.attr=self.all_attr[label_names] 45 | self.label_means=self.attr.mean()# attributes is 0,1 46 | 47 | self.image_dir=os.path.join(config.data_path,'images') 48 | self.filenames=[os.path.join(self.image_dir,j) for j in self.attr.index] 49 | 50 | self.num_examples_per_epoch=len(self.filenames) 51 | self.min_fraction_of_examples_in_queue=0.001#go faster during debug 52 | #self.min_fraction_of_examples_in_queue=0.01 53 | self.min_queue_examples=int(self.num_examples_per_epoch*self.min_fraction_of_examples_in_queue) 54 | 55 | 56 | def get_label_queue(self,batch_size): 57 | tf_labels = tf.convert_to_tensor(self.attr.values, dtype=tf.uint8)#0,1 58 | 59 | with tf.name_scope('label_queue'): 60 | uint_label=tf.train.slice_input_producer([tf_labels])[0] 61 | label=tf.to_float(uint_label) 62 | 63 | #All labels, not just those in causal_model 64 | dict_data={sl:tl for sl,tl in 65 | zip(self.label_names,tf.split(label,len(self.label_names)))} 66 | 67 | 68 | num_preprocess_threads = max(self.num_worker-3,1) 69 | 70 | data_batch = tf.train.shuffle_batch( 71 | dict_data, 72 | batch_size=batch_size, 73 | num_threads=num_preprocess_threads, 74 | capacity=self.min_queue_examples + 3 * batch_size, 75 | min_after_dequeue=self.min_queue_examples, 76 | ) 77 | 78 | return data_batch 79 | 80 | def get_data_queue(self,batch_size): 81 | image_files = tf.convert_to_tensor(self.filenames, dtype=tf.string) 82 | tf_labels = tf.convert_to_tensor(self.attr.values, dtype=tf.uint8) 83 | 84 | with tf.name_scope('filename_queue'): 85 | #must be list 86 | str_queue=tf.train.slice_input_producer([image_files,tf_labels]) 87 | img_filename, uint_label= str_queue 88 | 89 | img_contents=tf.read_file(img_filename) 90 | image = tf.image.decode_jpeg(img_contents, channels=3) 91 | 92 | image=tf.cast(image,dtype=tf.float32) 93 | if self.config.is_crop:#use dcgan cropping 94 | #dcgan center-crops input to 108x108, outputs 64x64 #centrally crops it #We emulate that here 95 | image=tf.image.resize_image_with_crop_or_pad(image,108,108) 96 | #image=tf.image.resize_bilinear(image,[scale_size,scale_size])#must be 4D 97 | 98 | resize_method=getattr(tf.image.ResizeMethod,self.config.resize_method) 99 | image=tf.image.resize_images(image,[self.scale_size,self.scale_size], 100 | method=resize_method) 101 | #Some dataset enlargement. Might as well. 102 | image=tf.image.random_flip_left_right(image) 103 | 104 | ##carpedm-began crops to 128x128 starting at (50,25), then resizes to 64x64 105 | #image=tf.image.crop_to_bounding_box(image, 50, 25, 128, 128) 106 | #image=tf.image.resize_nearest_neighbor(image, [scale_size, scale_size]) 107 | 108 | tf.summary.image('real_image',tf.expand_dims(image,0)) 109 | 110 | 111 | 112 | label=tf.to_float(uint_label) 113 | #Creates a dictionary {'Male',male_tensor, 'Young',young_tensor} etc.. 114 | dict_data={sl:tl for sl,tl in 115 | zip(self.label_names,tf.split(label,len(self.label_names)))} 116 | assert not 'x' in dict_data.keys()#don't have a label named "x" 117 | dict_data['x']=image 118 | 119 | print ('Filling queue with %d Celeb images before starting to train. ' 120 | 'I don\'t know how long this will take' % self.min_queue_examples) 121 | num_preprocess_threads = max(self.num_worker,1) 122 | 123 | data_batch = tf.train.shuffle_batch( 124 | dict_data, 125 | batch_size=batch_size, 126 | num_threads=num_preprocess_threads, 127 | capacity=self.min_queue_examples + 3 * batch_size, 128 | min_after_dequeue=self.min_queue_examples, 129 | ) 130 | return data_batch 131 | 132 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of 3 | https://github.com/carpedm20/BEGAN-tensorflow/blob/master/download.py 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import zipfile 8 | import requests 9 | import subprocess 10 | from tqdm import tqdm 11 | from collections import OrderedDict 12 | 13 | def download_file_from_google_drive(id, destination): 14 | URL = "https://docs.google.com/uc?export=download" 15 | session = requests.Session() 16 | 17 | response = session.get(URL, params={ 'id': id }, stream=True) 18 | token = get_confirm_token(response) 19 | 20 | if token: 21 | params = { 'id' : id, 'confirm' : token } 22 | response = session.get(URL, params=params, stream=True) 23 | 24 | save_response_content(response, destination) 25 | 26 | def get_confirm_token(response): 27 | for key, value in response.cookies.items(): 28 | if key.startswith('download_warning'): 29 | return value 30 | return None 31 | 32 | def save_response_content(response, destination, chunk_size=32*1024): 33 | total_size = int(response.headers.get('content-length', 0)) 34 | with open(destination, "wb") as f: 35 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 36 | unit='B', unit_scale=True, desc=destination): 37 | if chunk: # filter out keep-alive new chunks 38 | f.write(chunk) 39 | 40 | def unzip(filepath): 41 | print("Extracting: " + filepath) 42 | base_path = os.path.dirname(filepath) 43 | with zipfile.ZipFile(filepath) as zf: 44 | zf.extractall(base_path) 45 | os.remove(filepath) 46 | 47 | def download_celeb_a(base_path): 48 | data_path = os.path.join(base_path, 'celebA') 49 | images_path = os.path.join(data_path, 'images') 50 | if os.path.exists(data_path): 51 | print('[!] Found celeb-A - skip') 52 | return 53 | 54 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 55 | save_path = os.path.join(base_path, filename) 56 | 57 | if os.path.exists(save_path): 58 | print('[*] {} already exists'.format(save_path)) 59 | else: 60 | download_file_from_google_drive(drive_id, save_path) 61 | 62 | zip_dir = '' 63 | with zipfile.ZipFile(save_path) as zf: 64 | zip_dir = zf.namelist()[0] 65 | zf.extractall(base_path) 66 | if not os.path.exists(data_path): 67 | os.mkdir(data_path) 68 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 69 | os.remove(save_path) 70 | 71 | download_attr_file(data_path) 72 | 73 | 74 | def download_attr_file(data_path): 75 | attr_gdID='0B7EVK8r0v71pblRyaVFSWGxPY0U' 76 | attr_fname=os.path.join(data_path,'list_attr_celeba.txt') 77 | download_file_from_google_drive(attr_gdID, attr_fname) 78 | delete_top_line(attr_fname)#make pandas readable 79 | #Top line was just an integer saying how many samples there were 80 | 81 | def prepare_data_dir(path = './data'): 82 | if not os.path.exists(path): 83 | os.mkdir(path) 84 | 85 | # check, if file exists, make link 86 | def check_link(in_dir, basename, out_dir): 87 | in_file = os.path.join(in_dir, basename) 88 | if os.path.exists(in_file): 89 | link_file = os.path.join(out_dir, basename) 90 | rel_link = os.path.relpath(in_file, out_dir) 91 | os.symlink(rel_link, link_file) 92 | 93 | def add_splits(base_path): 94 | data_path = os.path.join(base_path, 'celebA') 95 | images_path = os.path.join(data_path, 'images') 96 | train_dir = os.path.join(data_path, 'splits', 'train') 97 | valid_dir = os.path.join(data_path, 'splits', 'valid') 98 | test_dir = os.path.join(data_path, 'splits', 'test') 99 | if not os.path.exists(train_dir): 100 | os.makedirs(train_dir) 101 | if not os.path.exists(valid_dir): 102 | os.makedirs(valid_dir) 103 | if not os.path.exists(test_dir): 104 | os.makedirs(test_dir) 105 | 106 | # these constants based on the standard celebA splits 107 | NUM_EXAMPLES = 202599 108 | TRAIN_STOP = 162770 109 | VALID_STOP = 182637 110 | 111 | for i in range(0, TRAIN_STOP): 112 | basename = "{:06d}.jpg".format(i+1) 113 | check_link(images_path, basename, train_dir) 114 | for i in range(TRAIN_STOP, VALID_STOP): 115 | basename = "{:06d}.jpg".format(i+1) 116 | check_link(images_path, basename, valid_dir) 117 | for i in range(VALID_STOP, NUM_EXAMPLES): 118 | basename = "{:06d}.jpg".format(i+1) 119 | check_link(images_path, basename, test_dir) 120 | 121 | def delete_top_line(txt_fname): 122 | lines=open(txt_fname,'r').readlines() 123 | open(txt_fname,'w').writelines(lines[1:]) 124 | 125 | if __name__ == '__main__': 126 | base_path = './data' 127 | prepare_data_dir() 128 | download_celeb_a(base_path) 129 | add_splits(base_path) 130 | -------------------------------------------------------------------------------- /figure_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/figure_scripts/__init__.py -------------------------------------------------------------------------------- /figure_scripts/distributions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import scipy.misc 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import trange,tqdm 8 | import pandas as pd 9 | from itertools import combinations, product 10 | import sys 11 | from utils import save_figure_images,make_sample_dir,guess_model_step 12 | from sample import get_joint,sample 13 | 14 | 15 | 16 | def get_pdf(model, do_dict=None,cond_dict=None,name='',N=6400,return_discrete=True,step=''): 17 | str_step=str(step) or guess_model_step(model) 18 | 19 | joint=get_joint(model,int_do_dict=do_dict,int_cond_dict=cond_dict,N=N,return_discrete=return_discrete) 20 | 21 | sample_dir=make_sample_dir(model) 22 | 23 | if name: 24 | name+='_' 25 | f_pdf=os.path.join(sample_dir,str_step+name+'dist'+'.csv') 26 | 27 | pdf=pd.DataFrame.from_dict({k:val.mean() for k,val in joint.items()}) 28 | 29 | #print 'get pdf cond_dict:',cond_dict 30 | if not do_dict and not cond_dict: 31 | data=model.attr.mean() 32 | pdf['data']=data 33 | if not do_dict and cond_dict: 34 | bool_cond=np.logical_and.reduce([model.attr[k]==v for k,v in cond_dict.items()]) 35 | attr=model.attr[bool_cond] 36 | pdf['data']=attr.mean() 37 | 38 | print 'Writing to file',f_pdf 39 | pdf.to_csv(f_pdf) 40 | 41 | return pdf 42 | 43 | 44 | TINY=1e-6 45 | def get_interv_table(model,intrv=True): 46 | 47 | n_batches=25 48 | table_outputs=[] 49 | d_vals=np.linspace(TINY,0.6,n_batches) 50 | for name in model.cc.node_names: 51 | outputs=[] 52 | for d_val in d_vals: 53 | do_dict={model.cc.node_dict[name].label_logit : d_val*np.ones((model.batch_size,1))} 54 | outputs.append(model.sess.run(model.fake_labels,do_dict)) 55 | 56 | out=np.vstack(outputs) 57 | table_outputs.append(out) 58 | 59 | table=np.stack(table_outputs,axis=2) 60 | 61 | np.mean(np.round(table),axis=0) 62 | 63 | return table 64 | 65 | #dT=pd.DataFrame(index=p_names, data=T, columns=do_names) 66 | #T=np.mean(np.round(table),axis=0) 67 | #table=get_interv_table(model) 68 | 69 | 70 | 71 | def record_interventional(model,step=''): 72 | ''' 73 | designed for truncated exponential noise. 74 | For each node that could be intervened on, 75 | sample interventions from the continuous 76 | distribution that discrete intervention 77 | corresponds to. Collect the joint and output 78 | to a csv file 79 | ''' 80 | make_sample_dir(model) 81 | 82 | str_step=str(step) 83 | if str_step=='': 84 | if hasattr(model,'step'): 85 | str_step=str( model.sess.run(model.step) )+'_' 86 | 87 | m=20 88 | do =lambda val: np.linspace(0,val*0.8,m) 89 | for name in model.cc.node_names: 90 | for int_val,intv in enumerate([do(-1), do(+1)]): 91 | do_dict={name:intv} 92 | 93 | joint=get_joint(model, do_dict=None, N=5,return_discrete=True,step='') 94 | 95 | lab_df=pd.DataFrame(data=joint['g_fake_label']) 96 | dfl_df=pd.DataFrame(data=joint['d_fake_label']) 97 | 98 | lab_fname=str_step+str(name)+str(int_val)+'.csv' 99 | dfl_fname=str_step+str(name)+str(int_val)+'.csv' 100 | 101 | lab_df.to_csv(lab_fname) 102 | dfl_df.to_csv(dfl_fname) 103 | 104 | #with open(dfl_xtab_fn,'w') as dlf_f, open(lab_xtab_fn,'w') as lab_f: 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /figure_scripts/encode.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function 2 | import tensorflow as tf 3 | #import scipy 4 | import scipy.misc 5 | import numpy as np 6 | from tqdm import trange 7 | import os 8 | import pandas as pd 9 | from itertools import combinations 10 | import sys 11 | from Causal_controller import * 12 | from began.models import GeneratorCNN, DiscriminatorCNN 13 | from utils import to_nhwc,read_prepared_uint8_image,make_encode_dir 14 | 15 | from utils import transform, inverse_transform #dcgan img norm 16 | from utils import norm_img, denorm_img #began norm image 17 | 18 | def var_like_z(z_ten,name): 19 | z_dim=z_ten.get_shape().as_list()[-1] 20 | return tf.get_variable(name,shape=(1,z_dim)) 21 | def noise_like_z(z_ten,name): 22 | z_dim=z_ten.get_shape().as_list()[-1] 23 | noise=tf.random_uniform([1,z_dim],minval=-1.,maxval=1.,) 24 | return noise 25 | 26 | 27 | class Encoder: 28 | ''' 29 | This is a class where you pass a model, and an image file 30 | and it creates more tensorflow variables, along with 31 | surrounding saving and summary functionality for encoding 32 | that image back into the hidden space using gradient descent 33 | ''' 34 | model_name = "Encode.model" 35 | model_type= 'encoder' 36 | summ_col='encoder_summaries' 37 | def __init__(self,model,image,image_name=None,max_tr_steps=50000,load_path=''): 38 | ''' 39 | image is assumed to be a path to a precropped 64x64x3 uint8 image 40 | ''' 41 | 42 | #Some hardcoded defaults here 43 | self.log_step=500 44 | self.lr=0.0005 45 | self.max_tr_steps=max_tr_steps 46 | 47 | self.model=model 48 | self.load_path=load_path 49 | 50 | self.image_name=image_name or os.path.basename(image).replace('.','_') 51 | self.encode_dir=make_encode_dir(model,self.image_name) 52 | self.model_dir=self.encode_dir#different from self.model.model_dir 53 | self.save_dir=os.path.join(self.model_dir,'save') 54 | 55 | self.sess=self.model.sess#session should already be in progress 56 | 57 | if model.model_type =='dcgan': 58 | self.data_format='NHWC'#Don't change 59 | elif model.model_type == 'began': 60 | self.data_format=model.data_format#'NCHW' if gpu 61 | else: 62 | raise Exception('Should not happen. model_type=',model.model_type) 63 | 64 | #Notation: 65 | #self.uint_x/G ; 3D [0,255] 66 | #self.x/G ; 4D [-1,1] 67 | self.uint_x=read_prepared_uint8_image(image)#x is [0,255] 68 | 69 | print('Read image shape',self.uint_x.shape) 70 | self.x=norm_img(np.expand_dims(self.uint_x,0),self.data_format)#bs=1 71 | #self.x=norm_img(tf.expand_dims(self.uint_x,0),self.data_format)#bs=1 72 | print('Shape after norm:',self.x.get_shape().as_list()) 73 | 74 | 75 | ##All variables created under encoder have uniform init 76 | vs=tf.variable_scope('encoder', 77 | initializer=tf.random_uniform_initializer(minval=-1.,maxval=1.), 78 | dtype=tf.float32) 79 | 80 | 81 | with vs as scope: 82 | #avoid creating adams params 83 | optimizer = tf.train.GradientDescentOptimizer 84 | #optimizer = tf.train.AdamOptimizer 85 | self.g_optimizer = optimizer(self.lr) 86 | 87 | encode_var={n.name:var_like_z(n.z,n.name) for n in model.cc.nodes} 88 | encode_var['gen']=var_like_z(model.z_gen,'gen') 89 | print 'encode variables created' 90 | self.train_var = tf.contrib.framework.get_variables(scope) 91 | self.step=tf.Variable(0,name='step') 92 | self.var = tf.contrib.framework.get_variables(scope) 93 | 94 | #all encode vars created by now 95 | self.saver = tf.train.Saver(var_list=self.var) 96 | print('Summaries will be written to ',self.model_dir) 97 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 98 | 99 | #load or initialize enmodel variables 100 | self.init() 101 | 102 | if model.model_type =='dcgan': 103 | self.cc=CausalController(graph=model.graph, input_dict=encode_var, reuse=True) 104 | self.fake_labels_logits= tf.concat( self.cc.list_label_logits(),-1 ) 105 | self.z_fake_labels=self.fake_labels_logits 106 | #self.z_gen = noise_like_z( self.model.z_gen,'en_z_gen') 107 | self.z_gen=encode_var['gen'] 108 | self.z= tf.concat( [self.z_gen, self.z_fake_labels], axis=1 , name='z') 109 | 110 | self.G=model.generator( self.z , bs=1, reuse=True) 111 | 112 | elif model.model_type == 'began': 113 | with tf.variable_scope('tower'):#reproduce variable scope 114 | self.cc=CausalController(graph=model.graph, input_dict=encode_var, reuse=True) 115 | 116 | self.fake_labels= tf.concat( self.cc.list_labels(),-1 ) 117 | self.fake_labels_logits= tf.concat( self.cc.list_label_logits(),-1 ) 118 | #self.z_gen = noise_like_z( self.model.z_gen,'en_z_gen') 119 | self.z_gen=encode_var['gen'] 120 | self.z= tf.concat( [self.fake_labels, self.z_gen],axis=-1,name='z') 121 | 122 | self.G,_ = GeneratorCNN( 123 | self.z, model.conv_hidden_num, model.channel, 124 | model.repeat_num, model.data_format,reuse=True) 125 | 126 | d_out, self.D_zG, self.D_var = DiscriminatorCNN( 127 | self.G, model.channel, model.z_num, 128 | model.repeat_num, model.conv_hidden_num, 129 | model.data_format,reuse=True) 130 | 131 | _ , self.D_zX, _ = DiscriminatorCNN( 132 | self.x, model.channel, model.z_num, 133 | model.repeat_num, model.conv_hidden_num, 134 | model.data_format,reuse=True) 135 | self.norm_AE_G=d_out 136 | 137 | #AE_G, AE_x = tf.split(d_out, 2) 138 | self.AE_G=denorm_img(self.norm_AE_G, model.data_format) 139 | self.aeg_sum=tf.summary.image('encoder/AE_G',self.AE_G) 140 | 141 | node_summaries=[] 142 | for node in self.cc.nodes: 143 | with tf.name_scope(node.name): 144 | ave_label=tf.reduce_mean(node.label) 145 | node_summaries.append(tf.summary.scalar('ave',ave_label)) 146 | 147 | 148 | #unclear how scope with adam param works 149 | #with tf.variable_scope('encoderGD') as scope: 150 | 151 | #use L1 loss 152 | #self.g_loss_image = tf.reduce_mean(tf.abs(self.x - self.G)) 153 | 154 | #use L2 loss 155 | #self.g_loss_image = tf.reduce_mean(tf.square(self.x - self.G)) 156 | 157 | #use autoencoder reconstruction loss #3.1.1 series 158 | #self.g_loss_image = tf.reduce_mean(tf.abs(self.x - self.norm_AE_G)) 159 | 160 | #use L1 in autoencoded space# 3.2 161 | self.g_loss_image = tf.reduce_mean(tf.abs(self.D_zX - self.D_zG)) 162 | 163 | g_loss_sum=tf.summary.scalar( 'encoder/g_loss_image',\ 164 | self.g_loss_image,self.summ_col) 165 | 166 | self.g_loss= self.g_loss_image 167 | self.train_op=self.g_optimizer.minimize(self.g_loss, 168 | var_list=self.train_var,global_step=self.step) 169 | 170 | self.uint_G=tf.squeeze(denorm_img( self.G ,self.data_format))#3D[0,255] 171 | gimg_sum=tf.summary.image( 'encoder/Reconstruct',tf.stack([self.uint_x,self.uint_G]),\ 172 | max_outputs=2,collections=self.summ_col) 173 | 174 | #self.summary_op=tf.summary.merge_all(self.summ_col) 175 | #self.summary_op=tf.summary.merge_all(self.summ_col) 176 | 177 | if model.model_type=='dcgan': 178 | self.summary_op=tf.summary.merge([g_loss_sum,gimg_sum]+node_summaries) 179 | elif model.model_type=='began': 180 | self.summary_op=tf.summary.merge([g_loss_sum,gimg_sum,self.aeg_sum]+node_summaries) 181 | 182 | 183 | #print 'encoder summaries:',self.summ_col 184 | #print 'encoder summaries:',tf.get_collection(self.summ_col) 185 | 186 | 187 | def init(self): 188 | if self.load_path: 189 | print 'Attempting to load directly from path:', 190 | print self.load_path 191 | self.saver.restore(self.sess,self.load_path) 192 | else: 193 | print 'New ENCODE Model..init new Z parameters' 194 | init=tf.variables_initializer(var_list=self.var) 195 | print 'Initializing following variables:' 196 | for v in self.var: 197 | print v.name, v.get_shape().as_list() 198 | 199 | self.model.sess.run(init) 200 | 201 | def save(self, step=None): 202 | if step is None: 203 | step=self.sess.run(self.step) 204 | 205 | if not os.path.exists(self.save_dir): 206 | print 'Creating Directory:',self.save_dir 207 | os.makedirs(self.save_dir) 208 | savefile=os.path.join(self.save_dir,self.model_name) 209 | print 'Saving file:',savefile 210 | self.saver.save(self.model.sess,savefile,global_step=step) 211 | 212 | def train(self, n_step=None): 213 | max_step=n_step or self.max_tr_steps 214 | 215 | if False:#debug 216 | print 'a' 217 | self.sess.run(self.train_op) 218 | print 'b' 219 | self.sess.run(self.summary_op) 220 | print 'c' 221 | self.sess.run(self.g_loss) 222 | print 'd' 223 | 224 | print 'max_step;',max_step 225 | for counter in trange(max_step): 226 | 227 | fetch_dict = { 228 | "train_op": self.train_op, 229 | } 230 | if counter%self.log_step==0: 231 | fetch_dict.update({ 232 | "summary": self.summary_op, 233 | "g_loss": self.g_loss, 234 | "global_step":self.step 235 | }) 236 | 237 | result = self.sess.run(fetch_dict) 238 | 239 | if counter % self.log_step == 0: 240 | g_loss=result['g_loss'] 241 | step=result['global_step'] 242 | self.summary_writer.add_summary(result['summary'],step) 243 | self.summary_writer.flush() 244 | 245 | print("[{}/{}] Reconstr Loss_G: {:.6f}".format(counter,max_step,g_loss)) 246 | 247 | if counter % (10.*self.log_step) == 0: 248 | self.save(step=step) 249 | 250 | self.save() 251 | 252 | 253 | 254 | ##Just for reference## 255 | #def load(self, checkpoint_dir): 256 | # print(" [*] Reading checkpoints...") 257 | # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 258 | # ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 259 | # if ckpt and ckpt.model_checkpoint_path: 260 | # ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 261 | # self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 262 | # print(" [*] Success to read {}".format(ckpt_name)) 263 | # return True 264 | # else: 265 | # print(" [*] Failed to find a checkpoint") 266 | # return False 267 | #def norm_img(image, data_format=None): 268 | # image = image/127.5 - 1. 269 | # if data_format: 270 | # image = to_nhwc(image, data_format) 271 | # return image 272 | #def transform: 273 | # stuff 274 | # return np.array(cropped_image)/127.5 - 1. 275 | #def denorm_img(norm, data_format): 276 | # return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 277 | #def inverse_transform(images): 278 | # return (images+1.)/2. 279 | 280 | 281 | 282 | #if model.model_name=='began': 283 | # fake_labels=model.fake_labels 284 | # D_fake_labels=model.D_fake_labels 285 | # #result_dir=os.path.join('began',model.model_dir) 286 | # result_dir=model.model_dir 287 | # if str_step=='': 288 | # str_step=str( model.sess.run(model.step) )+'_' 289 | # attr=model.attr[list(model.cc.node_names)] 290 | #elif model.model_name=='dcgan': 291 | # fake_labels=model.fake_labels 292 | # D_fake_labels=model.D_labels_for_fake 293 | # result_dir=model.checkpoint_dir 294 | # attr=0.5*(model.attributes+1) 295 | # attr=attr[list(model.cc.names)] 296 | 297 | -------------------------------------------------------------------------------- /figure_scripts/high_level.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import scipy.misc 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import trange,tqdm 8 | import pandas as pd 9 | from itertools import combinations, product 10 | import sys 11 | from utils import save_figure_images,make_sample_dir,guess_model_step 12 | from sample import get_joint,sample,find_logit_percentile 13 | 14 | 15 | 16 | ''' 17 | This is a file where each function creates a particular figure. No real need 18 | for this to be configurable. Just make a new function for each figure 19 | 20 | This uses functions in sample.py and distribution.py, which are intended to 21 | be lower level functions that can be used more generally. 22 | 23 | ''' 24 | 25 | 26 | 27 | 28 | def fig1(model, output_folder): 29 | ''' 30 | This function makes two 2x10 images 31 | showing the difference between conditioning 32 | and intervening 33 | ''' 34 | 35 | str_step=guess_model_step(model) 36 | fname=os.path.join(output_folder,str_step+model.model_type) 37 | 38 | for key in ['Young','Smiling','Wearing_Lipstick','Male','Mouth_Slightly_Open','Narrow_Eyes']: 39 | #for key in ['Mustache','Bald']: 40 | #for key in ['Mustache']: 41 | print 'Starting ',key, 42 | #for key in ['Bald']: 43 | 44 | p50,n50=find_logit_percentile(model,key,50) 45 | do_dict={key:np.repeat([p50],10)} 46 | eps=3 47 | cond_dict={key:np.repeat([+eps],10)} 48 | 49 | out,_=sample(model,do_dict=do_dict) 50 | intv_images=out['G'] 51 | 52 | out,_=sample(model,cond_dict=cond_dict) 53 | cond_images=out['G'] 54 | 55 | images=np.vstack([intv_images,cond_images]) 56 | dc_file=fname+'_'+key+'_topdo1_botcond1.pdf' 57 | save_figure_images(model.model_type,images,dc_file,size=[2,10]) 58 | 59 | do_dict={key:np.repeat([p50,n50],10)} 60 | cond_dict={key:np.repeat([+eps,-eps],10)} 61 | 62 | dout,_=sample(model,do_dict=do_dict) 63 | cout,_=sample(model,cond_dict=cond_dict) 64 | 65 | itv_file = fname+'_'+key+'_topdo1_botdo0.pdf' 66 | cond_file = fname+'_'+key+'_topcond1_botcond0.pdf' 67 | eps=3 68 | 69 | save_figure_images(model.model_type,dout['G'],itv_file,size=[2,10]) 70 | save_figure_images(model.model_type,cout['G'],cond_file,size=[2,10]) 71 | print '..finished ',key 72 | 73 | #return images,cout['G'],dout['G'] 74 | return key 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /figure_scripts/pairwise.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import tensorflow as tf 4 | import os 5 | import scipy.misc 6 | import numpy as np 7 | from tqdm import trange 8 | 9 | import pandas as pd 10 | from itertools import combinations 11 | import sys 12 | from sample import sample 13 | 14 | 15 | 16 | 17 | def calc_tvd(label_dict,attr): 18 | ''' 19 | attr should be a 0,1 pandas dataframe with 20 | columns corresponding to label names 21 | 22 | for example: 23 | names=zip(*self.graph)[0] 24 | calc_tvd(label_dict,attr[names]) 25 | 26 | label_dict should be a dictionary key:1d-array of samples 27 | ''' 28 | ####Calculate Total Variation#### 29 | if np.min(attr.values)<0: 30 | raise ValueError('calc_tvd received \ 31 | attr that may not have been in {0,1}') 32 | 33 | label_names=label_dict.keys() 34 | attr=attr[label_names] 35 | 36 | df2=attr.drop_duplicates() 37 | df2 = df2.reset_index(drop = True).reset_index() 38 | df2=df2.rename(columns = {'index':'ID'}) 39 | real_data_id=pd.merge(attr,df2) 40 | real_counts = pd.value_counts(real_data_id['ID']) 41 | real_pdf=real_counts/len(attr) 42 | 43 | label_list_dict={k:np.round(v.ravel()) for k,v in label_dict.items()} 44 | df_dat=pd.DataFrame.from_dict(label_list_dict) 45 | dat_id=pd.merge(df_dat,df2,on=label_names,how='left') 46 | dat_counts=pd.value_counts(dat_id['ID']) 47 | dat_pdf = dat_counts / dat_counts.sum() 48 | diff=real_pdf.subtract(dat_pdf, fill_value=0) 49 | tvd=0.5*diff.abs().sum() 50 | return tvd 51 | 52 | 53 | def crosstab(model,result_dir=None,report_tvd=True,no_save=False,N=500000): 54 | ''' 55 | This is a script for outputing [0,1/2], [1/2,1] binned pdfs 56 | including the marginals and the pairwise comparisons 57 | 58 | report_tvd is given as optional because it is somewhat time consuming 59 | 60 | result_dir is where to save the distribution text files. defaults to 61 | model.cc.model_dir 62 | 63 | ''' 64 | result_dir=result_dir or model.cc.model_dir 65 | result={} 66 | 67 | n_labels=len(model.cc.nodes) 68 | 69 | #Not really sure how this should scale 70 | #N=1000*n_labels 71 | #N=500*n_labels**2#open to ideas that avoid a while loop 72 | #N=12000 73 | 74 | #tvd will not be reported as low unless N is large 75 | #N=500000 #default 76 | 77 | print('Calculating joint distribution with',) 78 | 79 | t0=time.time() 80 | label_dict=sample(model,fetch_dict=model.cc.label_dict,N=N) 81 | print('sampling model N=',N,' times took ',time.time()-t0,'sec') 82 | 83 | 84 | #fake_labels=model.cc.fake_labels 85 | 86 | str_step=str( model.sess.run(model.cc.step) )+'_' 87 | 88 | attr=model.data.attr 89 | attr=attr[model.cc.node_names] 90 | 91 | lab_xtab_fn = os.path.join(result_dir,str_step+'glabel_crosstab.txt') 92 | print('Writing to files:',lab_xtab_fn) 93 | 94 | if report_tvd: 95 | t0=time.time() 96 | tvd=calc_tvd(label_dict,attr) 97 | result['tvd']=tvd 98 | print('calculating tvd from samples took ',time.time()-t0,'sec') 99 | 100 | if no_save: 101 | return result 102 | 103 | t0=time.time() 104 | 105 | joint={} 106 | label_joint={} 107 | #for name, lab in zip(model.cc.node_names,list_labels): 108 | for name, lab in label_dict.items(): 109 | joint[name]={ 'g_fake_label':lab } 110 | 111 | 112 | #with open(dfl_xtab_fn,'w') as dlf_f, open(lab_xtab_fn,'w') as lab_f, open(gvsd_xtab_fn,'w') as gldf_f: 113 | with open(lab_xtab_fn,'w') as lab_f: 114 | if report_tvd: 115 | lab_f.write('TVD:'+str(tvd)+'\n\n') 116 | lab_f.write('Marginals:\n') 117 | 118 | #Marginals 119 | for name in joint.keys(): 120 | lab_f.write('Node: '+name+'\n') 121 | 122 | true_marg=np.mean((attr[name]>0.5).values) 123 | lab_marg=(joint[name]['g_fake_label'] > 0.5).astype('int') 124 | 125 | lab_f.write(' mean='+str(np.mean(lab_marg))+'\t'+\ 126 | 'true mean='+str(true_marg)+'\n') 127 | 128 | lab_f.write('\n') 129 | 130 | 131 | #Pairs of labels 132 | lab_f.write('\nPairwise:\n') 133 | 134 | for node1,node2 in combinations(joint.keys(),r=2): 135 | 136 | lab_node1=(joint[node1]['g_fake_label']>0.5).astype('int') 137 | lab_node2=(joint[node2]['g_fake_label']>0.5).astype('int') 138 | lab_df=pd.DataFrame(data=np.hstack([lab_node1,lab_node2]),columns=[node1,node2]) 139 | lab_ct=pd.crosstab(index=lab_df[node1],columns=lab_df[node2],margins=True,normalize=True) 140 | 141 | true_ct=pd.crosstab(index=attr[node1],columns=attr[node2],margins=True,normalize=True) 142 | 143 | 144 | lab_f.write('\n\tFake:\n') 145 | lab_ct.to_csv(lab_xtab_fn,mode='a') 146 | lab_f.write( lab_ct.__repr__() ) 147 | lab_f.write('\n\tReal:\n') 148 | lab_f.write( true_ct.__repr__() ) 149 | 150 | lab_f.write('\n\n') 151 | 152 | print('calculating pairwise crosstabs and saving results took ',time.time()-t0,'sec') 153 | return result 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /figure_scripts/probability_table.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | model: celebA_0627_200239 5 | graph:MLS 6 | 7 | [img,cc,d_fake_labels,true] 8 | 9 | P(M=1|S=1) = [0.28, 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /figure_scripts/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | import tensorflow as tf 3 | import os 4 | from os import listdir 5 | from os.path import isfile, join 6 | import shutil 7 | import sys 8 | import math 9 | import json 10 | import logging 11 | import numpy as np 12 | from PIL import Image 13 | from datetime import datetime 14 | 15 | import tensorflow as tf 16 | from PIL import Image 17 | 18 | import math 19 | import random 20 | import pprint 21 | import scipy.misc 22 | import numpy as np 23 | from time import gmtime, strftime 24 | from six.moves import xrange 25 | 26 | pp = pprint.PrettyPrinter() 27 | 28 | def nhwc_to_nchw(x): 29 | return tf.transpose(x, [0, 3, 1, 2]) 30 | def to_nchw_numpy(image): 31 | if image.shape[3] in [1, 3]: 32 | new_image = image.transpose([0, 3, 1, 2]) 33 | else: 34 | new_image = image 35 | return new_image 36 | 37 | def norm_img(image, data_format=None): 38 | #image = tf.cast(image,tf.float32)/127.5 - 1. 39 | image = image/127.5 - 1. 40 | #if data_format: 41 | #image = to_nhwc(image, data_format) 42 | if data_format=='NCHW': 43 | image = to_nchw_numpy(image) 44 | 45 | image=tf.cast(image,tf.float32) 46 | return image 47 | 48 | 49 | #Denorming 50 | def nchw_to_nhwc(x): 51 | return tf.transpose(x, [0, 2, 3, 1]) 52 | def to_nhwc(image, data_format): 53 | if data_format == 'NCHW': 54 | new_image = nchw_to_nhwc(image) 55 | else: 56 | new_image = image 57 | return new_image 58 | def denorm_img(norm, data_format): 59 | return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 60 | 61 | 62 | def read_prepared_uint8_image(img_path): 63 | ''' 64 | img_path should point to a uint8 image that is 65 | already cropped and resized 66 | ''' 67 | cropped_image=scipy.misc.imread(img_path) 68 | if not np.all( np.array([64,64,3])==cropped_image.shape): 69 | raise ValueError('image must already be cropped and resized:',img_path) 70 | #TODO: warn if wrong dtype 71 | return cropped_image 72 | 73 | def make_encode_dir(model,image_name): 74 | #Terminology 75 | if model.model_type=='began': 76 | result_dir=model.model_dir 77 | elif model.model_type=='dcgan': 78 | print('DCGAN') 79 | result_dir=model.checkpoint_dir 80 | encode_dir=os.path.join(result_dir,'encode_'+str(image_name)) 81 | if not os.path.exists(encode_dir): 82 | os.mkdir(encode_dir) 83 | return encode_dir 84 | 85 | def make_sample_dir(model): 86 | #Terminology 87 | if model.model_type=='began': 88 | result_dir=model.model_dir 89 | elif model.model_type=='dcgan': 90 | print('DCGAN') 91 | result_dir=model.checkpoint_dir 92 | 93 | sample_dir=os.path.join(result_dir,'sample_figures') 94 | if not os.path.exists(sample_dir): 95 | os.mkdir(sample_dir) 96 | return sample_dir 97 | 98 | def guess_model_step(model): 99 | if model.model_type=='began': 100 | str_step=str( model.sess.run(model.step) )+'_' 101 | elif model.model_type=='dcgan': 102 | result_dir=model.checkpoint_dir 103 | ckpt = tf.train.get_checkpoint_state(result_dir) 104 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 105 | str_step=ckpt_name[-5:]+'_' 106 | return str_step 107 | 108 | def infer_grid_image_shape(N): 109 | if N%8==0: 110 | size=[8,N//8] 111 | else: 112 | size=[8,8] 113 | return size 114 | 115 | 116 | def save_figure_images(model_type, tensor, filename, size, padding=2, normalize=False, scale_each=False): 117 | 118 | print('[*] saving:',filename) 119 | 120 | #nrow=size[0] 121 | nrow=size[1]#Was this number per row and now number of rows? 122 | 123 | if model_type=='began': 124 | began_save_image(tensor,filename,nrow,padding,normalize,scale_each) 125 | elif model_type=='dcgan': 126 | #images = np.split(tensor,len(tensor)) 127 | images=tensor 128 | dcgan_save_images(images,size,filename) 129 | 130 | 131 | #Began originally 132 | def make_grid(tensor, nrow=8, padding=2, 133 | normalize=False, scale_each=False): 134 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 135 | nmaps = tensor.shape[0] 136 | xmaps = min(nrow, nmaps) 137 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 138 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 139 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 140 | k = 0 141 | for y in range(ymaps): 142 | for x in range(xmaps): 143 | if k >= nmaps: 144 | break 145 | h, h_width = y * height + 1 + padding // 2, height - padding 146 | w, w_width = x * width + 1 + padding // 2, width - padding 147 | 148 | grid[h:h+h_width, w:w+w_width] = tensor[k] 149 | k = k + 1 150 | return grid 151 | 152 | def began_save_image(tensor, filename, nrow=8, padding=2, 153 | normalize=False, scale_each=False): 154 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 155 | normalize=normalize, scale_each=scale_each) 156 | im = Image.fromarray(ndarr) 157 | im.save(filename) 158 | 159 | 160 | 161 | #Dcgan originally 162 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 163 | 164 | def get_image(image_path, input_height, input_width, 165 | resize_height=64, resize_width=64, 166 | is_crop=True, is_grayscale=False): 167 | image = imread(image_path, is_grayscale) 168 | return transform(image, input_height, input_width, 169 | resize_height, resize_width, is_crop) 170 | 171 | def dcgan_save_images(images, size, image_path): 172 | return imsave(inverse_transform(images), size, image_path) 173 | 174 | def imread(path, is_grayscale = False): 175 | if (is_grayscale): 176 | return scipy.misc.imread(path, flatten = True).astype(np.float) 177 | else: 178 | return scipy.misc.imread(path).astype(np.float) 179 | 180 | def merge_images(images, size): 181 | return inverse_transform(images) 182 | 183 | def merge(images, size): 184 | h, w = images.shape[1], images.shape[2] 185 | img = np.zeros((h * size[0], w * size[1], 3)) 186 | for idx, image in enumerate(images): 187 | i = idx % size[1] 188 | j = idx // size[1] 189 | img[j*h:j*h+h, i*w:i*w+w, :] = image 190 | return img 191 | 192 | def imsave(images, size, path): 193 | return scipy.misc.imsave(path, merge(images, size)) 194 | 195 | def center_crop(x, crop_h, crop_w, 196 | resize_h=64, resize_w=64): 197 | if crop_w is None: 198 | crop_w = crop_h 199 | h, w = x.shape[:2] 200 | j = int(round((h - crop_h)/2.)) 201 | i = int(round((w - crop_w)/2.)) 202 | return scipy.misc.imresize( 203 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 204 | 205 | def transform(image, input_height, input_width, 206 | resize_height=64, resize_width=64, is_crop=True): 207 | if is_crop: 208 | cropped_image = center_crop( 209 | image, input_height, input_width, 210 | resize_height, resize_width) 211 | else: 212 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 213 | return np.array(cropped_image)/127.5 - 1. 214 | 215 | def inverse_transform(images): 216 | return (images+1.)/2. 217 | 218 | 219 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | 6 | from trainer import Trainer 7 | from causal_graph import get_causal_graph 8 | from utils import prepare_dirs_and_logger, save_configs 9 | 10 | #Generic configuration arguments 11 | from config import get_config 12 | #Submodel specific configurations 13 | from causal_controller.config import get_config as get_cc_config 14 | from causal_dcgan.config import get_config as get_dcgan_config 15 | from causal_began.config import get_config as get_began_config 16 | 17 | from causal_began import CausalBEGAN 18 | from causal_dcgan import CausalGAN 19 | 20 | from IPython.core import debugger 21 | debug = debugger.Pdb().set_trace 22 | 23 | def get_trainer(): 24 | print('tf: resetting default graph!') 25 | tf.reset_default_graph()#for repeated calls in ipython 26 | 27 | 28 | ####GET CONFIGURATION#### 29 | #TODO:load configurations from previous model when loading previous model 30 | ##if load_path: 31 | #load config files from dir 32 | #except if pt_load_path, get cc_config from before 33 | #overwrite is_train, is_pretrain with current args--sort of a mess 34 | 35 | ##else: 36 | config,_=get_config() 37 | cc_config,_=get_cc_config() 38 | dcgan_config,_=get_dcgan_config() 39 | began_config,_=get_began_config() 40 | 41 | ###SEEDS### 42 | np.random.seed(config.seed) 43 | #tf.set_random_seed(config.seed) # Not working right now. 44 | 45 | prepare_dirs_and_logger(config) 46 | if not config.load_path: 47 | print('saving config because load path not given') 48 | save_configs(config,cc_config,dcgan_config,began_config) 49 | 50 | #Resolve model differences and batch_size 51 | if config.model_type: 52 | if config.model_type=='dcgan': 53 | config.batch_size=dcgan_config.batch_size 54 | cc_config.batch_size=dcgan_config.batch_size # make sure the batch size of cc is the same as the image model 55 | config.Model=CausalGAN.CausalGAN 56 | model_config=dcgan_config 57 | if config.model_type=='began': 58 | config.batch_size=began_config.batch_size 59 | cc_config.batch_size=began_config.batch_size # make sure the batch size of cc is the same as the image model 60 | config.Model=CausalBEGAN.CausalBEGAN 61 | model_config=began_config 62 | 63 | else:#no image model 64 | model_config=None 65 | config.batch_size=cc_config.batch_size 66 | 67 | if began_config.is_train or dcgan_config.is_train: 68 | raise ValueError('need to specify model_type for is_train=True') 69 | 70 | #Interpret causal_model keyword 71 | cc_config.graph=get_causal_graph(config.causal_model) 72 | 73 | #Builds and loads specified models: 74 | trainer=Trainer(config,cc_config,model_config) 75 | return trainer 76 | 77 | def main(trainer): 78 | #Do pretraining 79 | if trainer.cc_config.is_pretrain: 80 | trainer.pretrain_loop() 81 | 82 | if trainer.model_config: 83 | if trainer.model_config.is_train: 84 | trainer.train_loop() 85 | 86 | if __name__ == "__main__": 87 | trainer=get_trainer() 88 | 89 | #make ipython easier 90 | sess=trainer.sess 91 | cc=trainer.cc 92 | if hasattr(trainer,'model'): 93 | model=trainer.model 94 | 95 | main(trainer) 96 | 97 | tf.logging.set_verbosity(tf.logging.ERROR) -------------------------------------------------------------------------------- /synthetic/README.md: -------------------------------------------------------------------------------- 1 | # Causal(BE)GAN in Tensorflow 2 | 3 | # (test comment) 4 | 5 | Synthetic Data Figures 6 | <> (Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717).) 7 | 8 | Authors' Tensorflow implementation Synthetic portion of [CausalGAN: Learning Implicit Causal Models with Adversarial Training] 9 | 10 | <>some results files 11 | 12 | ## Setup. 13 | 14 | If not already set, make sure that run_datasets.sh is an executable by running 15 | $ chmod +x run_datasets.sh 16 | 17 | ## Usage 18 | 19 | A single run of main.py trains as many GANs as are in models.py (presently 6) for a single --data_type. This author can fit 3 such runs on a single gpu and conveniently there are 3 datasets considered. 20 | 21 | $ CUDA_VISIBLE_DEVICES='0' python main.py --data_type=linear 22 | 23 | Again the tboard.py utility is available to view the most recent model summaries. 24 | 25 | $ python tboard.py 26 | 27 | Recovering statistics means averaging over many runs. Mass usage follows the script run_datasets.sh. This bash script will train all GAN models on each of 3 datasets 30 times per dataset. The following will train 2(calls) x 30(loop/call) x 3(datasets/loop) x 6(gan models/dataset)=1080(gan models) 28 | 29 | 30 | $ (open first terminal) 31 | $ CUDA_VISIBLE_DEVICES='0' ./run_datasets.sh 32 | $ (open second terminal) 33 | $ CUDA_VISIBLE_DEVICES='1' ./run_datasets.sh 34 | 35 | 36 | ## Collecting Statistics 37 | 38 | 39 | ## Results 40 | 41 | 42 | ## Authors 43 | 44 | Christopher Snyder / [@22csnyder](http://22csnyder.github.io) 45 | Murat Kocaoglu / [@mkocaoglu](http://mkocaoglu.github.io) 46 | -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_all.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_all.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_collider.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_collider.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_complete.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_complete.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_data.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_data.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_fc5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_fc5.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_linear.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_linear.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notextcollider.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notextcollider.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notextcomplete.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notextcomplete.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notextdata.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notextdata.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notextfc5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notextfc5.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notextlinear.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notextlinear.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notitlecollider.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notitlecollider.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notitlecomplete.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notitlecomplete.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notitledata.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notitledata.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notitlefc5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notitlefc5.pdf -------------------------------------------------------------------------------- /synthetic/assets/0818_072052/x1x3_notitlelinear.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/0818_072052/x1x3_notitlelinear.pdf -------------------------------------------------------------------------------- /synthetic/assets/collider_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/collider_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/complete_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/complete_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/linear_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/linear_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/liny_tvd/liny_collider_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/liny_tvd/liny_collider_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/liny_tvd/liny_complete_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/liny_tvd/liny_complete_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/liny_tvd/liny_linear_synth_tvd_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/liny_tvd/liny_linear_synth_tvd_vs_time.pdf -------------------------------------------------------------------------------- /synthetic/assets/synth_tvd_vs_time_titled.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkocaoglu/CausalGAN/9d52b520b5ef7309aa159c5368494a1a5c4fbab3/synthetic/assets/synth_tvd_vs_time_titled.pdf -------------------------------------------------------------------------------- /synthetic/collect_stats.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import time 4 | from scipy import stats 5 | import os 6 | import matplotlib.pyplot as plt 7 | from models import GeneratorTypes,DataTypes 8 | import brewer2mpl 9 | 10 | 11 | def makeplots(x_iter,tvd_datastore,show=False,save=False,save_name=None): 12 | #Make plots 13 | dtypes=tvd_datastore.keys() 14 | fig,axes=plt.subplots(1,len(dtypes)) 15 | 16 | #fig.subplots_adjust(hspace=0.5,wspace=0.025) 17 | fig.subplots_adjust(hspace=0.75,wspace=0.05) 18 | 19 | x_iter=x_iter.astype('float')/1000 20 | 21 | 22 | for ax,dtype in zip(axes,dtypes): 23 | if ax in axes[:-1]: 24 | use_legend = False 25 | else: 26 | use_legend = True 27 | 28 | if ax==axes[0]: 29 | prefix='Synthetic Data Graph: ' 30 | posfix=' ' 31 | else: 32 | prefix='' 33 | posfix='' 34 | axtitle=prefix+dtype+posfix 35 | 36 | #df=pd.DataFrame.from_dict(tvd_datastore[dtype]) 37 | 38 | 39 | df_tvd=pd.DataFrame(data={gtype:tvd_datastore[dtype][gtype]['tvd'] for gtype in gtypes}) 40 | df_sem=pd.DataFrame(data={gtype:tvd_datastore[dtype][gtype]['sem'] for gtype in gtypes}) 41 | df_tvd.index=x_iter;df_sem.index=x_iter 42 | 43 | 44 | df_tvd.plot.line(ax=ax,sharey=True,use_index=True,yerr=df_sem,legend=use_legend,capsize=5,capthick=3,elinewidth=1,errorevery=100) 45 | 46 | 47 | ax.set_title(axtitle.title(),fontsize=18) 48 | ax.set_ylabel('Total Variational Distance',fontsize=18) 49 | if ax is axes[1]: 50 | ax.set_xlabel('iter(thousands)',fontsize=18) 51 | 52 | t='Graph Structured Generator tvd Convergence on Synthetic Data with Known Causal Graph' 53 | plt.suptitle(t,fontsize=20) 54 | 55 | fig.set_figwidth(15,forward=True) 56 | fig.set_figheight(7,forward=True) 57 | 58 | if save: 59 | save_name=save_name or 'synth_tvd_vs_time.pdf' 60 | save_path=os.path.join('assets',save_name) 61 | 62 | plt.savefig(save_path,bbox_inches='tight') 63 | #plt.savefig(save_path) 64 | 65 | if show: 66 | plt.show(block=False) 67 | return fig,axes 68 | 69 | def make_individual_plots(x_iter,tvd_datastore,smooth=True,show=False,save=False,save_name=None): 70 | fontsize=17.5 71 | tickfont=15 72 | 73 | gtypes=GeneratorTypes.keys() 74 | dtypes=tvd_datastore.keys() 75 | 76 | format_columns={ 77 | 'fc3' :'FC3', 78 | 'fc5' :'FC5', 79 | 'fc10' :'FC10', 80 | 'collider':'Collider', 81 | 'linear' :'Linear', 82 | 'complete':'Complete', 83 | } 84 | 85 | #styles={ 86 | # 'FC3' :'bs-', 87 | # 'FC5' :'ro-', 88 | # 'FC10' :'y^-', 89 | # 'Collider':'g+-', 90 | # 'Linear' :'m>-', 91 | # 'Complete':'kd-', 92 | # } 93 | styles={ 94 | 'FC3' :'s-', 95 | 'FC5' :'o-', 96 | 'FC10' :'^-', 97 | 'Collider':'+-', 98 | 'Linear' :'>-', 99 | 'Complete':'d-', 100 | } 101 | 102 | #bmap = brewer2mpl.get_map('Set2', 'qualitative', 7) 103 | #colors = bmap.mpl_colors 104 | colors=['b','r','y','g','m','k'] 105 | markers=['s','o','^','+','>','d'] 106 | 107 | #plt.style.use('seaborn-dark-palette') 108 | #plt.style.use('ggplot') 109 | plt.style.use('seaborn-deep') 110 | 111 | #Make plots 112 | 113 | #fig.subplots_adjust(hspace=0.5,wspace=0.025) 114 | #fig.subplots_adjust(hspace=0.75,wspace=0.05) 115 | 116 | x_iter=x_iter.astype('float')/1000 117 | 118 | for dtype in dtypes: 119 | use_legend=True 120 | 121 | #fig=plt.figure() 122 | 123 | df_tvd=pd.DataFrame(data={format_columns[gtype]:tvd_datastore[dtype][gtype]['tvd'] for gtype in gtypes}) 124 | df_sem=pd.DataFrame(data={format_columns[gtype]:tvd_datastore[dtype][gtype]['sem'] for gtype in gtypes}) 125 | df_tvd.index=x_iter;df_sem.index=x_iter 126 | 127 | 128 | if smooth: 129 | df_tvd=df_tvd.rolling(window=5,min_periods=1,center=True).mean() 130 | 131 | 132 | #styles=['bs-','ro-','y^-','g+-','m>-','kd-'] 133 | 134 | # df_tvd.plot.line(use_index=True,yerr=df_sem,legend=use_legend,capsize=5,capthick=3,elinewidth=1,errorevery=100,figsize=(6,4),style=styles,markevery=10,markersize=100) 135 | #df_tvd.plot.line(use_index=True,yerr=df_sem,legend=use_legend,capsize=5,capthick=3,elinewidth=1,errorevery=100,figsize=(6,4),style=styles,markersize=100) 136 | 137 | fig=plt.figure() 138 | ax=fig.add_subplot(111) 139 | i=0 140 | for col in df_tvd.columns: 141 | #df_tvd[col].plot(ax=ax,use_index=True,yerr=df_sem[col],legend=use_legend,capsize=5,capthick=3,elinewidth=1,errorevery=100,figsize=(6,4),linestyle='-',color=colors[i],marker=markers[i],markevery=50,markersize=7) 142 | #print 'col',col#Linear last 143 | #df_tvd[col].plot(ax=ax,use_index=True,yerr=df_sem[col],legend=use_legend,capsize=5,capthick=3,elinewidth=1,errorevery=100,figsize=(6,4),linestyle='-',marker=markers[i],markevery=50,markersize=7) 144 | df_tvd[col].plot(ax=ax,use_index=True,yerr=df_sem[col],capsize=5,capthick=3,elinewidth=1,errorevery=100,figsize=(6,4),linestyle='-',marker=markers[i],markevery=50,markersize=7) 145 | i+=1 146 | 147 | ax.set_yscale('log') 148 | plt.legend() 149 | 150 | plt.xticks(fontsize=tickfont) 151 | plt.yticks(fontsize=tickfont) 152 | 153 | 154 | plt.ylim([0,1]) 155 | 156 | plt.ylabel('Total Variation Distance',fontsize=fontsize) 157 | plt.xlabel('Iteration (in thousands)',fontsize=fontsize) 158 | 159 | if save: 160 | file_name=save_name or 'synth_tvd_vs_time.pdf' 161 | file_name=dtype+'_'+file_name 162 | save_path=os.path.join('assets',file_name) 163 | plt.savefig(save_path,bbox_inches='tight') 164 | #plt.savefig(save_path) 165 | 166 | if show: 167 | plt.show(block=False) 168 | 169 | 170 | if __name__=='__main__': 171 | dtypes=DataTypes.keys() 172 | gtypes=GeneratorTypes.keys() 173 | 174 | logdir='logs/figure_logs' 175 | 176 | #init 177 | #Create a dictionary for each dataset, of dictionaries for each gen_type 178 | tvd_all_datastore={dt:{gt:[] for gt in gtypes} for dt in dtypes} 179 | tvd_datastore={dt:{} for dt in dtypes} 180 | runs=os.listdir(logdir) 181 | 182 | for dtype in dtypes: 183 | print '' 184 | print 'Collecting data for datatype ',dtype,'...' 185 | 186 | typed_runs=filter(lambda x:x.endswith(dtype),runs) 187 | 188 | for gtype in gtypes: 189 | n_runs=0 190 | 191 | #Go through all runs for each (dtype,gtype) pair 192 | for run in typed_runs: 193 | #tvd_csv={gt:os.path.join(logdir,run,gt,'tvd.csv') for gt in gtypes} 194 | tvd_csv=os.path.join(logdir,run,gtype,'tvd.csv') 195 | 196 | #cols=['step','tvd','mvd'] 197 | dat=pd.read_csv(tvd_csv,sep=' ') 198 | 199 | if len(dat)!=1001: 200 | print 'WARN: file',tvd_csv,'was of length:',len(dat), 201 | print 'it may be in the process of optimizing.. not using' 202 | continue 203 | 204 | #tvd_all_datastore[dtype][gtype]+=dat['tvd'] 205 | tvd_all_datastore[dtype][gtype].append(dat['tvd']) 206 | n_runs+=1 207 | 208 | 209 | #after (dtype,gtype) collection 210 | if n_runs==0: 211 | #remove key since no matching gtype for dtype 212 | print 'Warning: for dtype',dtype,' no runs of gtype ',gtype 213 | #tvd_all_datastore[dtype].pop(gtype) 214 | else: 215 | df_concat=pd.concat(tvd_all_datastore[dtype][gtype],axis=1) 216 | gb=df_concat.groupby(by=df_concat.columns,axis=1) 217 | mean=gb.mean() 218 | sem=gb.sem().rename(columns={'tvd':'sem'}) 219 | tvd_datastore[dtype][gtype]=pd.concat([mean,sem],axis=1) 220 | 221 | #tvd_all_datastore[dtype][gtype]/=n_runs 222 | 223 | #concat 224 | #groupby 225 | 226 | #after dtype collection 227 | if len(tvd_datastore[dtype])==0: 228 | print 'Warning: no runs of dtype ',dtype 229 | tvd_datastore.pop(dtype) 230 | 231 | 232 | print '...There were ',n_runs,' runs of ',dtype 233 | 234 | 235 | x_iter=dat['iter'].values 236 | 237 | 238 | 239 | #run in ipython depending on what you want 240 | #fig,axes=makeplots(x_iter,tvd_datastore,show=False,save=True) 241 | make_individual_plots(x_iter,tvd_datastore,smooth=True,show=True,save=True) 242 | 243 | 244 | time.sleep(10) 245 | 246 | -------------------------------------------------------------------------------- /synthetic/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models import DataTypes 3 | def str2bool(v): 4 | return v is True or v.lower() in ('true', '1') 5 | 6 | dtypes=DataTypes.keys() 7 | 8 | 9 | arg_lists = [] 10 | parser = argparse.ArgumentParser() 11 | 12 | def add_argument_group(name): 13 | arg = parser.add_argument_group(name) 14 | arg_lists.append(arg) 15 | return arg 16 | 17 | #Pretrain network 18 | data_arg=add_argument_group('Data') 19 | gan_arg=add_argument_group('GAN') 20 | misc_arg=add_argument_group('misc') 21 | model_arg=add_argument_group('Model') 22 | 23 | data_arg.add_argument('--data_type',type=str,choices=dtypes, 24 | default='collider', help='''This is the graph structure 25 | that generates the synthetic dataset through polynomials''') 26 | 27 | gan_arg.add_argument('--gen_z_dim',type=int,default=10, 28 | help='''dim of noise input for generator''') 29 | gan_arg.add_argument('--gen_hidden_size',type=int,default=10,#3, 30 | help='''hidden size used for layers of generator''') 31 | gan_arg.add_argument('--disc_hidden_size',type=int,default=10,#6, 32 | help='''hidden size used for layers of discriminator''') 33 | gan_arg.add_argument('--lr_gen',type=float,default=0.0005,#0.005 34 | help='''generator learning rate''') 35 | gan_arg.add_argument('--lr_disc',type=float,default=0.0005,#0.0025 36 | help='''discriminator learning rate''') 37 | 38 | #broken 39 | #misc_arg.add_argument('--save_pdfs',type=str2bool,default=False, 40 | # help='''whether to save pdfs of scatterplots of x1x3 along 41 | # with tensorboard summaries''') 42 | 43 | misc_arg.add_argument('--model_dir',type=str,default='logs') 44 | #misc_arg.add_argument('--np_random_seed', type=int, default=123) 45 | #misc_arg.add_argument('--tf_random_seed', type=int, default=123) 46 | 47 | 48 | model_arg.add_argument('--load_path',type=str,default='', 49 | help='''Path to folder containing model to load. This 50 | should be actual checkpoint to load. Example: 51 | --load_path=./logs/0817_153755_collider/checkpoints/Model-50000''') 52 | model_arg.add_argument('--is_train',type=str2bool,default=True, 53 | help='''whether the model should train''') 54 | model_arg.add_argument('--batch_size',type=int,default=64, 55 | help='''batch_size for all generators and all 56 | discriminators''') 57 | 58 | 59 | def get_config(): 60 | 61 | #setattr(config, 'data_dir', data_format) 62 | config, unparsed = parser.parse_known_args() 63 | return config, unparsed 64 | 65 | -------------------------------------------------------------------------------- /synthetic/figure_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": false, 8 | "scrolled": true 9 | }, 10 | "outputs": [ 11 | { 12 | "name": "stdout", 13 | "output_type": "stream", 14 | "text": [ 15 | "tf: resetting default graph!\n", 16 | "Using data_type linear\n", 17 | "Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000\n", 18 | "[*] MODEL dir: ./logs/0818_072052_linear/checkpoints/Model-50000\n", 19 | "[*] PARAM path: ./logs/0818_072052_linear/checkpoints/Model-50000/params.json\n", 20 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/fc3\n", 21 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/collider\n", 22 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/fc5\n", 23 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/linear\n", 24 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/fc10\n", 25 | "GAN Model directory is ./logs/0818_072052_linear/checkpoints/Model-50000/complete\n", 26 | " [*] Attempting to restore ./logs/0818_072052_linear/checkpoints/Model-50000\n", 27 | "INFO:tensorflow:Restoring parameters from ./logs/0818_072052_linear/checkpoints/Model-50000\n", 28 | "built trainer successfully\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "%run main.py --data_type 'linear' --load_path './logs/0818_072052_linear/checkpoints/Model-50000' --is_train False" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "Using matplotlib backend: TkAgg\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "%matplotlib\n", 53 | "import matplotlib.pyplot as plt" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "sess=trainer.sess;gans=trainer.gans\n", 65 | "Xgs=[sess.run(g.gen.X,{g.gen.N:5000}) for g in gans]\n", 66 | "split_Xgs=[np.split(x,3,axis=1) for x in Xgs]\n", 67 | "X13gs=[[x[0],x[-1]] for x in split_Xgs]\n", 68 | "Xds=np.split(sess.run(trainer.data.X,{trainer.data.N:5000}),3,axis=1)\n", 69 | "X13d=[Xds[0],Xds[-1]]\n", 70 | "\n", 71 | "data_dict={'data':X13d}\n", 72 | "for g,dat in zip(gans,X13gs):\n", 73 | " data_dict[g.gan_type]=dat\n", 74 | "\n", 75 | "gan_plots=['data','linear','collider','fc5']\n" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": { 82 | "collapsed": true 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "titles={'data':'Data Distribution',\n", 87 | " 'linear':'Linear Generator',\n", 88 | " 'complete':'Complete Generator',\n", 89 | " 'collider':'Collider Generator',\n", 90 | " 'fc5':'Fully Connected Generator'}" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 74, 96 | "metadata": { 97 | "collapsed": false 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "#all at once\n", 102 | "fig,axes=plt.subplots(1,len(gan_plots),sharey=True)\n", 103 | "\n", 104 | "for gtype,ax in zip(gan_plots,axes):\n", 105 | " data=data_dict[gtype]\n", 106 | " ax.scatter(data[0],data[1])\n", 107 | " \n", 108 | " ax.set_title(titles[gtype])\n", 109 | " ax.set_xlabel('X1')\n", 110 | " if gtype==gan_plots[0]:\n", 111 | " ax.set_ylabel('X3')\n", 112 | "\n", 113 | " \n", 114 | "fig.canvas.draw()\n", 115 | "plt.show() \n", 116 | "\n", 117 | "fig.subplots_adjust(wspace=0.04,left=0.05,hspace=0.04,right=0.98)\n", 118 | "\n", 119 | "fig.set_figheight(4)\n", 120 | "fig.set_figwidth(12)\n", 121 | "\n", 122 | "plt.savefig('assets/0818_072052_x1x3_all.pdf')" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 97, 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "#one at a time\n", 134 | "\n", 135 | "for gtype in titles.keys():\n", 136 | " data=data_dict[gtype]\n", 137 | " fig=plt.figure()\n", 138 | " plt.scatter(data[0],data[1])\n", 139 | " plt.xlim([0,1])\n", 140 | " plt.ylim([0,1])\n", 141 | " \n", 142 | " plt.title(titles[gtype],fontsize=20)\n", 143 | "\n", 144 | " plt.ylabel('X3',fontsize=16)\n", 145 | " plt.xlabel('X1',fontsize=16)\n", 146 | " save_path='assets/'+'0818_072052/'+'x1x3_'+gtype+'.pdf'\n", 147 | " plt.savefig(save_path)\n", 148 | "\n" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": { 155 | "collapsed": false 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "#no titles\n", 160 | "\n", 161 | "for gtype in titles.keys():\n", 162 | " data=data_dict[gtype]\n", 163 | " fig=plt.figure()\n", 164 | " plt.scatter(data[0],data[1])\n", 165 | " plt.xlim([0,1])\n", 166 | " plt.ylim([0,1])\n", 167 | " \n", 168 | " #plt.title(titles[gtype],fontsize=20)\n", 169 | "\n", 170 | " plt.ylabel('X3',fontsize=16)\n", 171 | " plt.xlabel('X1',fontsize=16)\n", 172 | " save_path='assets/'+'0818_072052/'+'x1x3_notitle'+gtype+'.pdf'\n", 173 | " plt.savefig(save_path)\n", 174 | "\n" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 96, 180 | "metadata": { 181 | "collapsed": false 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "#no text\n", 186 | "#No titles: leave to latex to add titles/axes\n", 187 | "\n", 188 | "for gtype in titles.keys():\n", 189 | " data=data_dict[gtype]\n", 190 | " fig=plt.figure()\n", 191 | " plt.scatter(data[0],data[1])\n", 192 | " plt.xlim([0,1])\n", 193 | " plt.ylim([0,1])\n", 194 | " \n", 195 | " #plt.title(titles[gtype],fontsize=14)\n", 196 | "\n", 197 | " #plt.ylabel('X3',fontsize=14)\n", 198 | " #plt.xlabel('X1',fontsize=14)\n", 199 | " save_path='assets/'+'0818_072052/'+'x1x3_notext'+gtype+'.pdf'\n", 200 | " plt.savefig(save_path)\n", 201 | "\n" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 68, 207 | "metadata": { 208 | "collapsed": true 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "fig.subplots_adjust?" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 2, 218 | "metadata": { 219 | "collapsed": false 220 | }, 221 | "outputs": [ 222 | { 223 | "ename": "NameError", 224 | "evalue": "name 'trainer' is not defined", 225 | "output_type": "error", 226 | "traceback": [ 227 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 228 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 229 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 230 | "\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "trainer" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": { 242 | "collapsed": true 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "from utils import scatter2d" 247 | ] 248 | } 249 | ], 250 | "metadata": { 251 | "kernelspec": { 252 | "display_name": "Python 2", 253 | "language": "python", 254 | "name": "python2" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 2 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython2", 266 | "version": "2.7.12" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 1 271 | } 272 | -------------------------------------------------------------------------------- /synthetic/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from trainer import Trainer 6 | from config import get_config 7 | import os 8 | 9 | from IPython.core import debugger 10 | debug = debugger.Pdb().set_trace 11 | 12 | 13 | '''main code for synthetic experiments 14 | 15 | ''' 16 | 17 | 18 | def get_trainer(config): 19 | print('tf: resetting default graph!') 20 | tf.reset_default_graph() 21 | 22 | #tf.set_random_seed(config.random_seed) 23 | #np.random.seed(22) 24 | 25 | print('Using data_type ',config.data_type) 26 | trainer=Trainer(config,config.data_type) 27 | print('built trainer successfully') 28 | 29 | tf.logging.set_verbosity(tf.logging.ERROR) 30 | 31 | return trainer 32 | 33 | 34 | def main(trainer,config): 35 | 36 | if config.is_train: 37 | trainer.train() 38 | 39 | 40 | 41 | def get_model(config=None): 42 | if not None: 43 | config, unparsed = get_config() 44 | return get_trainer(config) 45 | 46 | if __name__ == "__main__": 47 | config, unparsed = get_config() 48 | if not os.path.exists(config.model_dir): 49 | os.mkdir(config.model_dir) 50 | trainer=get_trainer(config) 51 | main(trainer,config) 52 | 53 | 54 | -------------------------------------------------------------------------------- /synthetic/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | from utils import * 4 | 5 | #class Data3d 6 | 7 | def sxe(logits,labels): 8 | #use zeros or ones if pass in scalar 9 | if not isinstance(labels,tf.Tensor): 10 | labels=labels*tf.ones_like(logits) 11 | return tf.nn.sigmoid_cross_entropy_with_logits( 12 | logits=logits,labels=labels) 13 | 14 | #def linear(input_, output_dim, scope=None, stddev=10.): 15 | def linear(input_, output_dim, scope=None, stddev=.7): 16 | unif = tf.uniform_unit_scaling_initializer() 17 | norm = tf.random_normal_initializer(stddev=stddev) 18 | const = tf.constant_initializer(0.0) 19 | with tf.variable_scope(scope or 'linear'): 20 | #w = tf.get_variable('w', [input_.get_shape()[1], output_dim], initializer=unif) 21 | w = tf.get_variable('w', [input_.get_shape()[1], output_dim], initializer=norm) 22 | b = tf.get_variable('b', [output_dim], initializer=const) 23 | return tf.matmul(input_, w) + b 24 | 25 | 26 | class Arrows: 27 | x_dim=3 28 | e_dim=3 29 | bdry_buffer=0.05# output in [bdry_buffer,1-bdry_buffer] 30 | def __init__(self,N): 31 | with tf.variable_scope('Arrow') as scope: 32 | self.N=tf.placeholder_with_default(N,shape=[]) 33 | #self.N=tf.constant(N) #how many to sample at a time 34 | self.e1=tf.random_uniform([self.N,1],0,1) 35 | self.e2=tf.random_uniform([self.N,1],0,1) 36 | self.e3=tf.random_uniform([self.N,1],0,1) 37 | self.build() 38 | #WARN. some of these are not trainable: i.e. poly 39 | self.var = tf.contrib.framework.get_variables(scope) 40 | def build(self): 41 | pass 42 | 43 | def normalize_output(self,X): 44 | ''' 45 | I think that data literally in [0,1] was difficult for sigmoid network. 46 | Therefore, I am normalizing it to [bdry_buffer,1-bdry_buffer] 47 | 48 | X: assumed to be in [0,1] 49 | ''' 50 | return (1.-2*self.bdry_buffer)*X + self.bdry_buffer 51 | 52 | 53 | 54 | class Generator: 55 | x_dim=3 56 | def __init__(self, N, hidden_size=10,z_dim=10): 57 | with tf.variable_scope('Gen') as scope: 58 | self.N=tf.placeholder_with_default(N,shape=[]) 59 | self.hidden_size=hidden_size 60 | self.z_dim=z_dim 61 | self.build() 62 | self.tr_var = tf.contrib.framework.get_variables(scope) 63 | self.step=tf.Variable(0,name='step',trainable=False) 64 | self.var = tf.contrib.framework.get_variables(scope) 65 | def build(self): 66 | raise Exception('must override') 67 | def smallNN(self,inputs,name='smallNN'): 68 | with tf.variable_scope(name): 69 | if isinstance(inputs,list): 70 | inputs=tf.concat(inputs,axis=1) 71 | h01 = tf.tanh(linear(inputs, self.hidden_size, name+'l1')) 72 | h11 = tf.tanh(linear(h01, self.hidden_size, name+'l21')) 73 | #h21 = output_nonlinearity(linear(h11, 1, name+'l31')) 74 | #h21 = linear(h11, 1, name+'l31') 75 | h21 = tf.sigmoid(linear(h11, 1, name+'l31')) 76 | 77 | return h21#rank2 78 | #return tf.sigmoid(h21)#rank2 79 | 80 | 81 | randunif=tf.random_uniform_initializer(0,1,dtype=tf.float32) 82 | def poly(cause,cause2=None,cause3=None,name='poly1d',reuse=None): 83 | #assumes input is in [0,1]. Enforces output is in [0,1] 84 | #if cause2 is not given, this is a cubic poly is 1 variable 85 | 86 | #cause and cause2 should be given as tensors like (N,1) 87 | 88 | #Check conditions 89 | if isinstance(cause2,str): 90 | raise ValueError('cause2 was a string. you probably forgot to include\ 91 | the "name=" keyword when specifying only 1 cause') 92 | if isinstance(cause3,str): 93 | raise ValueError('cause3 was a string. you probably forgot to include\ 94 | the "name=" keyword when specifying only 1 cause') 95 | if not len(cause.shape)>=2: 96 | cshape=cause.get_shape().as_list() 97 | raise ValueError('cause and cause2 must have len(shape)>=2. shape was' , cshape ) 98 | if cause2 is not None: 99 | if not len(cause2.get_shape().as_list())>=2: 100 | cshape2=cause2.get_shape().as_list() 101 | raise ValueError('cause and cause2 must have len(shape)>=2. shape was %r'%(cshape2)) 102 | if cause3 is not None: 103 | if not len(cause3.get_shape().as_list())>=2: 104 | cshape3=cause3.get_shape().as_list() 105 | raise ValueError('cause and cause3 must have len(shape)>=2. shape was %r'%(cshape3)) 106 | 107 | #Start 108 | with tf.variable_scope(name,reuse=reuse): 109 | if cause2 is not None and cause3 is not None: 110 | inputs=[tf.ones_like(cause),cause,cause2,cause3] 111 | if cause2 is not None and cause3 is None: 112 | inputs=[tf.ones_like(cause),cause,cause2] 113 | else: 114 | inputs=[tf.ones_like(cause),cause] 115 | dim=len(inputs)#2 or 3 or 4 116 | 117 | C=np.random.rand(1,dim,dim,dim).astype(np.float32)#unif 118 | C=2*C-1 #unif[-1,1] 119 | 120 | n=200 121 | N=n**(dim-1) 122 | grids=np.mgrid[[slice(0,1,1./n) for i in inputs[1:]]] 123 | y=np.hstack([np.ones((N,1))]+[g.reshape(N,1) for g in grids]) 124 | y1=np.reshape(y,[N,-1,1,1]) 125 | y2=np.reshape(y,[N,1,-1,1]) 126 | y3=np.reshape(y,[N,1,1,-1]) 127 | 128 | test_poly=np.sum(y1*y2*y3*C,axis=(1,2,3)) 129 | Cmin=np.min(test_poly) 130 | Cmax=np.max(test_poly) 131 | #normalize [0,1]->[0,1] 132 | C[0,0,0,0]-=Cmin 133 | C/=(Cmax-Cmin) 134 | 135 | coeff=tf.Variable(C,name='coef',trainable=False) 136 | 137 | #M=cause.get_shape.as_list()[0] 138 | x=tf.concat(inputs,axis=1) 139 | x1=tf.reshape(x,[-1,dim,1,1]) 140 | x2=tf.reshape(x,[-1,1,dim,1]) 141 | x3=tf.reshape(x,[-1,1,1,dim]) 142 | 143 | poly=tf.reduce_sum(x1*x2*x3*coeff,axis=[1,2,3]) 144 | return tf.reshape(poly,[-1,1]) 145 | 146 | 147 | class CompleteArrows(Arrows): # Data generated from the causal graph X1->X2->X3 148 | name='complete' 149 | def build(self): 150 | with tf.variable_scope(self.name): 151 | self.X1=poly(self.e1,name='X1') 152 | #self.X2=0.5*poly(self.X1,name='X1cX2')+0.5*self.e2 153 | #self.X3=0.5*poly(self.X1,self.X2,name='X1X2cX3')+0.5*self.e3 154 | self.X2=poly(self.X1,self.e2,name='X1cX2') 155 | self.X3=poly(self.X1,self.X2,self.e3,name='X1X2cX3') 156 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 157 | self.X=self.normalize_output(self.X) 158 | #print 'completearrowX.shape:',self.X.get_shape().as_list() 159 | class CompleteGenerator(Generator): 160 | name='complete' 161 | def build(self): 162 | with tf.variable_scope(self.name): 163 | self.z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 164 | z1,z2,z3=tf.split( self.z ,3,axis=1)#3=x_dim 165 | self.X1=self.smallNN(z1,'X1') 166 | self.X2=self.smallNN([self.X1,z2],'X1cX2') 167 | self.X3=self.smallNN([self.X1,self.X2,z3],'X1X2cX3') 168 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 169 | #print 'completegenX.shape:',self.X.get_shape().as_list() 170 | 171 | class ColliderArrows(Arrows): 172 | name='collider' 173 | def build(self): 174 | with tf.variable_scope(self.name): 175 | self.X1=poly(self.e1,name='X1') 176 | self.X3=poly(self.e3,name='X3') 177 | #self.X2=0.5*poly(self.X1,self.X3,name='X1X3cX2')+0.5*self.e2 178 | self.X2=poly(self.X1,self.X3,self.e2,name='X1X3cX2') 179 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 180 | self.X=self.normalize_output(self.X) 181 | class ColliderGenerator(Generator): 182 | name='collider' 183 | def build(self): 184 | with tf.variable_scope(self.name): 185 | self.z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 186 | z1,z2,z3=tf.split( self.z ,3,axis=1)#3=x_dim 187 | self.X1=self.smallNN(z1,'X1') 188 | self.X3=self.smallNN(z3,'X3') 189 | self.X2=self.smallNN([self.X1,self.X3,z2],'X1X3cX2') 190 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 191 | 192 | class LinearArrows(Arrows): 193 | name='linear' 194 | def build(self): 195 | with tf.variable_scope(self.name): 196 | self.X1=poly(self.e1,name='X1') 197 | #self.X2=0.5*poly(self.X1,name='X2')+0.5*self.e2 198 | #self.X3=0.5*poly(self.X2,name='X3')+0.5*self.e3 199 | self.X2=poly(self.X1,self.e2,name='X2') 200 | self.X3=poly(self.X2,self.e3,name='X3') 201 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 202 | self.X=self.normalize_output(self.X) 203 | class LinearGenerator(Generator): 204 | name='linear' 205 | def build(self): 206 | with tf.variable_scope(self.name): 207 | self.z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 208 | z1,z2,z3=tf.split( self.z ,3,axis=1)#3=x_dim 209 | self.X1=self.smallNN(z1,'X1') 210 | self.X2=self.smallNN([self.X1,z2],'X2') 211 | self.X3=self.smallNN([self.X2,z3],'X3') 212 | self.X=tf.concat([self.X1,self.X2,self.X3],axis=1) 213 | 214 | class NetworkArrows(Arrows): 215 | name='network' 216 | def build(self): 217 | with tf.variable_scope(self.name): 218 | self.hidden_size=10 219 | h0 = tf.tanh(linear(self.e1, self.hidden_size, 'netarrow0')) 220 | h1 = tf.tanh(linear(h0, self.hidden_size, 'netarrow1')) 221 | h2 = tf.tanh(linear(h1, self.hidden_size, 'netarrow2')) 222 | h3 = tf.tanh(linear(h2, self.hidden_size, 'netarrow3')) 223 | h4 = tf.sigmoid(linear(h3, self.x_dim, 'netarrow4')) 224 | self.X=self.normalize_output(h4) 225 | 226 | class FC3_Generator(Generator): 227 | name='fc3' 228 | def build(self): 229 | z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 230 | z1,z2,z3=tf.split( z ,3,axis=1)#3=x_dim 231 | h0 = tf.tanh(linear(z1, self.hidden_size, 'fc3gen0')) 232 | h1 = tf.tanh(linear(h0, self.hidden_size, 'fc3gen1')) 233 | h2 = tf.sigmoid(linear(h1, self.x_dim, 'fc3gen2')) 234 | self.X=h2 235 | 236 | class FC5_Generator(Generator): 237 | name='fc5' 238 | def build(self): 239 | z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 240 | z1,z2,z3=tf.split( z ,3,axis=1)#3=x_dim 241 | h0 = tf.tanh(linear(z1, self.hidden_size, 'fc5gen0')) 242 | h1 = tf.tanh(linear(h0, self.hidden_size, 'fc5gen1')) 243 | h2 = tf.tanh(linear(h1, self.hidden_size, 'fc5gen2')) 244 | h3 = tf.tanh(linear(h2, self.hidden_size, 'fc5gen3')) 245 | h4 = tf.sigmoid(linear(h3, self.x_dim, 'fc5gen4')) 246 | self.X=h4 247 | 248 | class FC10_Generator(Generator): 249 | name='fc10' 250 | def build(self): 251 | z=tf.random_uniform((self.N,self.x_dim*self.z_dim), 0,1,name='z') 252 | z1,z2,z3=tf.split( z ,3,axis=1)#3=x_dim 253 | h0 = tf.tanh(linear(z1, self.hidden_size, 'fc10gen0')) 254 | h1 = tf.tanh(linear(h0, self.hidden_size, 'fc10gen1')) 255 | h2 = tf.tanh(linear(h1, self.hidden_size, 'fc10gen2')) 256 | h3 = tf.tanh(linear(h2, self.hidden_size, 'fc10gen3')) 257 | h4 = tf.tanh(linear(h3, self.hidden_size, 'fc10gen4')) 258 | h5 = tf.tanh(linear(h4, self.hidden_size, 'fc10gen5')) 259 | h6 = tf.tanh(linear(h5, self.hidden_size, 'fc10gen6')) 260 | h7 = tf.tanh(linear(h6, self.hidden_size, 'fc10gen7')) 261 | h8 = tf.tanh(linear(h7, self.hidden_size, 'fc10gen8')) 262 | h9 = tf.sigmoid(linear(h8, self.x_dim, 'fc10gen9')) 263 | self.X=h9 264 | 265 | 266 | def minibatch(input_, num_kernels=5, kernel_dim=3): 267 | x = linear(input_, num_kernels * kernel_dim, scope='minibatch', stddev=0.02) 268 | activation = tf.reshape(x, (-1, num_kernels, kernel_dim)) 269 | diffs = tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0) 270 | abs_diffs = tf.reduce_sum(tf.abs(diffs), 2) 271 | minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2) 272 | return tf.concat([input_, minibatch_features],1) 273 | 274 | 275 | def Discriminator(input_, hidden_size,minibatch_layer=True,alpha=0.5,reuse=None): 276 | with tf.variable_scope('discriminator',reuse=reuse): 277 | h0_ = tf.nn.relu(linear(input_, hidden_size, 'disc0')) 278 | h0 = tf.maximum(alpha*h0_,h0_) 279 | h1_ = tf.nn.relu(linear(h0, hidden_size, 'disc1')) 280 | h1 = tf.maximum(alpha*h1_,h1_) 281 | if minibatch_layer: 282 | h2 = minibatch(h1) 283 | else: 284 | h2_ = tf.nn.relu(linear(h1, hidden_size, 'disc2')) 285 | h2 = tf.maximum(alpha*h2_,h2_) 286 | h3 = linear(h2, 1, 'disc3') 287 | return h3 288 | 289 | 290 | 291 | GeneratorTypes={CompleteGenerator.name:CompleteGenerator, 292 | ColliderGenerator.name:ColliderGenerator, 293 | LinearGenerator.name:LinearGenerator, 294 | FC3_Generator.name:FC3_Generator, 295 | FC5_Generator.name:FC5_Generator, 296 | FC10_Generator.name:FC10_Generator} 297 | DataTypes={CompleteArrows.name:CompleteArrows, 298 | ColliderArrows.name:ColliderArrows, 299 | LinearArrows.name:LinearArrows, 300 | NetworkArrows.name:NetworkArrows} 301 | 302 | #def poly1d(cause,name='poly1d',reuse=None): 303 | # #assumes input is in [0,1]. Enforces output is in [0,1] 304 | # print 'Warning poly1d not ready yet' 305 | # with tf.variable_scope(name,initializer=randunif,reuse=reuse): 306 | # #C=np.random.rand(1,2,2).astype(np.float32)#unif 307 | # C=np.random.rand(1,2,2,2).astype(np.float32)#unif 308 | # 309 | # #find min and max 310 | # N=2000 311 | # y=np.hstack([np.ones((N,1)),np.linspace(0,1.,N).reshape((N,1))]) 312 | # y1=np.reshape(y,[N,2,1,1]) 313 | # y2=np.reshape(y,[N,1,2,1]) 314 | # y3=np.reshape(y,[N,1,1,2]) 315 | # 316 | # test_poly=np.sum(y1*y2*y3*C,axis=(1,2,3)) 317 | # Cmin=np.min(test_poly) 318 | # Cmax=np.max(test_poly) 319 | # 320 | # #normalize [0,1]->[0,1] 321 | # C[0,0,0,0]-=Cmin 322 | # C/=(Cmax-Cmin) 323 | # 324 | # coeff=tf.Variable(C,name='coef',trainable=False) 325 | # x2=tf.reshape(tf.stack([tf.ones_like(cause),cause],axis=1),[-1,1,2]) 326 | # x1=tf.transpose(x2,[0,2,1]) 327 | # poly=tf.reduce_sum(x1*x2*coeff,axis=[1,2]) 328 | # out= tf.squeeze(poly) 329 | # return poly 330 | # 331 | # #coeff=tf.Variable(trainable=False,expected_shape=[1,3]) 332 | # # X=tf.stack([cause,cause*cause,cause*cause*cause],axis=1) 333 | # # return tf.reduce_sum(coeff*X,axis=1)/tf.reduce_max(coeff) 334 | # 335 | #def poly2d(cause,cause2,name='poly2d',reuse=None): 336 | # with tf.variable_scope(name,initializer=randunif,reuse=reuse): 337 | # #coeff=tf.Variable(np.random.randn(1,2,2,2).astype(np.float32),trainable=False) 338 | # #x3=tf.reshape(tf.stack([cause,cause2],axis=0),[-1,1,1,2]) 339 | # #x2=tf.transpose(x3,[0,2,3,1]) 340 | # #x1=tf.transpose(x2,[0,2,3,1]) 341 | # 342 | # C=np.random.rand(1,3,3,3).astype(np.float32) 343 | # C[:,0,0,0]=0.#constant 344 | # C[:,0,2,0]=1.#x^3,y^3 coeff 345 | # C[:,0,0,2]=1. 346 | # coeff=tf.Variable(C, trainable=False) 347 | # x3=tf.reshape(tf.stack([tf.ones_like(cause),cause,cause2],axis=1),[-1,1,1,3]) 348 | # x2=tf.transpose(x3,[0,2,3,1]) 349 | # x1=tf.transpose(x2,[0,2,3,1]) 350 | # 351 | # poly=tf.reduce_sum(x1*x2*x3*coeff,axis=[1,2,3]) 352 | # 353 | # #out = tf.squeeze(poly)/tf.reduce_max(coeff) 354 | # out= tf.squeeze(poly) 355 | # return out 356 | 357 | -------------------------------------------------------------------------------- /synthetic/run_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #This script should be called with CUDA_VISIBLE_DEVICES 4 | #already set. This script runs 1 of each gan model for 5 | #1 of each dataset model 6 | 7 | set -e 8 | 9 | cvd=${CUDA_VISIBLE_DEVICES:?"Needs to be set"} 10 | echo "DEVICES=$cvd" 11 | 12 | #Sorry tqmd will produce some spastic output 13 | 14 | #for i in {1..5} 15 | for i in {1..30} 16 | do 17 | echo "GPU "$CUDA_VISIBLE_DEVICES" Iter $i" 18 | 19 | python main.py --data_type=linear & 20 | sleep 2s 21 | python main.py --data_type=collider & 22 | sleep 2s 23 | python main.py --data_type=complete 24 | 25 | #python main.py --data_type=linear & 26 | #sleep 2s 27 | #python main.py --data_type=linear & 28 | #sleep 2s 29 | #python main.py --data_type=linear 30 | 31 | #python main.py --data_type=network & 32 | #python main.py --data_type=network & 33 | #python main.py --data_type=network 34 | 35 | #Make sure all finished 36 | echo "Sleeping" 37 | sleep 5m 38 | 39 | done 40 | 41 | 42 | 43 | echo "finshed fork_datasets.sh" 44 | 45 | -------------------------------------------------------------------------------- /synthetic/tboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from subprocess import call 5 | 6 | def file2number(fname): 7 | nums=[s for s in fname.split('_') if s.isdigit()] 8 | if len(nums)==0: 9 | nums=['0'] 10 | number=int(''.join(nums)) 11 | return number 12 | 13 | if __name__=='__main__': 14 | root='./logs' 15 | 16 | logs=os.listdir(root) 17 | logs.sort(key=lambda x:file2number(x)) 18 | 19 | 20 | logdir=os.path.join(root,logs[-1]) 21 | print 'running tensorboard on logdir:',logdir 22 | 23 | call(['tensorboard', '--logdir',logdir]) 24 | 25 | -------------------------------------------------------------------------------- /synthetic/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import logging 4 | import numpy as np 5 | import pandas as pd 6 | import shutil 7 | import json 8 | import sys 9 | import os 10 | from datetime import datetime 11 | from tqdm import trange 12 | import matplotlib.pyplot as plt 13 | 14 | from os import listdir 15 | from os.path import isfile,join 16 | 17 | from utils import calc_tvd,summary_scatterplots,Timer,summary_losses,make_summary 18 | from models import GeneratorTypes,DataTypes,Discriminator,sxe 19 | 20 | class GAN(object): 21 | def __init__(self,config,gan_type,data,parent_dir): 22 | self.config=config 23 | self.gan_type=gan_type 24 | self.data=data 25 | self.Xd=data.X 26 | self.parent_dir=parent_dir 27 | self.prepare_model_dir() 28 | self.prepare_logger() 29 | 30 | with tf.variable_scope(gan_type): 31 | self.step=tf.Variable(0,'step') 32 | self.inc_step=tf.assign(self.step,self.step+1) 33 | self.build_model() 34 | self.build_summaries()#This can be either in var_scope(name) or out 35 | 36 | def build_model(self): 37 | Gen=GeneratorTypes[self.gan_type] 38 | config=self.config 39 | self.gen=Gen(config.batch_size,config.gen_hidden_size,config.gen_z_dim) 40 | 41 | with tf.variable_scope('Disc') as scope: 42 | self.D1 = Discriminator(self.data.X, config.disc_hidden_size) 43 | scope.reuse_variables() 44 | self.D2 = Discriminator(self.gen.X, config.disc_hidden_size) 45 | d_var = tf.contrib.framework.get_variables(scope) 46 | 47 | d_loss_real=tf.reduce_mean( sxe(self.D1,1) ) 48 | d_loss_fake=tf.reduce_mean( sxe(self.D2,0) ) 49 | self.loss_d = d_loss_real + d_loss_fake 50 | self.loss_g = tf.reduce_mean( sxe(self.D2,1) ) 51 | 52 | optimizer=tf.train.AdamOptimizer 53 | g_optimizer=optimizer(self.config.lr_gen) 54 | d_optimizer=optimizer(self.config.lr_disc) 55 | self.opt_d = d_optimizer.minimize(self.loss_d,var_list= d_var) 56 | self.opt_g = g_optimizer.minimize(self.loss_g,var_list= self.gen.tr_var, 57 | global_step=self.gen.step) 58 | 59 | with tf.control_dependencies([self.inc_step]): 60 | self.train_op=tf.group(self.opt_d,self.opt_g) 61 | 62 | def build_summaries(self): 63 | d_summ=tf.summary.scalar(self.data.name+'_dloss',self.loss_d) 64 | g_summ=tf.summary.scalar(self.data.name+'_gloss',self.loss_g) 65 | self.summaries=[d_summ,g_summ] 66 | self.summary_op=tf.summary.merge(self.summaries) 67 | self.tf_scatter=tf.placeholder(tf.uint8,[3,480,640,3]) 68 | scatter_name='scatter_D'+self.data.name+'_G'+self.gen.name 69 | self.g_scatter_summary=tf.summary.image(scatter_name,self.tf_scatter,max_outputs=3) 70 | self.summary_writer=tf.summary.FileWriter(self.model_dir) 71 | 72 | def record_losses(self,sess): 73 | step, sum_loss_g, sum_loss_d = summary_losses(sess,self) 74 | self.summary_writer.add_summary(sum_loss_g,step) 75 | self.summary_writer.add_summary(sum_loss_d,step) 76 | self.summary_writer.flush() 77 | 78 | def record_tvd(self,sess): 79 | step,tvd,mvd = calc_tvd(sess,self.gen,self.data) 80 | self.log_tvd(step,tvd,mvd) 81 | summ_tvd=make_summary(self.data.name+'_tvd',tvd) 82 | summ_mvd=make_summary(self.data.name+'_mvd',mvd) 83 | self.summary_writer.add_summary(summ_tvd,step) 84 | self.summary_writer.add_summary(summ_mvd,step) 85 | self.summary_writer.flush() 86 | def record_scatter(self,sess): 87 | Xg=sess.run(self.gen.X,{self.gen.N:5000}) 88 | X1,X2,X3=np.split(Xg,3,axis=1) 89 | x1x2,x1x3,x2x3 = summary_scatterplots(X1,X2,X3) 90 | step,Pg_summ=sess.run([self.step,self.g_scatter_summary],{self.tf_scatter:np.concatenate([x1x2,x1x3,x2x3])}) 91 | self.summary_writer.add_summary(Pg_summ,step) 92 | self.summary_writer.flush() 93 | 94 | # if self.config.save_pdfs: 95 | # self.save_np_scatter(step,X1,X3) 96 | 97 | #Maybe it's the supervisor creating the segfault?? 98 | #Try just one model at a time 99 | 100 | # #will cause segfault ;) 101 | # def save_np_scatter(self,step,x,y,save_dir=None,ext='.pdf'): 102 | # ''' 103 | # This is a convenience that just saves the image as a pdf in addition to putting it on 104 | # tensorboard. only does x1x3 because that's what I needed at the moment 105 | # 106 | # sorry I wrote this really quickly 107 | # TODO: make less bad. 108 | # ''' 109 | # plt.scatter(x,y) 110 | # plt.title('X1X3') 111 | # plt.xlabel('X1') 112 | # plt.ylabel('X3') 113 | # plt.xlim([0,1]) 114 | # plt.ylim([0,1]) 115 | # 116 | # scatter_dir=os.path.join(self.model_dir,'scatter') 117 | # 118 | # save_dir=save_dir or scatter_dir 119 | # if not os.path.exists(save_dir): 120 | # os.mkdir(save_dir) 121 | # 122 | # save_name=os.path.join(save_dir,'{}_scatter_x1x3_{}_{}'+ext) 123 | # save_path=save_name.format(step,self.config.data_type,self.gan_type) 124 | # 125 | # plt.savefig(save_path) 126 | 127 | 128 | 129 | def prepare_model_dir(self): 130 | self.model_dir=os.path.join(self.parent_dir,self.gan_type) 131 | if not os.path.exists(self.model_dir): 132 | os.mkdir(self.model_dir) 133 | print('GAN Model directory is ',self.model_dir) 134 | def prepare_logger(self): 135 | self.logger=logging.getLogger(self.gan_type) 136 | pth=os.path.join(self.model_dir,'tvd.csv') 137 | file_handler=logging.FileHandler(pth) 138 | self.logger.addHandler(file_handler) 139 | self.logger.setLevel(logging.INFO) 140 | self.logger.info('iter tvd mvd') 141 | def log_tvd(self,step,tvd,mvd): 142 | log_str=' '.join([str(step),str(tvd),str(mvd)]) 143 | self.logger.info(log_str) 144 | 145 | 146 | class Trainer(object): 147 | def __init__(self,config,data_type): 148 | self.config=config 149 | self.data_type=data_type 150 | self.prepare_model_dir() 151 | 152 | 153 | 154 | #with tf.variable_scope('trainer'):#commented to get summaries on same plot 155 | self.step=tf.Variable(0,'step') 156 | self.inc_step=tf.assign(self.step,self.step+1) 157 | self.build_model() 158 | 159 | self.summary_writer=tf.summary.FileWriter(self.model_dir) 160 | 161 | self.saver=tf.train.Saver() 162 | 163 | #sv = tf.train.Supervisor( 164 | # logdir=self.save_model_dir, 165 | # is_chief=True, 166 | # saver=self.saver, 167 | # summary_op=None, 168 | # summary_writer=self.summary_writer, 169 | # save_model_secs=300, 170 | # global_step=self.step, 171 | # ready_for_local_init_op=None 172 | # ) 173 | 174 | gpu_options = tf.GPUOptions(allow_growth=True, 175 | per_process_gpu_memory_fraction=0.333) 176 | sess_config = tf.ConfigProto(allow_soft_placement=True, 177 | gpu_options=gpu_options) 178 | #self.sess = sv.prepare_or_wait_for_session(config=sess_config) 179 | self.sess = tf.Session(config=sess_config) 180 | 181 | 182 | init=tf.global_variables_initializer() 183 | self.sess.run(init) 184 | 185 | #if load_path, replace initialized values 186 | if self.config.load_path: 187 | print(" [*] Attempting to restore {}".format(self.config.load_path)) 188 | self.saver.restore(self.sess,self.config.load_path) 189 | 190 | #print(" [*] Attempting to restore {}".format(ckpt)) 191 | #self.saver.restore(self.sess,ckpt) 192 | #print(" [*] Success to read {}".format(ckpt)) 193 | 194 | 195 | 196 | if not self.config.load_path: 197 | #once data scatterplot (doesn't change during training) 198 | self.data_scatterplot() 199 | 200 | 201 | def data_scatterplot(self): 202 | Xd=self.sess.run(self.data.X,{self.data.N:5000}) 203 | X1,X2,X3=np.split(Xd,3,axis=1) 204 | x1x2,x1x3,x2x3 = summary_scatterplots(X1,X2,X3) 205 | step,Pg_summ=self.sess.run([self.step,self.d_scatter_summary],{self.tf_scatter:np.concatenate([x1x2,x1x3,x2x3])}) 206 | self.summary_writer.add_summary(Pg_summ,step) 207 | self.summary_writer.flush() 208 | 209 | 210 | def build_model(self): 211 | self.data=DataTypes[self.data_type](self.config.batch_size) 212 | 213 | self.gans=[GAN(self.config,n,self.data,self.model_dir) for n in GeneratorTypes.keys()] 214 | 215 | with tf.control_dependencies([self.inc_step]): 216 | self.train_op=tf.group(*[gan.train_op for gan in self.gans]) 217 | #self.train_op=tf.group(gan.train_op for gan in self.gans.values()) 218 | 219 | #Used for generating image summaries of scatterplots 220 | self.tf_scatter=tf.placeholder(tf.uint8,[3,480,640,3]) 221 | self.d_scatter_summary=tf.summary.image('scatter_Data_'+self.data_type,self.tf_scatter,max_outputs=3) 222 | 223 | 224 | def train(self): 225 | self.train_timer =Timer() 226 | self.losses_timer =Timer() 227 | self.tvd_timer =Timer() 228 | self.scatter_timer =Timer() 229 | 230 | self.log_step=50 231 | self.max_step=50001 232 | #self.max_step=501 233 | for step in trange(self.max_step): 234 | 235 | if step % self.log_step == 0: 236 | for gan in self.gans: 237 | self.losses_timer.on() 238 | gan.record_losses(self.sess) 239 | self.losses_timer.off() 240 | 241 | self.tvd_timer.on() 242 | gan.record_tvd(self.sess) 243 | self.tvd_timer.off() 244 | 245 | if step % (10*self.log_step) == 0: 246 | for gan in self.gans: 247 | self.scatter_timer.on() 248 | gan.record_scatter(self.sess) 249 | 250 | #DEBUG: reassure me nothing changes during optimization 251 | #self.data_scatterplot() 252 | 253 | self.scatter_timer.off() 254 | 255 | if step % (5000) == 0: 256 | self.saver.save(self.sess,self.save_model_name,step) 257 | 258 | self.train_timer.on() 259 | self.sess.run(self.train_op) 260 | self.train_timer.off() 261 | 262 | 263 | print("Timers:") 264 | print(self.train_timer) 265 | print(self.losses_timer) 266 | print(self.tvd_timer) 267 | print(self.scatter_timer) 268 | 269 | 270 | def prepare_model_dir(self): 271 | if self.config.load_path: 272 | self.model_dir=self.config.load_path 273 | else: 274 | pth=datetime.now().strftime("%m%d_%H%M%S")+'_'+self.data_type 275 | self.model_dir=os.path.join(self.config.model_dir,pth) 276 | 277 | 278 | if not os.path.exists(self.model_dir): 279 | os.mkdir(self.model_dir) 280 | print('Model directory is ',self.model_dir) 281 | 282 | self.save_model_dir=os.path.join(self.model_dir,'checkpoints') 283 | if not os.path.exists(self.save_model_dir): 284 | os.mkdir(self.save_model_dir) 285 | self.save_model_name=os.path.join(self.save_model_dir,'Model') 286 | 287 | 288 | param_path = os.path.join(self.model_dir, "params.json") 289 | print("[*] MODEL dir: %s" % self.model_dir) 290 | print("[*] PARAM path: %s" % param_path) 291 | with open(param_path, 'w') as fp: 292 | json.dump(self.config.__dict__, fp, indent=4, sort_keys=True) 293 | 294 | config=self.config 295 | if config.is_train and not config.load_path: 296 | config.log_code_dir=os.path.join(self.model_dir,'code') 297 | for path in [self.model_dir, config.log_code_dir]: 298 | if not os.path.exists(path): 299 | os.makedirs(path) 300 | 301 | #Copy python code in directory into model_dir/code for future reference: 302 | code_dir=os.path.dirname(os.path.realpath(sys.argv[0])) 303 | model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))] 304 | for f in model_files: 305 | if f.endswith('.py'): 306 | shutil.copy2(f,config.log_code_dir) 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | -------------------------------------------------------------------------------- /synthetic/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import os 4 | from os import listdir 5 | from os.path import isfile, join 6 | from skimage import io 7 | import shutil 8 | import sys 9 | import math 10 | import time 11 | import json 12 | import logging 13 | import numpy as np 14 | from PIL import Image 15 | from datetime import datetime 16 | from tensorflow.core.framework import summary_pb2 17 | import matplotlib.pyplot as plt 18 | 19 | def make_summary(name, val): 20 | return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)]) 21 | 22 | def summary_losses(sess,model,N=1000): 23 | step,loss_g,loss_d=sess.run([model.step,model.loss_g,model.loss_d],{model.data.N:N,model.gen.N:N}) 24 | lgsum=make_summary(model.data.name+'_gloss',loss_g) 25 | ldsum=make_summary(model.data.name+'_dloss',loss_d) 26 | return step,lgsum, ldsum 27 | 28 | def calc_tvd(sess,Generator,Data,N=50000,nbins=10): 29 | Xd=sess.run(Data.X,{Data.N:N}) 30 | step,Xg=sess.run([Generator.step,Generator.X],{Generator.N:N}) 31 | 32 | p_gen,_ = np.histogramdd(Xg,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True) 33 | p_dat,_ = np.histogramdd(Xd,bins=nbins,range=[[0,1],[0,1],[0,1]],normed=True) 34 | p_gen/=nbins**3 35 | p_dat/=nbins**3 36 | tvd=0.5*np.sum(np.abs( p_gen-p_dat )) 37 | mvd=np.max(np.abs( p_gen-p_dat )) 38 | 39 | return step,tvd, mvd 40 | 41 | s_tvd=make_summary(Data.name+'_tvd',tvd) 42 | s_mvd=make_summary(Data.name+'_mvd',mvd) 43 | 44 | return step,s_tvd,s_mvd 45 | #return make_summary('tvd/'+Generator.name,tvd) 46 | 47 | 48 | def summary_stats(name,tensor,hist=False): 49 | ave=tf.reduce_mean(tensor) 50 | std=tf.sqrt(tf.reduce_mean(tf.square(ave-tensor))) 51 | tf.summary.scalar(name+'_ave',ave) 52 | tf.summary.scalar(name+'_std',std) 53 | if hist: 54 | tf.summary.histogram(name+'_hist',tensor) 55 | 56 | def summary_scatterplots(X1,X2,X3): 57 | with tf.name_scope('scatter'): 58 | img1=summary_scatter2d(X1,X2,'X1X2',xlabel='X1',ylabel='X2') 59 | img2=summary_scatter2d(X1,X3,'X1X3',xlabel='X1',ylabel='X3') 60 | img3=summary_scatter2d(X2,X3,'X2X3',xlabel='X2',ylabel='X3') 61 | plt.close() 62 | return img1,img2,img3 63 | 64 | 65 | 66 | def summary_scatter2d(x,y,title='2dscatterplot',xlabel=None,ylabel=None): 67 | fig=scatter2d(x,y,title,xlabel=xlabel,ylabel=ylabel) 68 | 69 | fig.canvas.draw() 70 | rgb=fig.canvas.tostring_rgb() 71 | buf=np.fromstring(rgb,dtype=np.uint8) 72 | 73 | w,h = fig.canvas.get_width_height() 74 | img=buf.reshape(1,h,w,3) 75 | #summary=tf.summary.image(title,img) 76 | plt.close(fig) 77 | #fig.clf() 78 | return img 79 | 80 | def scatter2d(x,y,title='2dscatterplot',xlabel=None,ylabel=None): 81 | fig=plt.figure() 82 | plt.scatter(x,y) 83 | plt.title(title) 84 | if xlabel: 85 | plt.xlabel(xlabel) 86 | if ylabel: 87 | plt.ylabel(ylabel) 88 | 89 | if not 0<=np.min(x)<=np.max(x)<=1: 90 | raise ValueError('summary_scatter2d title:',title,' input x exceeded [0,1] range.\ 91 | min:',np.min(x),' max:',np.max(x)) 92 | if not 0<=np.min(y)<=np.max(y)<=1: 93 | raise ValueError('summary_scatter2d title:',title,' input y exceeded [0,1] range.\ 94 | min:',np.min(y),' max:',np.max(y)) 95 | 96 | plt.xlim([0,1]) 97 | plt.ylim([0,1]) 98 | return fig 99 | 100 | 101 | def prepare_dirs_and_logger(config): 102 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 103 | logger = logging.getLogger() 104 | 105 | for hdlr in logger.handlers: 106 | logger.removeHandler(hdlr) 107 | 108 | handler = logging.StreamHandler() 109 | handler.setFormatter(formatter) 110 | 111 | logger.addHandler(handler) 112 | 113 | if config.load_path: 114 | if config.load_path.startswith(config.log_dir): 115 | config.model_dir = config.load_path 116 | else: 117 | if config.load_path.startswith(config.dataset): 118 | config.model_name = config.load_path 119 | else: 120 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 121 | else: 122 | config.model_name = "{}_{}".format(config.dataset, get_time()) 123 | 124 | if not hasattr(config, 'model_dir'): 125 | config.model_dir = os.path.join(config.log_dir, config.model_name) 126 | config.data_path = os.path.join(config.data_dir, config.dataset) 127 | 128 | if config.is_train: 129 | config.log_code_dir=os.path.join(config.model_dir,'code') 130 | for path in [config.log_dir, config.data_dir, 131 | config.model_dir, config.log_code_dir]: 132 | if not os.path.exists(path): 133 | os.makedirs(path) 134 | 135 | #Copy python code in directory into model_dir/code for future reference: 136 | code_dir=os.path.dirname(os.path.realpath(sys.argv[0])) 137 | model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))] 138 | for f in model_files: 139 | if f.endswith('.py'): 140 | shutil.copy2(f,config.log_code_dir) 141 | 142 | def get_time(): 143 | return datetime.now().strftime("%m%d_%H%M%S") 144 | 145 | def save_config(config): 146 | param_path = os.path.join(config.model_dir, "params.json") 147 | 148 | print("[*] MODEL dir: %s" % config.model_dir) 149 | print("[*] PARAM path: %s" % param_path) 150 | 151 | with open(param_path, 'w') as fp: 152 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 153 | 154 | 155 | 156 | class Timer(object): 157 | def __init__(self): 158 | self.total_section_time=0. 159 | self.iter=0 160 | def on(self): 161 | self.t0=time.time() 162 | def off(self): 163 | self.total_section_time+=time.time()-self.t0 164 | self.iter+=1 165 | def __str__(self): 166 | n_min=self.total_section_time/60. 167 | return '%.2fmin'%n_min 168 | -------------------------------------------------------------------------------- /tboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from subprocess import call 5 | 6 | def file2number(fname): 7 | nums=[s for s in fname.split('_') if s.isdigit()] 8 | if len(nums)==0: 9 | nums=['0'] 10 | number=int(''.join(nums)) 11 | return number 12 | 13 | if __name__=='__main__': 14 | root='./logs' 15 | 16 | logs=os.listdir(root) 17 | logs.sort(key=lambda x:file2number(x)) 18 | 19 | 20 | logdir=os.path.join(root,logs[-1]) 21 | print 'running tensorboard on logdir:',logdir 22 | 23 | call(['tensorboard', '--logdir',logdir]) 24 | 25 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from functools import partial 4 | import os 5 | from os import listdir 6 | from os.path import isfile, join 7 | import shutil 8 | import sys 9 | from glob import glob 10 | import math 11 | import json 12 | import logging 13 | import numpy as np 14 | from PIL import Image 15 | from datetime import datetime 16 | from tensorflow.core.framework import summary_pb2 17 | 18 | 19 | 20 | def make_summary(name, val): 21 | return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)]) 22 | 23 | def summary_stats(name,tensor,collections=None,hist=False): 24 | collections=collections or [tf.GraphKeys.SUMMARIES] 25 | ave=tf.reduce_mean(tensor) 26 | std=tf.sqrt(tf.reduce_mean(tf.square(ave-tensor))) 27 | tf.summary.scalar(name+'_ave',ave,collections) 28 | tf.summary.scalar(name+'_std',std,collections) 29 | if hist: 30 | tf.summary.histogram(name+'_hist',tensor,collections) 31 | 32 | 33 | def prepare_dirs_and_logger(config): 34 | 35 | if config.load_path: 36 | strip_lp=config.load_path.strip('./') 37 | if strip_lp.startswith(config.log_dir): 38 | config.model_dir = config.load_path 39 | else: 40 | if config.load_path.startswith(config.dataset): 41 | config.model_name = config.load_path 42 | else: 43 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 44 | else:#new model 45 | config.model_name = "{}_{}".format(config.dataset, get_time()) 46 | if config.descrip: 47 | config.model_name+='_'+config.descrip 48 | 49 | 50 | if not hasattr(config, 'model_dir'): 51 | config.model_dir = os.path.join(config.log_dir, config.model_name) 52 | config.data_path = os.path.join(config.data_dir, config.dataset) 53 | 54 | 55 | if not config.load_path: 56 | config.log_code_dir=os.path.join(config.model_dir,'code') 57 | for path in [config.log_dir, config.data_dir, 58 | config.model_dir]: 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | 62 | #Copy python code in directory into model_dir/code for future reference: 63 | #All python files in this directory are copied. 64 | code_dir=os.path.dirname(os.path.realpath(sys.argv[0])) 65 | 66 | ##additionally, all python files in these directories are also copied. Also symlinks are copied. The idea is to allow easier model loading in the future 67 | allowed_dirs=['causal_controller','causal_began','causal_dcgan','figure_scripts'] 68 | 69 | #ignore copy of all non-*.py except for these directories 70 | #If you make another folder you want copied, you have to add it here 71 | ignore_these=partial(ignore_except,allowed_dirs=allowed_dirs) 72 | shutil.copytree(code_dir,config.log_code_dir,symlinks=True,ignore=ignore_these) 73 | 74 | 75 | # model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))] 76 | # for f in model_files: 77 | # if f.endswith('.py'): 78 | # shutil.copy2(f,config.log_code_dir) 79 | 80 | 81 | def ignore_except(src,contents,allowed_dirs): 82 | files=filter(os.path.isfile,contents) 83 | dirs=filter(os.path.isdir,contents) 84 | ignored_files=[f for f in files if not f.endswith('.py')] 85 | ignored_dirs=[d for d in dirs if not d in allowed_dirs] 86 | return ignored_files+ignored_dirs 87 | 88 | def get_time(): 89 | return datetime.now().strftime("%m%d_%H%M%S") 90 | 91 | def save_configs(config,cc_config,dcgan_config,began_config): 92 | model_dir=config.model_dir 93 | print("[*] MODEL dir: %s" % model_dir) 94 | save_config(config) 95 | save_config(cc_config,'cc_params.json',model_dir) 96 | save_config(dcgan_config,'dcgan_params.json',model_dir) 97 | save_config(began_config,'began_params.json',model_dir) 98 | 99 | 100 | def save_config(config,name="params.json",where=None): 101 | where=where or config.model_dir 102 | param_path = os.path.join(where, name) 103 | 104 | print("[*] PARAM path: %s" % param_path) 105 | 106 | with open(param_path, 'w') as fp: 107 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 108 | 109 | def get_available_gpus(): 110 | from tensorflow.python.client import device_lib 111 | local_device_protos = device_lib.list_local_devices() 112 | return [x.name for x in local_device_protos if x.device_type=='GPU'] 113 | 114 | def distribute_input_data(data_loader,num_gpu): 115 | ''' 116 | data_loader is a dictionary of tensors that are fed into our model 117 | 118 | This function takes that dictionary of n*batch_size dimension tensors 119 | and breaks it up into n dictionaries with the same key of tensors with 120 | dimension batch_size. One is given to each gpu 121 | ''' 122 | if num_gpu==0: 123 | return {'/cpu:0':data_loader} 124 | 125 | gpus=get_available_gpus() 126 | if num_gpu > len(gpus): 127 | raise ValueError('number of gpus specified={}, more than gpus available={}'.format(num_gpu,len(gpus))) 128 | 129 | gpus=gpus[:num_gpu] 130 | 131 | data_by_gpu={g:{} for g in gpus} 132 | for key,value in data_loader.items(): 133 | spl_vals=tf.split(value,num_gpu) 134 | for gpu,val in zip(gpus,spl_vals): 135 | data_by_gpu[gpu][key]=val 136 | 137 | return data_by_gpu 138 | 139 | 140 | def rank(array): 141 | return len(array.shape) 142 | 143 | def make_grid(tensor, nrow=8, padding=2, 144 | normalize=False, scale_each=False): 145 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py 146 | minor improvement, row/col was reversed""" 147 | nmaps = tensor.shape[0] 148 | ymaps = min(nrow, nmaps) 149 | xmaps = int(math.ceil(float(nmaps) / ymaps)) 150 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 151 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 152 | k = 0 153 | for y in range(ymaps): 154 | for x in range(xmaps): 155 | if k >= nmaps: 156 | break 157 | h, h_width = y * height + 1 + padding // 2, height - padding 158 | w, w_width = x * width + 1 + padding // 2, width - padding 159 | 160 | grid[h:h+h_width, w:w+w_width] = tensor[k] 161 | k = k + 1 162 | return grid 163 | 164 | def save_image(tensor, filename, nrow=8, padding=2, 165 | normalize=False, scale_each=False): 166 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 167 | normalize=normalize, scale_each=scale_each) 168 | im = Image.fromarray(ndarr) 169 | im.save(filename) 170 | --------------------------------------------------------------------------------