├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt └── src ├── data_loader ├── __init__.py └── synthetic_dataset.py ├── demo.ipynb ├── helpers ├── analyze_utils.py ├── config_utils.py ├── dir_utils.py ├── log_helper.py └── tf_utils.py ├── main.py ├── models ├── __init__.py └── notears.py └── trainers ├── __init__.py └── al_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ignavier Ng Zhi Yong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOTEARS-Tensorflow 2 | 3 | This repository is a Tensorflow reimplementation of NOTEARS [1]. 4 | 5 | ## 1. Setup 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## 2. Training 12 | To run `NoTears`, for example, run: 13 | 14 | ``` 15 | python main.py --seed 1230 \ 16 | --d 20 \ 17 | --n 1000 \ 18 | --degree 4 \ 19 | --graph_thres 0.3 \ 20 | --l1_lambda 0.1 21 | ``` 22 | 23 | ### Remark 24 | - Some of the code implementation is referred from https://github.com/xunzheng/notears 25 | 26 | ## References 27 | [1] Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P. DAGs with NO TEARS: Continuous optimization for structure learning. In Advances in Neural Information Processing Systems, 2018. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | networkx 3 | pyyaml 4 | pytz 5 | matplotlib 6 | tensorflow==1.15.4 -------------------------------------------------------------------------------- /src/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from data_loader.synthetic_dataset import SyntheticDataset -------------------------------------------------------------------------------- /src/data_loader/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import networkx as nx 4 | 5 | 6 | class SyntheticDataset(object): 7 | _logger = logging.getLogger(__name__) 8 | 9 | def __init__(self, n, d, graph_type, degree, sem_type, noise_scale=1.0, dataset_type='linear'): 10 | self.n = n 11 | self.d = d 12 | self.graph_type = graph_type 13 | self.degree = degree 14 | self.sem_type = sem_type 15 | self.noise_scale = noise_scale 16 | self.dataset_type = dataset_type 17 | self.w_range = (0.5, 2.0) 18 | 19 | self._setup() 20 | self._logger.debug('Finished setting up dataset class') 21 | 22 | def _setup(self): 23 | self.W = SyntheticDataset.simulate_random_dag(self.d, self.degree, 24 | self.graph_type, self.w_range) 25 | 26 | self.X = SyntheticDataset.simulate_sem(self.W, self.n, self.sem_type, self.w_range, 27 | self.noise_scale, self.dataset_type) 28 | 29 | @staticmethod 30 | def simulate_random_dag(d, degree, graph_type, w_range): 31 | """Simulate random DAG with some expected degree. 32 | 33 | Args: 34 | d: number of nodes 35 | degree: expected node degree, in + out 36 | graph_type: {erdos-renyi, barabasi-albert, full} 37 | w_range: weight range +/- (low, high) 38 | 39 | Returns: 40 | W: weighted DAG 41 | """ 42 | if graph_type == 'erdos-renyi': 43 | prob = float(degree) / (d - 1) 44 | B = np.tril((np.random.rand(d, d) < prob).astype(float), k=-1) 45 | elif graph_type == 'barabasi-albert': 46 | m = int(round(degree / 2)) 47 | B = np.zeros([d, d]) 48 | bag = [0] 49 | for ii in range(1, d): 50 | dest = np.random.choice(bag, size=m) 51 | for jj in dest: 52 | B[ii, jj] = 1 53 | bag.append(ii) 54 | bag.extend(dest) 55 | elif graph_type == 'full': # ignore degree, only for experimental use 56 | B = np.tril(np.ones([d, d]), k=-1) 57 | else: 58 | raise ValueError('Unknown graph type') 59 | # random permutation 60 | P = np.random.permutation(np.eye(d, d)) # permutes first axis only 61 | B_perm = P.T.dot(B).dot(P) 62 | U = np.random.uniform(low=w_range[0], high=w_range[1], size=[d, d]) 63 | U[np.random.rand(d, d) < 0.5] *= -1 64 | W = (B_perm != 0).astype(float) * U 65 | 66 | return W 67 | 68 | @staticmethod 69 | def simulate_sem(W, n, sem_type, w_range, noise_scale=1.0, dataset_type='nonlinear_1'): 70 | """Simulate samples from SEM with specified type of noise. 71 | 72 | Args: 73 | W: weigthed DAG 74 | n: number of samples 75 | sem_type: {linear-gauss,linear-exp,linear-gumbel} 76 | noise_scale: scale parameter of noise distribution in linear SEM 77 | 78 | Returns: 79 | X: [n,d] sample matrix 80 | """ 81 | G = nx.DiGraph(W) 82 | d = W.shape[0] 83 | X = np.zeros([n, d]) 84 | ordered_vertices = list(nx.topological_sort(G)) 85 | assert len(ordered_vertices) == d 86 | for j in ordered_vertices: 87 | parents = list(G.predecessors(j)) 88 | if dataset_type == 'linear': 89 | eta = X[:, parents].dot(W[parents, j]) 90 | else: 91 | raise ValueError('Unknown dataset type') 92 | 93 | if sem_type == 'linear-gauss': 94 | X[:, j] = eta + np.random.normal(scale=noise_scale, size=n) 95 | elif sem_type == 'linear-exp': 96 | X[:, j] = eta + np.random.exponential(scale=noise_scale, size=n) 97 | elif sem_type == 'linear-gumbel': 98 | X[:, j] = eta + np.random.gumbel(scale=noise_scale, size=n) 99 | else: 100 | raise ValueError('Unknown sem type') 101 | 102 | return X 103 | 104 | 105 | if __name__ == '__main__': 106 | n, d = 3000, 20 107 | graph_type, degree, sem_type = 'erdos-renyi', 3, 'linear-gauss' 108 | noise_scale = 1.0 109 | 110 | dataset = SyntheticDataset(n, d, graph_type, degree, sem_type, 111 | noise_scale, dataset_type='linear') 112 | print('dataset.X.shape: {}'.format(dataset.X.shape)) 113 | print('dataset.W.shape: {}'.format(dataset.W.shape)) 114 | -------------------------------------------------------------------------------- /src/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "2020-08-01 04:07:04,175 INFO - helpers.log_helper - Finished configuring logger.\n", 13 | "2020-08-01 04:07:04,248 INFO - __main__ - Finished generating dataset\n", 14 | "WARNING:tensorflow:From /home/mila/z/zhi.yong.ignavier-ng/.conda/envs/notears/lib/python3.6/site-packages/tensorflow_core/python/ops/linalg/linalg_impl.py:283: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 15 | "Instructions for updating:\n", 16 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 17 | "2020-08-01 04:07:04,372 WARNING - tensorflow - From /home/mila/z/zhi.yong.ignavier-ng/.conda/envs/notears/lib/python3.6/site-packages/tensorflow_core/python/ops/linalg/linalg_impl.py:283: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 18 | "Instructions for updating:\n", 19 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 20 | "2020-08-01 04:07:05,518 INFO - models.notears - Model summary:\n", 21 | "2020-08-01 04:07:05,519 INFO - models.notears - ---------\n", 22 | "2020-08-01 04:07:05,520 INFO - models.notears - Variables: name (type shape) [size]\n", 23 | "2020-08-01 04:07:05,520 INFO - models.notears - ---------\n", 24 | "2020-08-01 04:07:05,521 INFO - models.notears - Variable:0 (float32_ref 20x20) [400, bytes: 1600]\n", 25 | "2020-08-01 04:07:05,521 INFO - models.notears - Total size of variables: 400\n", 26 | "2020-08-01 04:07:05,522 INFO - models.notears - Total bytes of variables: 1600\n", 27 | "2020-08-01 04:07:11,432 INFO - trainers.al_trainer - Started training for 20 iterations\n", 28 | "2020-08-01 04:07:11,433 INFO - trainers.al_trainer - rho 1.000E+00, alpha 0.000E+00\n", 29 | "2020-08-01 04:07:17,054 INFO - trainers.al_trainer - [Iter 1] loss 1.487E+01, mse 2.145E+04, acyclic 4.007E-01, shd 52, tpr 0.421, fdr 0.667, pred_size 48\n", 30 | "2020-08-01 04:07:17,063 INFO - trainers.al_trainer - rho 1.000E+00, alpha 4.007E-01\n", 31 | "2020-08-01 04:07:20,870 INFO - trainers.al_trainer - rho 1.000E+01, alpha 4.007E-01\n", 32 | "2020-08-01 04:07:24,679 INFO - trainers.al_trainer - rho 1.000E+02, alpha 4.007E-01\n", 33 | "2020-08-01 04:07:28,454 INFO - trainers.al_trainer - rho 1.000E+03, alpha 4.007E-01\n", 34 | "2020-08-01 04:07:32,218 INFO - trainers.al_trainer - [Iter 2] loss 1.200E+01, mse 1.546E+04, acyclic 3.098E-02, shd 16, tpr 0.737, fdr 0.176, pred_size 34\n", 35 | "2020-08-01 04:07:32,221 INFO - trainers.al_trainer - rho 1.000E+03, alpha 3.138E+01\n", 36 | "2020-08-01 04:07:35,945 INFO - trainers.al_trainer - rho 1.000E+04, alpha 3.138E+01\n", 37 | "2020-08-01 04:07:39,771 INFO - trainers.al_trainer - [Iter 3] loss 1.291E+01, mse 1.712E+04, acyclic 7.328E-03, shd 6, tpr 0.868, fdr 0.029, pred_size 34\n", 38 | "2020-08-01 04:07:39,774 INFO - trainers.al_trainer - rho 1.000E+04, alpha 1.047E+02\n", 39 | "2020-08-01 04:07:43,542 INFO - trainers.al_trainer - rho 1.000E+05, alpha 1.047E+02\n", 40 | "2020-08-01 04:07:47,304 INFO - trainers.al_trainer - rho 1.000E+06, alpha 1.047E+02\n", 41 | "2020-08-01 04:07:51,112 INFO - trainers.al_trainer - [Iter 4] loss 1.412E+01, mse 1.959E+04, acyclic 6.008E-04, shd 3, tpr 0.921, fdr 0.028, pred_size 36\n", 42 | "2020-08-01 04:07:51,115 INFO - trainers.al_trainer - rho 1.000E+06, alpha 7.055E+02\n", 43 | "2020-08-01 04:07:54,886 INFO - trainers.al_trainer - rho 1.000E+07, alpha 7.055E+02\n", 44 | "2020-08-01 04:07:58,674 INFO - trainers.al_trainer - [Iter 5] loss 1.449E+01, mse 2.034E+04, acyclic 8.202E-05, shd 2, tpr 0.974, fdr 0.051, pred_size 39\n", 45 | "2020-08-01 04:07:58,677 INFO - trainers.al_trainer - rho 1.000E+07, alpha 1.526E+03\n", 46 | "2020-08-01 04:08:02,429 INFO - trainers.al_trainer - rho 1.000E+08, alpha 1.526E+03\n", 47 | "2020-08-01 04:08:06,158 INFO - trainers.al_trainer - [Iter 6] loss 1.457E+01, mse 2.052E+04, acyclic 1.717E-05, shd 2, tpr 0.974, fdr 0.051, pred_size 39\n", 48 | "2020-08-01 04:08:06,161 INFO - trainers.al_trainer - rho 1.000E+08, alpha 3.242E+03\n", 49 | "2020-08-01 04:08:09,923 INFO - trainers.al_trainer - rho 1.000E+09, alpha 3.242E+03\n", 50 | "2020-08-01 04:08:13,684 INFO - trainers.al_trainer - rho 1.000E+10, alpha 3.242E+03\n", 51 | "2020-08-01 04:08:17,478 INFO - trainers.al_trainer - [Iter 7] loss 1.476E+01, mse 2.107E+04, acyclic 1.907E-06, shd 2, tpr 0.974, fdr 0.051, pred_size 39\n", 52 | "2020-08-01 04:08:17,481 INFO - trainers.al_trainer - rho 1.000E+10, alpha 2.232E+04\n", 53 | "2020-08-01 04:08:21,230 INFO - trainers.al_trainer - [Iter 8] loss 1.468E+01, mse 2.093E+04, acyclic 0.000E+00, shd 2, tpr 0.974, fdr 0.051, pred_size 39\n", 54 | "2020-08-01 04:08:21,233 INFO - trainers.al_trainer - Early stopping at 8-th iteration\n", 55 | "2020-08-01 04:08:21,295 INFO - trainers.al_trainer - Model saved to output/2020-07-31_15-07-04-172/model/\n", 56 | "2020-08-01 04:08:21,296 INFO - __main__ - Finished training model\n", 57 | "Figure(800x300)\n", 58 | "2020-08-01 04:08:22,570 INFO - __main__ - Thresholding.\n", 59 | "Figure(800x300)\n", 60 | "2020-08-01 04:08:23,727 INFO - __main__ - Results after thresholding by 0.3: {'fdr': 0.05128205128205128, 'tpr': 0.9736842105263158, 'fpr': 0.013157894736842105, 'shd': 2, 'pred_size': 39}\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "!python main.py --seed 1230 \\\n", 66 | " --d 20 \\\n", 67 | " --n 1000 \\\n", 68 | " --degree 4 \\\n", 69 | " --graph_thres 0.3 \\\n", 70 | " --l1_lambda 0.1" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [] 79 | } 80 | ], 81 | "metadata": { 82 | "kernelspec": { 83 | "display_name": "Python 3", 84 | "language": "python", 85 | "name": "python3" 86 | }, 87 | "language_info": { 88 | "codemirror_mode": { 89 | "name": "ipython", 90 | "version": 3 91 | }, 92 | "file_extension": ".py", 93 | "mimetype": "text/x-python", 94 | "name": "python", 95 | "nbconvert_exporter": "python", 96 | "pygments_lexer": "ipython3", 97 | "version": "3.6.10" 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 4 102 | } 103 | -------------------------------------------------------------------------------- /src/helpers/analyze_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def count_accuracy(W_true, W_est, W_und=None): 6 | """ 7 | Compute FDR, TPR, and FPR for B, or optionally for CPDAG B + B_und. 8 | 9 | Args: 10 | W_true: ground truth graph 11 | W_est: predicted graph 12 | W_und: predicted undirected edges in CPDAG, asymmetric 13 | 14 | Returns in dict: 15 | fdr: (reverse + false positive) / prediction positive 16 | tpr: (true positive) / condition positive 17 | fpr: (reverse + false positive) / condition negative 18 | shd: undirected extra + undirected missing + reverse 19 | nnz: prediction positive 20 | """ 21 | B_true = W_true != 0 22 | B = W_est != 0 23 | B_und = None if W_und is None else W_und 24 | d = B.shape[0] 25 | 26 | # linear index of nonzeros 27 | if B_und is not None: 28 | pred_und = np.flatnonzero(B_und) 29 | pred = np.flatnonzero(B) 30 | cond = np.flatnonzero(B_true) 31 | cond_reversed = np.flatnonzero(B_true.T) 32 | cond_skeleton = np.concatenate([cond, cond_reversed]) 33 | # true pos 34 | true_pos = np.intersect1d(pred, cond, assume_unique=True) 35 | if B_und is not None: 36 | # treat undirected edge favorably 37 | true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True) 38 | true_pos = np.concatenate([true_pos, true_pos_und]) 39 | # false pos 40 | false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True) 41 | if B_und is not None: 42 | false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True) 43 | false_pos = np.concatenate([false_pos, false_pos_und]) 44 | # reverse 45 | extra = np.setdiff1d(pred, cond, assume_unique=True) 46 | reverse = np.intersect1d(extra, cond_reversed, assume_unique=True) 47 | # compute ratio 48 | pred_size = len(pred) 49 | if B_und is not None: 50 | pred_size += len(pred_und) 51 | cond_neg_size = 0.5 * d * (d - 1) - len(cond) 52 | fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1) 53 | tpr = float(len(true_pos)) / max(len(cond), 1) 54 | fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1) 55 | # structural hamming distance 56 | B_lower = np.tril(B + B.T) 57 | if B_und is not None: 58 | B_lower += np.tril(B_und + B_und.T) 59 | pred_lower = np.flatnonzero(B_lower) 60 | cond_lower = np.flatnonzero(np.tril(B_true + B_true.T)) 61 | extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True) 62 | missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True) 63 | shd = len(extra_lower) + len(missing_lower) + len(reverse) 64 | 65 | return { 66 | 'fdr': fdr, 67 | 'tpr': tpr, 68 | 'fpr': fpr, 69 | 'shd': shd, 70 | 'pred_size': pred_size 71 | } 72 | 73 | 74 | def plot_estimated_graph(W_est, W, save_name=None): 75 | fig, (ax1, ax2) = plt.subplots(figsize=(8, 3), ncols=2) 76 | 77 | # plot just the positive data and save the 78 | # color "mappable" object returned by ax1.imshow 79 | ax1.set_title('estimated_graph') 80 | map1 = ax1.imshow(W_est, cmap='Greys', interpolation='none') 81 | 82 | # add the colorbar using the figure's method, 83 | # telling which mappable we're talking about and 84 | # which axes object it should be near 85 | fig.colorbar(map1, ax=ax1) 86 | 87 | # repeat everything above for the the negative data 88 | ax2.set_title('true_graph') 89 | map2 = ax2.imshow(W, cmap='Greys', interpolation='none') 90 | fig.colorbar(map2, ax=ax2) 91 | 92 | plt.show() 93 | 94 | if save_name is not None: 95 | fig.savefig(save_name) 96 | -------------------------------------------------------------------------------- /src/helpers/config_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import yaml 3 | import argparse 4 | 5 | 6 | def load_yaml_config(path, skip_lines=0): 7 | with open(path, 'r') as infile: 8 | for i in range(skip_lines): 9 | # Skip some lines (e.g., namespace at the first line) 10 | _ = infile.readline() 11 | 12 | return yaml.safe_load(infile) 13 | 14 | 15 | def save_yaml_config(config, path): 16 | with open(path, 'w') as outfile: 17 | yaml.dump(config, outfile, default_flow_style=False) 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | 23 | ##### General settings ##### 24 | parser.add_argument('--seed', 25 | type=int, 26 | default=1230, 27 | help='Reproducibility') 28 | 29 | ##### Dataset settings ##### 30 | parser.add_argument('--n', 31 | type=int, 32 | default=1000, 33 | help='Number of observation data') 34 | 35 | parser.add_argument('--d', 36 | type=int, 37 | default=20, 38 | help='Number of nodes') 39 | 40 | parser.add_argument('--graph_type', 41 | type=str, 42 | default='erdos-renyi', 43 | help='Type of graph [erdos-renyi, barabasi-albert]') 44 | 45 | parser.add_argument('--degree', 46 | type=int, 47 | default=4, 48 | help='Degree of graph') 49 | 50 | parser.add_argument('--sem_type', 51 | type=str, 52 | default='linear-gauss', 53 | help='Type of sem [linear-gauss, linear-exp, linear-gumbel ]') 54 | 55 | parser.add_argument('--noise_scale', 56 | type=float, 57 | default=1.0, 58 | help='Variance for Gaussian Noise') 59 | 60 | parser.add_argument('--dataset_type', 61 | type=str, 62 | default='linear', 63 | help='Type of dataset [only linear is implemented]') 64 | 65 | ##### Model settings ##### 66 | parser.add_argument('--l1_lambda', 67 | type=float, 68 | default=0.0, 69 | help='L1 penalty for sparse graph. Set to 0 to disable') 70 | 71 | parser.add_argument('--use_float64', 72 | type=bool, 73 | default=False, 74 | help='Whether to use tf.float64 or tf.float32 during training') 75 | 76 | ##### Training settings ##### 77 | parser.add_argument('--learning_rate', 78 | type=float, 79 | default=1e-3, 80 | help='Learning rate') 81 | 82 | parser.add_argument('--max_iter', 83 | type=int, 84 | default=20, 85 | help='Number of iterations for optimization problem') 86 | 87 | parser.add_argument('--iter_step', 88 | type=int, 89 | default=1500, 90 | help='Number of steps for each iteration') 91 | 92 | parser.add_argument('--init_iter', 93 | type=int, 94 | default=3, 95 | help='Initial iteration to disallow early stopping') 96 | 97 | parser.add_argument('--h_tol', 98 | type=float, 99 | default=1e-8, 100 | help='Tolerance of optimization problem') 101 | 102 | parser.add_argument('--init_rho', 103 | type=float, 104 | default=1.0, 105 | help='Initial value of rho') 106 | 107 | parser.add_argument('--rho_max', 108 | type=float, 109 | default=1e+16, 110 | help='Maximum value of rho') 111 | 112 | parser.add_argument('--h_factor', 113 | type=float, 114 | default=0.25, 115 | help='Factor of h') 116 | 117 | parser.add_argument('--rho_multiply', 118 | type=float, 119 | default=10.0, 120 | help='Multiplication to amplify rho each time') 121 | 122 | ##### Other settings ##### 123 | parser.add_argument('--graph_thres', 124 | type=float, 125 | default=0.3, 126 | help='Threshold to filter out small values in graph') 127 | 128 | return parser.parse_args(args=sys.argv[1:]) 129 | -------------------------------------------------------------------------------- /src/helpers/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import logging 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def create_dir(dir_path): 9 | """ 10 | dir_path - A path of directory to create if it is not found 11 | """ 12 | try: 13 | if not os.path.exists(dir_path): 14 | pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True) 15 | 16 | return 0 17 | except Exception as err: 18 | _logger.critical('Creating directories error: {0}'.format(err)) 19 | exit(-1) -------------------------------------------------------------------------------- /src/helpers/log_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import pathlib 4 | from datetime import datetime 5 | from pytz import timezone, utc 6 | 7 | 8 | class LogHelper(object): 9 | """ 10 | Helper class to configure logger 11 | """ 12 | log_format = '%(asctime)s %(levelname)s - %(name)s - %(message)s' 13 | 14 | @staticmethod 15 | def setup(log_path, level_str='INFO'): 16 | logging.basicConfig( 17 | filename=log_path, 18 | level=logging.getLevelName(level_str), 19 | format= LogHelper.log_format, 20 | ) 21 | 22 | def customTime(*args): 23 | utc_dt = utc.localize(datetime.utcnow()) 24 | my_tz = timezone('Canada/Central') 25 | converted = utc_dt.astimezone(my_tz) 26 | return converted.timetuple() 27 | 28 | logging.Formatter.converter = customTime 29 | 30 | # Set up logging to console 31 | console = logging.StreamHandler() 32 | console.setLevel(logging.DEBUG) 33 | console.setFormatter(logging.Formatter(LogHelper.log_format)) 34 | # Add the console handler to the root logger 35 | logging.getLogger('').addHandler(console) 36 | 37 | # Log for unhandled exception 38 | logger = logging.getLogger(__name__) 39 | sys.excepthook = lambda *ex: logger.critical('Unhandled exception', exc_info=ex) 40 | 41 | logger.info('Finished configuring logger.') 42 | -------------------------------------------------------------------------------- /src/helpers/tf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def is_cuda_available(): 8 | return tf.test.is_gpu_available(cuda_only=True) 9 | 10 | 11 | def set_seed(seed): 12 | """ 13 | Referred from: 14 | - https://stackoverflow.com/questions/38469632/tensorflow-non-repeatable-results 15 | """ 16 | # Reproducibility 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | tf.compat.v1.set_random_seed(seed) 20 | try: 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | except: 23 | pass 24 | 25 | 26 | def tensor_description(var): 27 | """ 28 | Returns a compact and informative string about a tensor. 29 | Args: 30 | var: A tensor variable. 31 | Returns: 32 | a string with type and size, e.g.: (float32 1x8x8x1024). 33 | 34 | Referred from: 35 | - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/model_analyzer.py 36 | """ 37 | description = '(' + str(var.dtype.name) + ' ' 38 | sizes = var.get_shape() 39 | for i, size in enumerate(sizes): 40 | description += str(size) 41 | if i < len(sizes) - 1: 42 | description += 'x' 43 | description += ')' 44 | return description 45 | 46 | 47 | def print_summary(print_func): 48 | """ 49 | Print a summary table of the network structure 50 | Referred from: 51 | - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/model_analyzer.py 52 | """ 53 | variables = tf.compat.v1.trainable_variables() 54 | 55 | print_func('Model summary:') 56 | print_func('---------') 57 | print_func('Variables: name (type shape) [size]') 58 | print_func('---------') 59 | 60 | total_size = 0 61 | total_bytes = 0 62 | for var in variables: 63 | # if var.num_elements() is None or [] assume size 0. 64 | var_size = var.get_shape().num_elements() or 0 65 | var_bytes = var_size * var.dtype.size 66 | total_size += var_size 67 | total_bytes += var_bytes 68 | 69 | print_func('{} {} [{}, bytes: {}]'.format(var.name, tensor_description(var), var_size, var_bytes)) 70 | 71 | print_func('Total size of variables: {}'.format(total_size)) 72 | print_func('Total bytes of variables: {}'.format(total_bytes)) 73 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pytz import timezone 4 | from datetime import datetime 5 | import numpy as np 6 | 7 | from data_loader import SyntheticDataset 8 | from models import NoTears 9 | from trainers import ALTrainer 10 | from helpers.config_utils import save_yaml_config, get_args 11 | from helpers.log_helper import LogHelper 12 | from helpers.tf_utils import set_seed 13 | from helpers.dir_utils import create_dir 14 | from helpers.analyze_utils import count_accuracy, plot_estimated_graph 15 | 16 | 17 | # For logging of tensorflow messages 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 19 | 20 | 21 | def main(): 22 | # Get arguments parsed 23 | args = get_args() 24 | 25 | # Setup for logging 26 | output_dir = 'output/{}'.format(datetime.now(timezone('Canada/Central')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) 27 | create_dir(output_dir) 28 | LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') 29 | _logger = logging.getLogger(__name__) 30 | 31 | # Save the configuration for logging purpose 32 | save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) 33 | 34 | # Reproducibility 35 | set_seed(args.seed) 36 | 37 | # Get dataset 38 | dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree, args.sem_type, 39 | args.noise_scale, args.dataset_type) 40 | _logger.info('Finished generating dataset') 41 | 42 | model = NoTears(args.n, args.d, args.seed, args.l1_lambda, args.use_float64) 43 | model.print_summary(print_func=model.logger.info) 44 | 45 | trainer = ALTrainer(args.init_rho, args.rho_max, args.h_factor, args.rho_multiply, 46 | args.init_iter, args.learning_rate, args.h_tol) 47 | W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres, 48 | args.max_iter, args.iter_step, output_dir) 49 | _logger.info('Finished training model') 50 | 51 | # Save raw estimated graph, ground truth and observational data after training 52 | np.save('{}/true_graph.npy'.format(output_dir), dataset.W) 53 | np.save('{}/X.npy'.format(output_dir), dataset.X) 54 | np.save('{}/final_raw_estimated_graph.npy'.format(output_dir), W_est) 55 | 56 | # Plot raw estimated graph 57 | plot_estimated_graph(W_est, dataset.W, 58 | save_name='{}/raw_estimated_graph.png'.format(output_dir)) 59 | 60 | _logger.info('Thresholding.') 61 | # Plot thresholded estimated graph 62 | W_est[np.abs(W_est) < args.graph_thres] = 0 # Thresholding 63 | plot_estimated_graph(W_est, dataset.W, 64 | save_name='{}/thresholded_estimated_graph.png'.format(output_dir)) 65 | results_thresholded = count_accuracy(dataset.W, W_est) 66 | _logger.info('Results after thresholding by {}: {}'.format(args.graph_thres, results_thresholded)) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.notears import NoTears -------------------------------------------------------------------------------- /src/models/notears.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tensorflow as tf 3 | 4 | from helpers.dir_utils import create_dir 5 | from helpers.tf_utils import print_summary 6 | 7 | 8 | class NoTears(object): 9 | _logger = logging.getLogger(__name__) 10 | 11 | def __init__(self, n, d, seed=8, l1_lambda=0, use_float64=False): 12 | self.print_summary = print_summary # Print summary for tensorflow variables 13 | 14 | self.n = n 15 | self.d = d 16 | self.seed = seed 17 | self.l1_lambda = l1_lambda 18 | self.tf_float_type = tf.dtypes.float64 if use_float64 else tf.dtypes.float32 19 | 20 | # Initializer (for reproducibility) 21 | self.initializer = tf.keras.initializers.glorot_uniform(seed=self.seed) 22 | 23 | self._build() 24 | self._init_session() 25 | self._init_saver() 26 | 27 | def _init_session(self): 28 | self.sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( 29 | gpu_options=tf.compat.v1.GPUOptions( 30 | per_process_gpu_memory_fraction=0.5, 31 | allow_growth=True, 32 | ) 33 | )) 34 | 35 | def _init_saver(self): 36 | self.saver = tf.compat.v1.train.Saver() 37 | 38 | def _build(self): 39 | tf.compat.v1.reset_default_graph() 40 | 41 | self.rho = tf.compat.v1.placeholder(self.tf_float_type) 42 | self.alpha = tf.compat.v1.placeholder(self.tf_float_type) 43 | self.lr = tf.compat.v1.placeholder(self.tf_float_type) 44 | 45 | self.X = tf.compat.v1.placeholder(self.tf_float_type, shape=[self.n, self.d]) 46 | W = tf.Variable(tf.zeros([self.d, self.d], self.tf_float_type)) 47 | 48 | self.W_prime = self._preprocess_graph(W) 49 | self.mse_loss = self._get_mse_loss(self.X, self.W_prime) 50 | 51 | self.h = tf.linalg.trace(tf.linalg.expm(self.W_prime * self.W_prime)) - self.d # Acyclicity 52 | self.loss = 0.5 / self.n * self.mse_loss \ 53 | + self.l1_lambda * tf.norm(self.W_prime, ord=1) \ 54 | + self.alpha * self.h + 0.5 * self.rho * self.h * self.h 55 | 56 | self.train_op = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 57 | self._logger.debug('Finished building Tensorflow graph') 58 | 59 | def _preprocess_graph(self, W): 60 | # Mask the diagonal entries of graph 61 | return tf.linalg.set_diag(W, tf.zeros(W.shape[0], dtype=self.tf_float_type)) 62 | 63 | def _get_mse_loss(self, X, W_prime): 64 | X_prime = tf.matmul(X, W_prime) 65 | return tf.square(tf.linalg.norm(X - X_prime)) 66 | 67 | def save(self, model_dir): 68 | create_dir(model_dir) 69 | self.saver.save(self.sess, '{}/model'.format(model_dir)) 70 | 71 | @property 72 | def logger(self): 73 | try: 74 | return self._logger 75 | except: 76 | raise NotImplementedError('self._logger does not exist!') 77 | 78 | 79 | if __name__ == '__main__': 80 | n, d = 3000, 20 81 | model = NoTears(n, d) 82 | model.print_summary(print) 83 | 84 | print() 85 | print('model.W_prime: {}'.format(model.W_prime)) 86 | print('model.mse_loss: {}'.format(model.mse_loss)) 87 | print('model.h: {}'.format(model.h)) 88 | print('model.loss: {}'.format(model.loss)) 89 | print('model.train_op: {}'.format(model.train_op)) 90 | 91 | -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from trainers.al_trainer import ALTrainer -------------------------------------------------------------------------------- /src/trainers/al_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from helpers.dir_utils import create_dir 6 | from helpers.analyze_utils import count_accuracy 7 | 8 | 9 | class ALTrainer(object): 10 | """ 11 | Augmented Lagrangian method with gradient-based optimization 12 | """ 13 | _logger = logging.getLogger(__name__) 14 | 15 | def __init__(self, init_rho, rho_max, h_factor, rho_multiply, init_iter, learning_rate, h_tol): 16 | self.init_rho = init_rho 17 | self.rho_max = rho_max 18 | self.h_factor = h_factor 19 | self.rho_multiply = rho_multiply 20 | self.init_iter = init_iter 21 | self.learning_rate = learning_rate 22 | self.h_tol = h_tol 23 | 24 | def train(self, model, X, W_true, graph_thres, max_iter, iter_step, output_dir): 25 | """ 26 | model object should contain the several class member: 27 | - sess 28 | - train_op 29 | - loss 30 | - mse_loss 31 | - h 32 | - W_prime 33 | - X 34 | - rho 35 | - alpha 36 | - lr 37 | """ 38 | model.sess.run(tf.compat.v1.global_variables_initializer()) 39 | rho, alpha, h, h_new = self.init_rho, 0.0, np.inf, np.inf 40 | 41 | self._logger.info('Started training for {} iterations'.format(max_iter)) 42 | for epoch in range(1, max_iter + 1): 43 | while rho < self.rho_max: 44 | self._logger.info('rho {:.3E}, alpha {:.3E}'.format(rho, alpha)) 45 | loss_new, mse_new, h_new, W_new = self.train_step(model, iter_step, X, rho, alpha) 46 | if h_new > self.h_factor * h: 47 | rho *= self.rho_multiply 48 | else: 49 | break 50 | 51 | self.train_callback(epoch, loss_new, mse_new, h_new, W_true, W_new, graph_thres, output_dir) 52 | W_est, h = W_new, h_new 53 | alpha += rho * h 54 | 55 | if h <= self.h_tol and epoch > self.init_iter: 56 | self._logger.info('Early stopping at {}-th iteration'.format(epoch)) 57 | break 58 | 59 | # Save model 60 | model_dir = '{}/model/'.format(output_dir) 61 | model.save(model_dir) 62 | self._logger.info('Model saved to {}'.format(model_dir)) 63 | 64 | return W_est 65 | 66 | def train_step(self, model, iter_step, X, rho, alpha): 67 | for _ in range(iter_step): 68 | _, curr_loss, curr_mse, curr_h, curr_W \ 69 | = model.sess.run([model.train_op, model.loss, model.mse_loss, model.h, model.W_prime], 70 | feed_dict={model.X: X, 71 | model.rho: rho, 72 | model.alpha: alpha, 73 | model.lr: self.learning_rate}) 74 | 75 | return curr_loss, curr_mse, curr_h, curr_W 76 | 77 | def train_callback(self, epoch, loss, mse, h, W_true, W_est, graph_thres, output_dir): 78 | # Evaluate the learned W in each iteration after thresholding 79 | W_thresholded = np.copy(W_est) 80 | W_thresholded[np.abs(W_thresholded) < graph_thres] = 0 81 | results_thresholded = count_accuracy(W_true, W_thresholded) 82 | 83 | self._logger.info( 84 | '[Iter {}] loss {:.3E}, mse {:.3E}, acyclic {:.3E}, shd {}, tpr {:.3f}, fdr {:.3f}, pred_size {}'.format( 85 | epoch, loss, mse, h, results_thresholded['shd'], results_thresholded['tpr'], 86 | results_thresholded['fdr'], results_thresholded['pred_size'] 87 | ) 88 | ) 89 | 90 | # Save the raw estimated graph in each iteration 91 | create_dir('{}/raw_estimated_graph'.format(output_dir)) 92 | np.save('{}/raw_estimated_graph/graph_iteration_{}.npy'.format(output_dir, epoch), W_est) --------------------------------------------------------------------------------