├── __init__.py ├── data └── count │ ├── emnlp-x.npy │ ├── emnlp-y.npy │ ├── emnlp-adj.npz │ ├── wiki-squirrel-x.npy │ ├── wiki-squirrel-y.npy │ ├── wiki-chameleon-x.npy │ ├── wiki-chameleon-y.npy │ ├── wiki-chameleon-edge.npy │ └── wiki-squirrel-edge.npy ├── utils.py ├── LICENSE ├── README.md ├── environment.yml ├── copula.py ├── data.py ├── main.py └── models.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/count/emnlp-x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/emnlp-x.npy -------------------------------------------------------------------------------- /data/count/emnlp-y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/emnlp-y.npy -------------------------------------------------------------------------------- /data/count/emnlp-adj.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/emnlp-adj.npz -------------------------------------------------------------------------------- /data/count/wiki-squirrel-x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-squirrel-x.npy -------------------------------------------------------------------------------- /data/count/wiki-squirrel-y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-squirrel-y.npy -------------------------------------------------------------------------------- /data/count/wiki-chameleon-x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-chameleon-x.npy -------------------------------------------------------------------------------- /data/count/wiki-chameleon-y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-chameleon-y.npy -------------------------------------------------------------------------------- /data/count/wiki-chameleon-edge.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-chameleon-edge.npy -------------------------------------------------------------------------------- /data/count/wiki-squirrel-edge.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqima/CopulaGNN/HEAD/data/count/wiki-squirrel-edge.npy -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | 5 | class Logger(object): 6 | 7 | def __init__(self, verbose=0, log_path=None, file_prefix=""): 8 | self.verbose = verbose 9 | self.filename = None 10 | if log_path is not None: 11 | if not os.path.exists(log_path): 12 | os.makedirs(log_path) 13 | self.filename = os.path.join( 14 | log_path, file_prefix + ".log") 15 | with open(self.filename, "w") as f: 16 | f.write(self.filename) 17 | f.write("\n") 18 | 19 | def p(self, s, level=1): 20 | if self.verbose >= level: 21 | print(s) 22 | if self.filename is not None: 23 | with open(self.filename, "a") as f: 24 | f.write(datetime.now().strftime("[%m/%d %H:%M:%S] ") + str(s)) 25 | f.write("\n") 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jiaqi Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CopulaGNN 2 | 3 | This repo provides a PyTorch implementation for the CopulaGNN models as described in the following paper: 4 | 5 | [CopulaGNN: Towards Integrating Representational and Correlational Roles of Graphs in Graph Neural Networks](https://arxiv.org/abs/2010.02089) 6 | 7 | Jiaqi Ma, Bo Chang, Xuefei Zhang, and Qiaozhu Mei. ICLR 2021. 8 | 9 | ## Requirements 10 | Most dependency packages are included in `environment.yml`. Run `conda torch_env create -f environment.yml` to install the required packages. 11 | 12 | In addition, one also needs to install [PyTorch-Geometric](https://github.com/rusty1s/pytorch_geometric) following the [official installation instructions](https://github.com/rusty1s/pytorch_geometric#installation). 13 | 14 | The code is tested with the following PyTorch-Geometric version. 15 | 16 | ``` 17 | torch-scatter==2.0.5 18 | torch-sparse==0.6.7 19 | torch-cluster==1.5.7 20 | torch-geometric==1.6.1 21 | ``` 22 | 23 | ## Run the code 24 | Example: `python main.py --lr 0.001 --hidden_size 16 --dataset wiki-squirrel --model_type regcgcn`. 25 | 26 | ## Cite 27 | ``` 28 | @article{ma2020copulagnn, 29 | title={CopulaGNN: Towards Integrating Representational and Correlational Roles of Graphs in Graph Neural Networks}, 30 | author={Ma, Jiaqi and Chang, Bo and Zhang, Xuefei and Mei, Qiaozhu}, 31 | booktitle={International Conference on Learning Representations}, 32 | year={2021} 33 | } 34 | ``` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: copula_env 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - backcall=0.2.0=py_0 8 | - ca-certificates=2020.10.14=0 9 | - certifi=2020.6.20=py38_0 10 | - cudatoolkit=10.1.243=h6bb024c_0 11 | - decorator=4.4.2=py_0 12 | - ipykernel=5.3.4=py38h5ca1d4c_0 13 | - ipython=7.18.1=py38h5ca1d4c_0 14 | - ipython_genutils=0.2.0=py38_0 15 | - jedi=0.17.2=py38_0 16 | - jupyter_client=6.1.7=py_0 17 | - jupyter_core=4.6.3=py38_0 18 | - libsodium=1.0.18=h7b6447c_0 19 | - openssl=1.1.1h=h7b6447c_0 20 | - parso=0.7.0=py_0 21 | - pexpect=4.8.0=py38_0 22 | - pickleshare=0.7.5=py38_1000 23 | - prompt-toolkit=3.0.8=py_0 24 | - ptyprocess=0.6.0=py38_0 25 | - pygments=2.7.1=py_0 26 | - pyzmq=19.0.2=py38he6710b0_1 27 | - tornado=6.0.4=py38h7b6447c_1 28 | - traitlets=5.0.5=py_0 29 | - wcwidth=0.2.5=py_0 30 | - zeromq=4.3.3=he6710b0_3 31 | - _libgcc_mutex=0.1=main 32 | - blas=1.0=mkl 33 | - freetype=2.10.2=h5ab3b9f_0 34 | - intel-openmp=2020.2=254 35 | - joblib=0.16.0=py_0 36 | - jpeg=9b=h024ee3a_2 37 | - lcms2=2.11=h396b838_0 38 | - ld_impl_linux-64=2.33.1=h53a641e_7 39 | - libedit=3.1.20191231=h14c3975_1 40 | - libffi=3.3=he6710b0_2 41 | - libgcc-ng=9.1.0=hdf63c60_0 42 | - libgfortran-ng=7.3.0=hdf63c60_0 43 | - libpng=1.6.37=hbc83047_0 44 | - libstdcxx-ng=9.1.0=hdf63c60_0 45 | - libtiff=4.1.0=h2733197_1 46 | - lz4-c=1.9.2=he6710b0_1 47 | - mkl=2020.2=256 48 | - mkl-service=2.3.0=py38he904b0f_0 49 | - mkl_fft=1.1.0=py38h23d657b_0 50 | - mkl_random=1.1.1=py38h0573a6f_0 51 | - ncurses=6.2=he6710b0_1 52 | - ninja=1.10.1=py38hfd86e86_0 53 | - numpy=1.19.1=py38hbc911f0_0 54 | - numpy-base=1.19.1=py38hfa32c7d_0 55 | - olefile=0.46=py_0 56 | - pandas=1.1.1=py38he6710b0_0 57 | - pillow=7.2.0=py38hb39fc2d_0 58 | - pip=20.2.2=py38_0 59 | - python=3.8.5=h7579374_1 60 | - python-dateutil=2.8.1=py_0 61 | - pytz=2020.1=py_0 62 | - readline=8.0=h7b6447c_0 63 | - scikit-learn=0.23.2=py38h0573a6f_0 64 | - scipy=1.5.2=py38h0b6359f_0 65 | - setuptools=49.6.0=py38_0 66 | - six=1.15.0=py_0 67 | - sqlite=3.33.0=h62c20be_0 68 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 69 | - tk=8.6.10=hbc83047_0 70 | - wheel=0.35.1=py_0 71 | - xz=5.2.5=h7b6447c_0 72 | - zlib=1.2.11=h7b6447c_3 73 | - zstd=1.4.5=h9ceee32_0 74 | - pytorch=1.6.0=py3.8_cuda10.1.243_cudnn7.6.3_0 75 | - torchvision=0.7.0=py38_cu101 76 | -------------------------------------------------------------------------------- /copula.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.distributions import constraints 5 | from torch.distributions.distribution import Distribution 6 | from torch.distributions.multivariate_normal import ( 7 | MultivariateNormal, 8 | _batch_mahalanobis, 9 | ) 10 | 11 | 12 | def _standard_normal_quantile(u): 13 | # Ref: https://en.wikipedia.org/wiki/Normal_distribution 14 | return math.sqrt(2) * torch.erfinv(2 * u - 1) 15 | 16 | 17 | def _standard_normal_cdf(x): 18 | # Ref: https://en.wikipedia.org/wiki/Normal_distribution 19 | return 0.5 * (1 + torch.erf(x / math.sqrt(2))) 20 | 21 | 22 | class GaussianCopula(Distribution): 23 | r""" 24 | A Gaussian copula. 25 | 26 | Args: 27 | covariance_matrix (torch.Tensor): positive-definite covariance matrix 28 | """ 29 | arg_constraints = {"covariance_matrix": constraints.positive_definite} 30 | support = constraints.interval(0.0, 1.0) 31 | has_rsample = True 32 | 33 | def __init__(self, covariance_matrix=None, validate_args=None): 34 | # convert the covariance matrix to the correlation matrix 35 | # self.covariance_matrix = covariance_matrix.clone() 36 | # batch_diag = torch.diagonal(self.covariance_matrix, dim1=-1, dim2=-2).pow(-0.5) 37 | # self.covariance_matrix *= batch_diag.unsqueeze(-1) 38 | # self.covariance_matrix *= batch_diag.unsqueeze(-2) 39 | diag = torch.diag(covariance_matrix).pow(-0.5) 40 | self.covariance_matrix = ( 41 | torch.diag(diag)).matmul(covariance_matrix).matmul( 42 | torch.diag(diag)) 43 | 44 | batch_shape, event_shape = ( 45 | covariance_matrix.shape[:-2], 46 | covariance_matrix.shape[-1:], 47 | ) 48 | 49 | super().__init__(batch_shape, event_shape, validate_args=validate_args) 50 | 51 | self.multivariate_normal = MultivariateNormal( 52 | loc=torch.zeros(event_shape), 53 | covariance_matrix=self.covariance_matrix, 54 | validate_args=validate_args, 55 | ) 56 | 57 | def log_prob(self, value): 58 | if self._validate_args: 59 | self._validate_sample(value) 60 | value_x = _standard_normal_quantile(value) 61 | half_log_det = ( 62 | self.multivariate_normal._unbroadcasted_scale_tril.diagonal( 63 | dim1=-2, dim2=-1 64 | ) 65 | .log() 66 | .sum(-1) 67 | ) 68 | M = _batch_mahalanobis( 69 | self.multivariate_normal._unbroadcasted_scale_tril, value_x 70 | ) 71 | M -= value_x.pow(2).sum(-1) 72 | return -0.5 * M - half_log_det 73 | 74 | def conditional_sample( 75 | self, cond_val, sample_shape=torch.Size([]), cond_idx=None, sample_idx=None 76 | ): 77 | """ 78 | Draw samples conditioning on cond_val. 79 | 80 | Args: 81 | cond_val (torch.Tensor): conditional values. Should be a 1D tensor. 82 | sample_shape (torch.Size): same as in 83 | `Distribution.sample(sample_shape=torch.Size([]))`. 84 | cond_idx (torch.LongTensor): indices that correspond to cond_val. 85 | If None, use the last m dimensions, where m is the length of cond_val. 86 | sample_idx (torch.LongTensor): indices to sample from. If None, sample 87 | from all remaining dimensions. 88 | 89 | Returns: 90 | Generates a sample_shape shaped sample or sample_shape shaped batch of 91 | samples if the distribution parameters are batched. 92 | """ 93 | m, n = *cond_val.shape, *self.event_shape 94 | 95 | if cond_idx is None: 96 | cond_idx = torch.arange(n - m, n) 97 | if sample_idx is None: 98 | sample_idx = torch.tensor( 99 | [i for i in range(n) if i not in set(cond_idx.tolist())] 100 | ) 101 | 102 | assert ( 103 | len(cond_idx) == m 104 | and len(sample_idx) + len(cond_idx) <= n 105 | and not set(cond_idx.tolist()) & set(sample_idx.tolist()) 106 | ) 107 | 108 | cov_00 = self.covariance_matrix.index_select( 109 | dim=0, index=sample_idx 110 | ).index_select(dim=1, index=sample_idx) 111 | cov_01 = self.covariance_matrix.index_select( 112 | dim=0, index=sample_idx 113 | ).index_select(dim=1, index=cond_idx) 114 | cov_10 = self.covariance_matrix.index_select( 115 | dim=0, index=cond_idx 116 | ).index_select(dim=1, index=sample_idx) 117 | cov_11 = self.covariance_matrix.index_select( 118 | dim=0, index=cond_idx 119 | ).index_select(dim=1, index=cond_idx) 120 | 121 | cond_val_nscale = _standard_normal_quantile(cond_val) # Phi^{-1}(u_cond) 122 | reg_coeff, _ = torch.solve(cov_10, cov_11) # Sigma_{11}^{-1} Sigma_{10} 123 | cond_mu = torch.mv(reg_coeff.t(), cond_val_nscale) 124 | cond_sigma = cov_00 - torch.mm(cov_01, reg_coeff) 125 | cond_normal = MultivariateNormal(loc=cond_mu, covariance_matrix=cond_sigma) 126 | 127 | samples_nscale = cond_normal.sample(sample_shape) 128 | samples_uscale = _standard_normal_cdf(samples_nscale) 129 | 130 | return samples_uscale 131 | 132 | 133 | if __name__ == "__main__": 134 | covariance_matrix = torch.tensor( 135 | [ 136 | [1.0, 0.5, 0.5, 0.5], 137 | [0.5, 1.0, 0.5, 0.5], 138 | [0.5, 0.5, 1.0, 0.5], 139 | [0.5, 0.5, 0.5, 1.0], 140 | ] 141 | ) 142 | gaussian_copula = GaussianCopula(covariance_matrix=covariance_matrix) 143 | cond_samples = gaussian_copula.conditional_sample( 144 | torch.Tensor([0.1]), sample_shape=[5] 145 | ) 146 | print(cond_samples) 147 | 148 | from torch.distributions.normal import Normal 149 | 150 | covariance_matrix = torch.tensor([[1.5, 0.5], [0.5, 2.0]]) 151 | multivariate_normal = MultivariateNormal( 152 | loc=torch.zeros(2), covariance_matrix=covariance_matrix 153 | ) 154 | normal = Normal(loc=torch.zeros(2), scale=torch.diag(covariance_matrix).pow(0.5)) 155 | gaussian_copula = GaussianCopula(covariance_matrix=covariance_matrix) 156 | 157 | for _ in range(10): 158 | value_x = torch.randn(5, 2) 159 | value_u = normal.cdf(value_x) 160 | actual = gaussian_copula.log_prob(value_u) 161 | expected = multivariate_normal.log_prob(value_x) - normal.log_prob(value_x).sum( 162 | -1 163 | ) 164 | 165 | print(f"expected: {expected}, actual: {actual}.") 166 | assert torch.norm(actual - expected) < 1e-5 167 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import scipy.sparse as sp 5 | import torch 6 | from six.moves import cPickle as pickle 7 | from torch_geometric.data import Data 8 | from torch_geometric.utils import to_undirected, remove_self_loops 9 | 10 | 11 | def to_data(x, y, adj=None, edge_index=None, train_idx=None, valid_idx=None, 12 | test_idx=None, train_size=1. / 3, valid_size=1. / 3): 13 | x_tensor = torch.tensor(x, dtype=torch.float32) 14 | y_tensor = torch.tensor(y, dtype=torch.float32) 15 | if edge_index is None: 16 | assert adj is not None 17 | edge_index = torch.tensor(np.array(list(adj.nonzero()))) 18 | else: 19 | edge_index = torch.tensor(edge_index, dtype=torch.int64) 20 | edge_index = remove_self_loops(to_undirected(edge_index))[0] 21 | 22 | data = Data(x=x_tensor, y=y_tensor, edge_index=edge_index) 23 | n = data.x.size(0) 24 | if train_idx is not None: 25 | assert valid_idx is not None and test_idx is not None 26 | all_idx = set(list(range(n))) 27 | train_idx = set(train_idx) 28 | valid_idx = set(valid_idx) 29 | test_idx = all_idx.difference(train_idx.union(valid_idx)) 30 | elif isinstance(train_size, float): 31 | train_size = int(n * train_size) 32 | valid_size = int(n * valid_size) 33 | test_size = n - train_size - valid_size 34 | train_idx = set(range(train_size)) 35 | valid_idx = set(range(train_size, train_size + valid_size)) 36 | test_idx = set(range(n - test_size, n)) 37 | assert len(test_idx.intersection(train_idx.union(valid_idx))) == 0 38 | data.train_mask = torch.zeros(n).to(dtype=torch.bool) 39 | data.train_mask[list(train_idx)] = True 40 | data.valid_mask = torch.zeros(n).to(dtype=torch.bool) 41 | data.valid_mask[list(valid_idx)] = True 42 | data.test_mask = torch.zeros(n).to(dtype=torch.bool) 43 | data.test_mask[list(test_idx)] = True 44 | return data 45 | 46 | 47 | def generate_lsn(n=300, 48 | d=10, 49 | m=1500, 50 | gamma=0.1, 51 | tau=1., 52 | seed=1, 53 | lsn_mode="xw", 54 | root='./data', 55 | load_file=False, 56 | save_file=False): 57 | path = os.path.join(root, "lsn") 58 | assert lsn_mode in ["xw", "daxwi", "daxw"] 59 | filename = "lsn_{}_n{}_d{}_m{}_g{}_t{}_s{}.pkl".format( 60 | lsn_mode, n, d, m, gamma, tau, seed) 61 | if load_file and os.path.exists(os.path.join(path, filename)): 62 | with open(os.path.join(path, filename), "rb") as f: 63 | data, params = pickle.load(f) 64 | return data[0], data[1], data[2], filename 65 | 66 | rs = np.random.RandomState(seed=seed) 67 | x = rs.normal(size=(n, d)) 68 | w_a = rs.normal(size=(d, d)) 69 | w_y = rs.normal(size=(d, )) 70 | 71 | prod = x.dot(w_a) # (n, d) 72 | logits = -np.linalg.norm( 73 | prod.reshape(1, n, d) - prod.reshape(n, 1, d), axis=2) # (n, n) 74 | threshold = np.sort(logits.reshape(-1))[-m] 75 | adj = (logits >= threshold).astype(float) 76 | L = np.diag(adj.sum(axis=0)) - adj 77 | 78 | if lsn_mode == "xw": 79 | y_mean = x.dot(w_y) 80 | else: 81 | y_mean = np.diag(1. / adj.sum(axis=0)).dot(adj).dot(x).dot(w_y) 82 | 83 | if lsn_mode == "daxwi": 84 | y_cov = tau * np.linalg.inv(gamma * np.eye(n)) 85 | else: 86 | y_cov = tau * np.linalg.inv(L + gamma * np.eye(n)) 87 | y = rs.multivariate_normal(y_mean, y_cov) 88 | 89 | if save_file: 90 | if not os.path.exists(path): 91 | os.makedirs(path) 92 | with open(os.path.join(path, filename), "wb") as f: 93 | pickle.dump(((x, y, adj), (w_a, w_y)), f) 94 | 95 | return x, y, adj, filename 96 | 97 | 98 | def read_wiki(path, name="chameleon", seed=1): 99 | data_path = os.path.join(path, "count") 100 | 101 | x = np.load(os.path.join(data_path, "wiki-{}-x.npy".format(name))) 102 | y = np.load(os.path.join(data_path, "wiki-{}-y.npy".format(name))) 103 | edge_index = np.load( 104 | os.path.join(data_path, "wiki-{}-edge.npy".format(name))) 105 | 106 | rs = np.random.RandomState(seed) 107 | idx = rs.permutation(len(y)) 108 | split_size = int(len(idx) / 3) 109 | train_idx = idx[:1 * split_size] 110 | valid_idx = idx[1 * split_size:2 * split_size] 111 | test_idx = idx[-split_size:] 112 | data = to_data(x, y, edge_index=edge_index, train_idx=train_idx, 113 | valid_idx=valid_idx, test_idx=test_idx) 114 | return data 115 | 116 | 117 | def read_emnlp(path, seed=1): 118 | data_path = os.path.join(path, "count") 119 | x = np.load(os.path.join(data_path, "emnlp-x.npy")) 120 | y = np.load(os.path.join(data_path, "emnlp-y.npy")) 121 | adj = sp.load_npz(os.path.join(data_path, "emnlp-adj.npz")) 122 | adj = adj.todense() 123 | 124 | rs = np.random.RandomState(seed) 125 | idx = rs.permutation(len(y)) 126 | split_size = int(len(idx) / 3) 127 | train_idx = idx[:1 * split_size] 128 | valid_idx = idx[1 * split_size:2 * split_size] 129 | test_idx = idx[-split_size:] 130 | data = to_data(x, y, adj=adj, train_idx=train_idx, 131 | valid_idx=valid_idx, test_idx=test_idx) 132 | return data 133 | 134 | 135 | def read_election(path, target="election", seed=1): 136 | data_path = os.path.join(path, "election") 137 | 138 | edge_file = "2012_adj.csv" 139 | with open(os.path.join(data_path, edge_file)) as f: 140 | edge_index = np.loadtxt(f, dtype=int, delimiter=",", skiprows=1) 141 | edge_index = edge_index.T - 1 142 | feature_file = "2012_xy.csv" 143 | with open(os.path.join(data_path, feature_file)) as f: 144 | x = np.loadtxt(f, dtype=float, delimiter=",", skiprows=1) 145 | if target == "income": 146 | pos = 0 147 | elif target == "education": 148 | pos = 4 149 | elif target == "unemployment": 150 | pos = 5 151 | elif target == "election": 152 | pos = 6 153 | else: 154 | NotImplementedError("Unexpected target type {}".format(target)) 155 | y = x[:, pos] 156 | x = np.concatenate((x[:, :pos], x[:, pos+1:]), axis=1) 157 | 158 | rs = np.random.RandomState(seed) 159 | idx = rs.permutation(len(y)) 160 | split_size = int(len(idx) / 5) 161 | train_idx = idx[:3 * split_size] 162 | valid_idx = idx[3 * split_size:4 * split_size] 163 | test_idx = idx[-split_size:] 164 | data = to_data(x, y, edge_index=edge_index, train_idx=train_idx, 165 | valid_idx=valid_idx, test_idx=test_idx) 166 | return data 167 | 168 | 169 | if __name__ == "__main__": 170 | # Test to_data 171 | x = np.arange(6).reshape(3, 2) 172 | y = np.ones((3,)) 173 | y[0] = 0 174 | adj = np.array([ 175 | [0, 1, 0], 176 | [0, 0, 0], 177 | [0, 0, 0] 178 | ]) 179 | data = to_data(x, y, adj=adj) 180 | assert (data.edge_index.numpy() == np.array([[0, 1], [1, 0]])).all() 181 | assert (data.train_mask.numpy() == np.array([True, False, False])).all() 182 | assert (data.test_mask.numpy() == np.array([False, False, True])).all() 183 | assert (data.x.numpy() == x).all() 184 | assert (data.y.numpy() == y).all() 185 | 186 | adj = np.array([ 187 | [1, 1, 0], 188 | [0, 0, 0], 189 | [1, 0, 0] 190 | ]) 191 | data = to_data(x, y, adj=adj, train_idx=[2], valid_idx=[1], test_idx=[0]) 192 | print(data.edge_index) 193 | assert (data.train_mask.numpy() == np.array([False, False, True])).all() 194 | assert (data.test_mask.numpy() == np.array([True, False, False])).all() 195 | 196 | edge_index = np.array([[1, 0], [0, 2]]) 197 | data = to_data(x, y, edge_index=edge_index) 198 | assert (data.edge_index.numpy() == np.array([[0, 0, 1, 2], [1, 2, 0, 0]])).all() 199 | 200 | for target in ["election", "income", "education", "unemployment"]: 201 | print("reading ", target) 202 | data = read_election("data", target) 203 | assert len(data.x) == data.edge_index.max().item() + 1 204 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import os 5 | import random 6 | import time 7 | from six.moves import cPickle as pickle 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from data import (generate_lsn, to_data, read_election, read_wiki, read_emnlp) 14 | from models import (MLP, GCN, SAGE, GAT, APPNPNet, CorCopulaGCN, CorCopulaSAGE, 15 | RegCopulaGCN, RegCopulaSAGE,) 16 | from utils import Logger 17 | 18 | import warnings 19 | warnings.filterwarnings('error') 20 | 21 | parser = argparse.ArgumentParser(description='Main.') 22 | parser.add_argument("--verbose", type=int, default=2) 23 | parser.add_argument("--debug", action="store_true") 24 | parser.add_argument("--device", default="cpu") 25 | parser.add_argument("--seed", type=int, default=1) 26 | parser.add_argument("--num_trials", type=int, default=10) 27 | 28 | # Dataset configuration 29 | parser.add_argument("--path", default="./data") 30 | parser.add_argument("--dataset", default="wiki-squirrel") 31 | # Synthetic data configuration 32 | parser.add_argument( 33 | "--lsn_mode", default="daxw", 34 | help=("Choices: `daxwi`, `xw', or `daxw`. \n" 35 | " `daxwi`: only mean is graph-dependent; \n" 36 | " `xw`: only cov is graph-dependent; \n" 37 | " `daxw`: both mean and cov are graph-dependent.")) 38 | parser.add_argument("--num_features", type=int, default=10) 39 | parser.add_argument("--num_nodes", type=int, default=300) 40 | parser.add_argument("--num_edges", type=int, default=5000) 41 | parser.add_argument("--gamma", type=float, default=0.1) 42 | parser.add_argument("--tau", type=float, default=1.0) 43 | 44 | # Model configuration 45 | parser.add_argument("--model_type", default="mlp") 46 | parser.add_argument("--hidden_size", type=int, default=8) 47 | parser.add_argument("--dropout", type=float, default=0.) 48 | parser.add_argument("--num_heads", type=int, default=4) 49 | parser.add_argument("--clip_output", type=float, default=0.5) 50 | 51 | # Training configuration 52 | parser.add_argument("--opt", default="Adam") 53 | parser.add_argument("--lr", type=float, default=0.001) 54 | parser.add_argument("--num_epochs", type=int, default=10000) 55 | parser.add_argument("--patience", type=int, default=50) 56 | 57 | # Other configuration 58 | parser.add_argument("--log_interval", type=int, default=20) 59 | parser.add_argument("--result_path", default=None) 60 | 61 | args = parser.parse_args() 62 | 63 | # Set random seed 64 | if args.seed >= 0: 65 | random.seed(args.seed) 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | if args.device.startswith("cuda"): 69 | torch.cuda.manual_seed(args.seed) 70 | 71 | # Load data 72 | data_seed = int(np.ceil(args.seed / float(args.num_trials))) 73 | if args.dataset == "lsn": 74 | x, y, adj, datafile = generate_lsn(n=args.num_nodes, 75 | d=args.num_features, 76 | m=args.num_edges, 77 | gamma=args.gamma, 78 | tau=args.tau, 79 | seed=data_seed, 80 | lsn_mode=args.lsn_mode, 81 | root=args.path, 82 | save_file=False) 83 | data = to_data(x, y, adj=adj) 84 | data.is_count_data = False 85 | data.to(args.device) 86 | elif args.dataset.startswith("election"): 87 | target = args.dataset.split("-")[1] 88 | data = read_election("data", target, seed=data_seed) 89 | data.is_count_data = False 90 | data.to(args.device) 91 | elif args.dataset.startswith("wiki"): 92 | name = args.dataset.split("-")[1] 93 | data = read_wiki("data", name=name, seed=data_seed) 94 | data.is_count_data = True 95 | data.to(args.device) 96 | elif args.dataset.startswith("emnlp", seed=data_seed): 97 | data = read_emnlp("data") 98 | data.is_count_data = True 99 | data.to(args.device) 100 | else: 101 | raise NotImplementedError("Dataset {} is not supported.".format( 102 | args.dataset)) 103 | 104 | # Outcome type config 105 | if not data.is_count_data: 106 | marginal_type = "Normal" 107 | 108 | # R-squared 109 | def metric(preds, labels): 110 | num = torch.mean((preds - labels)**2) 111 | denum = torch.mean((labels - torch.mean(labels))**2) 112 | return 1 - num / denum 113 | else: 114 | marginal_type = "Poisson" 115 | 116 | # R-squared based on deviance residuals for count data 117 | # Suitable for heteroscedastic data 118 | # http://cameron.econ.ucdavis.edu/research/jbes96preprint.pdf 119 | def metric(preds, labels): 120 | labels = 1 + labels 121 | preds = 1 + preds 122 | ratio = torch.log(labels / preds) 123 | num = torch.mean(labels * ratio - (labels - preds)) 124 | denum = torch.mean(labels * torch.log(labels / torch.mean(labels))) 125 | return 1 - num / denum 126 | 127 | minimize_metric = -1 128 | 129 | # Log file 130 | time_stamp = time.time() 131 | log_file = ( 132 | "data__{}__model__{}__lr__{}__h__{}__seed__{}__stamp__{}").format( 133 | args.dataset, args.model_type, args.lr, args.hidden_size, args.seed, 134 | time_stamp) 135 | if args.dataset == "lsn": 136 | log_file += "__datafile__{}".format(os.path.splitext(datafile)[0]) 137 | log_path = os.path.join(args.path, "logs") 138 | lgr = Logger(args.verbose, log_path, log_file) 139 | lgr.p(args) 140 | 141 | # Model config 142 | model_args = { 143 | "num_features": data.x.size(1), 144 | "hidden_size": args.hidden_size, 145 | "dropout": args.dropout, 146 | "activation": "relu" 147 | } 148 | 149 | if args.model_type in ["corcgcn", "regcgcn", "corcsage", "regcsage"]: 150 | model_args["marginal_type"] = marginal_type 151 | 152 | if args.model_type == "mlp": 153 | model = MLP(**model_args) 154 | elif args.model_type == "gcn": 155 | model = GCN(**model_args) 156 | elif args.model_type == "sage": 157 | model = SAGE(**model_args) 158 | elif args.model_type == "gat": 159 | model_args["num_heads"] = args.num_heads 160 | model_args["hidden_size"] = int(args.hidden_size / args.num_heads) 161 | model = GAT(**model_args) 162 | elif args.model_type == "appnp": 163 | model = APPNPNet(**model_args) 164 | elif args.model_type == "corcgcn": 165 | model = CorCopulaGCN(**model_args) 166 | elif args.model_type == "corcsage": 167 | model = CorCopulaSAGE(**model_args) 168 | elif args.model_type == "regcgcn": 169 | model = RegCopulaGCN(**model_args) 170 | elif args.model_type == "regcsage": 171 | model = RegCopulaSAGE(**model_args) 172 | else: 173 | raise NotImplementedError("Model {} is not supported.".format( 174 | args.model_type)) 175 | model.to(args.device) 176 | 177 | # Optimizer 178 | if args.opt == "Adam": 179 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) 180 | else: 181 | raise NotImplementedError("Optimizer {} is not supported.".format( 182 | args.opt)) 183 | 184 | # Training objective 185 | if hasattr(model, "nll"): 186 | 187 | def train_loss_fn(model, data): 188 | return model.nll(data) 189 | 190 | else: 191 | 192 | if marginal_type == "Normal": 193 | criterion = nn.MSELoss() 194 | elif marginal_type == "Poisson": 195 | 196 | def criterion(logits, labels): 197 | return torch.mean(torch.exp(logits) - labels * logits) 198 | 199 | else: 200 | raise NotImplementedError("Marginal type {} is not supported.".format( 201 | marginal_type)) 202 | 203 | def train_loss_fn(model, data): 204 | return criterion( 205 | model(data)[data.train_mask], data.y[data.train_mask]) 206 | 207 | 208 | # Training and evaluation 209 | def train(): 210 | model.train() 211 | optimizer.zero_grad() 212 | loss = train_loss_fn(model, data) 213 | loss.backward() 214 | optimizer.step() 215 | 216 | 217 | def test(): 218 | model.eval() 219 | with torch.no_grad(): 220 | if hasattr(model, "predict"): 221 | preds = model.predict(data, num_samples=1000) 222 | else: 223 | preds = model(data) 224 | if marginal_type == "Poisson": 225 | preds = torch.exp(preds) 226 | if args.clip_output != 0: # clip output logits to avoid extreme outliers 227 | left = torch.min(data.y[data.train_mask]) / args.clip_output 228 | right = torch.max(data.y[data.train_mask]) * args.clip_output 229 | preds = torch.clamp(preds, left, right) 230 | train_metric = metric( 231 | preds[data.train_mask], data.y[data.train_mask]).item() 232 | valid_metric = metric( 233 | preds[data.valid_mask], data.y[data.valid_mask]).item() 234 | test_metric = metric( 235 | preds[data.test_mask], data.y[data.test_mask]).item() 236 | return train_metric, valid_metric, test_metric 237 | 238 | 239 | patience = args.patience 240 | best_metric = np.inf 241 | stats_to_save = {"args": args, "traj": []} 242 | for epoch in range(args.num_epochs): 243 | train() 244 | if (epoch + 1) % args.log_interval == 0: 245 | train_metric, valid_metric, test_metric = test() 246 | this_metric = valid_metric * minimize_metric 247 | patience -= 1 248 | if this_metric < best_metric: 249 | patience = args.patience 250 | best_metric = this_metric 251 | stats_to_save["valid_metric"] = valid_metric 252 | stats_to_save["test_metric"] = test_metric 253 | stats_to_save["epoch"] = epoch 254 | stats_to_save["traj"].append({ 255 | "epoch": epoch, 256 | "valid_metric": valid_metric, 257 | "test_metric": test_metric 258 | }) 259 | if patience == 0: 260 | break 261 | lgr.p("Epoch {}: train {:.4f}, valid {:.4f}, test {:.4f}".format( 262 | epoch, train_metric, valid_metric, test_metric)) 263 | 264 | lgr.p("-----\nBest epoch {}: valid {:.4f}, test {:.4f}".format( 265 | stats_to_save["epoch"], stats_to_save["valid_metric"], 266 | stats_to_save["test_metric"])) 267 | 268 | # Write outputs 269 | if args.verbose == 0: 270 | if args.result_path is None: 271 | result_path = os.path.join(args.path, "results") 272 | else: 273 | result_path = args.result_path 274 | if not os.path.exists(result_path): 275 | os.makedirs(result_path) 276 | result_file = ( 277 | "data__{}__valid__{}__test__{}__model__{}__lr__{}__h__{}__seed__{}" 278 | "__stamp__{}").format( 279 | args.dataset, stats_to_save["valid_metric"], 280 | stats_to_save["test_metric"], args.model_type, args.lr, 281 | args.hidden_size, args.seed, time_stamp) 282 | if args.dataset == "lsn": 283 | result_file += "__datafile__{}".format(os.path.splitext(datafile)[0]) 284 | with open(os.path.join(result_path, result_file), "wb") as f: 285 | pickle.dump(stats_to_save, f) 286 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Function 12 | from torch.nn import Linear 13 | from torch_geometric.nn import GCNConv, GATConv, APPNP 14 | from torch_geometric.utils import get_laplacian, to_dense_adj 15 | 16 | from scipy.special import gammainc 17 | from scipy.stats import poisson 18 | from torch.distributions.normal import Normal 19 | from torch.distributions.poisson import Poisson 20 | from copula import GaussianCopula 21 | 22 | EPS = 1e-3 23 | 24 | 25 | def grad_x_gammainc(a, x, grad_output): 26 | temp = -x + (a-1)*torch.log(x) - torch.lgamma(a) 27 | temp = torch.where(temp > -25, torch.exp(temp), torch.zeros_like(temp)) # avoid underflow 28 | return temp * grad_output # everything is element-wise 29 | 30 | 31 | class GammaIncFunc(Function): 32 | '''Regularized lower incomplete gamma function.''' 33 | @staticmethod 34 | def forward(ctx, a, x): 35 | # detach so we can cast to NumPy 36 | a = a.detach() 37 | x = x.detach() 38 | result = gammainc(a.cpu().numpy(), x.cpu().numpy()) 39 | result = torch.as_tensor(result, dtype=x.dtype, device=x.device) 40 | ctx.save_for_backward(a, x, result) 41 | return result 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | grad_output = grad_output.detach() 46 | a, x, result = ctx.saved_tensors 47 | grad_a = torch.zeros_like(a) # grad_a is never needed 48 | grad_x = grad_x_gammainc(a, x, grad_output) 49 | return grad_a, grad_x 50 | 51 | 52 | def _batch_normal_icdf(loc, scale, value): 53 | return loc[None, :] + scale[None, :] * torch.erfinv(2 * value - 1) * math.sqrt(2) 54 | 55 | 56 | class MLP(nn.Module): 57 | def __init__(self, 58 | num_features, 59 | hidden_size, 60 | dropout=0.5, 61 | activation="relu", 62 | *args, 63 | **kwargs): 64 | super().__init__(*args, **kwargs) 65 | self.fc1 = Linear(num_features, hidden_size) 66 | self.fc2 = Linear(hidden_size, 1) 67 | 68 | self.dropout = dropout 69 | assert activation in ["relu", "elu"] 70 | self.activation = getattr(F, activation) 71 | 72 | def forward(self, data): 73 | x = data.x 74 | x = self.activation(self.fc1(x)) 75 | x = F.dropout(x, p=self.dropout, training=self.training) 76 | x = self.fc2(x) 77 | return x.view(-1) 78 | 79 | 80 | class GCN(nn.Module): 81 | def __init__(self, 82 | num_features, 83 | hidden_size, 84 | dropout=0.5, 85 | activation="relu", 86 | *args, 87 | **kwargs): 88 | super().__init__(*args, **kwargs) 89 | self.conv1 = GCNConv(num_features, hidden_size) 90 | self.conv2 = GCNConv(hidden_size, 1) 91 | 92 | self.dropout = dropout 93 | assert activation in ["relu", "elu"] 94 | self.activation = getattr(F, activation) 95 | 96 | def forward(self, data): 97 | x, edge_index = data.x, data.edge_index 98 | x = self.activation(self.conv1(x, edge_index)) 99 | x = F.dropout(x, p=self.dropout, training=self.training) 100 | x = self.conv2(x, edge_index) 101 | return x.view(-1) 102 | 103 | 104 | class SAGE(nn.Module): 105 | def __init__(self, 106 | num_features, 107 | hidden_size, 108 | dropout=0.5, 109 | activation="relu", 110 | *args, 111 | **kwargs): 112 | super().__init__(*args, **kwargs) 113 | if hidden_size % 2 == 1: 114 | hidden_size += 1 115 | self.conv1 = GCNConv(num_features, hidden_size // 2) 116 | self.conv2 = GCNConv(hidden_size, 1) 117 | self.fc1 = Linear(num_features, hidden_size // 2) 118 | self.fc2 = Linear(hidden_size, 1) 119 | 120 | self.dropout = dropout 121 | assert activation in ["relu", "elu"] 122 | self.activation = getattr(F, activation) 123 | 124 | def forward(self, data): 125 | x, edge_index = data.x, data.edge_index 126 | x = torch.cat([self.conv1(x, edge_index), self.fc1(x)], dim=1) 127 | x = self.activation(x) 128 | x = F.dropout(x, p=self.dropout, training=self.training) 129 | x = self.conv2(x, edge_index) + self.fc2(x) 130 | return x.view(-1) 131 | 132 | 133 | class GAT(nn.Module): 134 | def __init__(self, 135 | num_features, 136 | hidden_size, 137 | dropout=0.5, 138 | activation="relu", 139 | num_heads=8, 140 | *args, 141 | **kwargs): 142 | super().__init__(*args, **kwargs) 143 | self.conv1 = GATConv( 144 | num_features, hidden_size, heads=num_heads, dropout=dropout) 145 | self.conv2 = GATConv( 146 | hidden_size * num_heads, 1, dropout=dropout) 147 | 148 | self.dropout = dropout 149 | assert activation in ["relu", "elu"] 150 | self.activation = getattr(F, activation) 151 | 152 | def forward(self, data): 153 | x, edge_index = data.x, data.edge_index 154 | x = self.activation(self.conv1(x, edge_index)) 155 | x = F.dropout(x, p=self.dropout, training=self.training) 156 | x = self.conv2(x, edge_index) 157 | return x.view(-1) 158 | 159 | 160 | class APPNPNet(nn.Module): 161 | def __init__(self, 162 | num_features, 163 | hidden_size, 164 | dropout=0.5, 165 | activation="relu", 166 | K=10, 167 | alpha=0.1, 168 | *args, 169 | **kwargs): 170 | super().__init__(*args, **kwargs) 171 | self.lin1 = Linear(num_features, hidden_size) 172 | self.lin2 = Linear(hidden_size, 1) 173 | self.prop1 = APPNP(K, alpha) 174 | 175 | self.dropout = dropout 176 | assert activation in ["relu", "elu"] 177 | self.activation = getattr(F, activation) 178 | 179 | def reset_parameters(self): 180 | self.lin1.reset_parameters() 181 | self.lin2.reset_parameters() 182 | 183 | def forward(self, data): 184 | x, edge_index = data.x, data.edge_index 185 | x = F.dropout(x, p=self.dropout, training=self.training) 186 | x = self.activation(self.lin1(x)) 187 | x = F.dropout(x, p=self.dropout, training=self.training) 188 | x = self.lin2(x) 189 | x = self.prop1(x, edge_index) 190 | return x.view(-1) 191 | 192 | 193 | class CopulaModel(nn.Module): 194 | 195 | def __init__(self, marginal_type, eps=EPS, *args, **kwargs): 196 | super().__init__(*args, **kwargs) 197 | self.marginal_type = marginal_type 198 | self.eps = eps 199 | 200 | def marginal(self, logits, cov=None): 201 | if self.marginal_type == "Normal": 202 | return Normal(loc=logits, scale=torch.diag(cov).pow(0.5)) 203 | elif self.marginal_type == "Poisson": 204 | return Poisson(rate=torch.exp(logits) + self.eps) 205 | else: 206 | raise NotImplementedError( 207 | "Marginal type `{}` not supported.".format(self.marginal_type)) 208 | 209 | def cdf(self, marginal, labels): 210 | if self.marginal_type == "Normal": 211 | res = marginal.cdf(labels) 212 | elif self.marginal_type == "Poisson": 213 | cdf_left = 1 - GammaIncFunc.apply(labels, marginal.mean) 214 | cdf_right = 1 - GammaIncFunc.apply(labels + 1, marginal.mean) 215 | res = (cdf_left + cdf_right) / 2 216 | else: 217 | raise NotImplementedError( 218 | "Marginal type `{}` not supported.".format(self.marginal_type)) 219 | return torch.clamp(res, self.eps, 1-self.eps) 220 | 221 | def icdf(self, marginal, u): 222 | if self.marginal_type == "Normal": 223 | res = _batch_normal_icdf(marginal.mean, marginal.stddev, u) 224 | elif self.marginal_type == "Poisson": 225 | mean = marginal.mean.detach().cpu().numpy() 226 | q = u.detach().cpu().numpy() 227 | res = poisson.ppf(q, mean) 228 | res = torch.as_tensor(res, dtype=u.dtype, device=u.device) 229 | else: 230 | raise NotImplementedError( 231 | "Marginal type `{}` not supported.".format(self.marginal_type)) 232 | if (res == float("inf")).sum() + (res == float("nan")).sum() > 0: 233 | # remove the rows containing inf or nan values 234 | inf_mask = res.sum(dim=-1) == float("inf") 235 | nan_mask = res.sum(dim=-1) == float("nan") 236 | res[inf_mask] = 0 237 | res[nan_mask] = 0 238 | valid_num = inf_mask.size(0) - inf_mask.sum() - nan_mask.sum() 239 | return res.sum(dim=0) / valid_num 240 | return res.mean(dim=0) 241 | 242 | def get_cov(self, data): 243 | raise NotImplementedError("`get_cov` not implemented.") 244 | 245 | def get_prec(self, data): 246 | raise NotImplementedError("`get_prec` not implemented.") 247 | 248 | def nll(self, data): 249 | cov = self.get_cov(data) 250 | cov = cov[data.train_mask, :] 251 | cov = cov[:, data.train_mask] 252 | logits = self.forward(data)[data.train_mask] 253 | labels = data.y[data.train_mask] 254 | 255 | copula = GaussianCopula(cov) 256 | marginal = self.marginal(logits, cov) 257 | 258 | u = self.cdf(marginal, labels) 259 | nll_copula = - copula.log_prob(u) 260 | nll_marginal = - torch.sum(marginal.log_prob(labels)) 261 | return (nll_copula + nll_marginal) / labels.size(0) 262 | 263 | def predict(self, data, num_samples=500): 264 | cond_mask = data.train_mask 265 | eval_mask = torch.logical_xor( 266 | torch.ones_like(data.train_mask).to(dtype=torch.bool), 267 | data.train_mask) 268 | cov = self.get_cov(data) 269 | logits = self.forward(data) 270 | copula = GaussianCopula(cov) 271 | 272 | cond_cov = (cov[cond_mask, :])[:, cond_mask] 273 | cond_marginal = self.marginal(logits[cond_mask], cond_cov) 274 | eval_cov = (cov[eval_mask, :])[:, eval_mask] 275 | eval_marginal = self.marginal(logits[eval_mask], eval_cov) 276 | 277 | cond_u = torch.clamp( 278 | self.cdf(cond_marginal, data.y[cond_mask]), self.eps, 1-self.eps) 279 | cond_idx = torch.where(cond_mask)[0] 280 | sample_idx = torch.where(eval_mask)[0] 281 | eval_u = copula.conditional_sample( 282 | cond_val=cond_u, sample_shape=[num_samples, ], cond_idx=cond_idx, 283 | sample_idx=sample_idx) 284 | eval_u = torch.clamp(eval_u, self.eps, 1-self.eps) 285 | eval_y = self.icdf(eval_marginal, eval_u) 286 | 287 | pred_y = data.y.clone() 288 | pred_y[eval_mask] = eval_y 289 | return pred_y 290 | 291 | 292 | class CorCopulaGCN(GCN, CopulaModel): 293 | 294 | def __init__(self, 295 | num_features, 296 | hidden_size, 297 | marginal_type="Normal", 298 | dropout=0., 299 | activation="relu"): 300 | super().__init__( 301 | num_features=num_features, hidden_size=hidden_size, 302 | dropout=dropout, activation=activation, 303 | marginal_type=marginal_type) 304 | 305 | self.alpha = nn.Parameter(torch.tensor(1.0)) 306 | self.beta = nn.Parameter(torch.tensor(3.0)) 307 | self.S = None 308 | self.I = None 309 | 310 | def get_prec(self, data): 311 | if self.S is None: 312 | adj = to_dense_adj(data.edge_index)[0].cpu().numpy() 313 | degree = adj.sum(axis=0) 314 | degree[degree==0] = 1 315 | D = np.diag(degree**(-0.5)) 316 | S = D.dot(adj).dot(D) 317 | self.S = torch.tensor(S, dtype=torch.float32).to(data.x.device) 318 | self.I = torch.eye(self.S.size(0)).to(data.x.device) 319 | prec = torch.exp(self.beta) * (self.I - torch.tanh(self.alpha) * self.S) 320 | return prec 321 | 322 | def get_cov(self, data): 323 | return torch.inverse(self.get_prec(data)) 324 | 325 | 326 | class RegCopulaGCN(GCN, CopulaModel): 327 | 328 | def __init__(self, 329 | num_features, 330 | hidden_size, 331 | marginal_type="Normal", 332 | dropout=0., 333 | activation="relu"): 334 | super().__init__( 335 | num_features=num_features, hidden_size=hidden_size, 336 | dropout=dropout, activation=activation, 337 | marginal_type=marginal_type) 338 | 339 | self.reg_fc1 = nn.Linear(num_features * 2, hidden_size) 340 | self.reg_fc2 = nn.Linear(hidden_size, 1) 341 | 342 | def get_prec(self, data): 343 | triangle_mask = data.edge_index[0] < data.edge_index[1] 344 | edge_index = torch.stack([data.edge_index[0][triangle_mask], 345 | data.edge_index[1][triangle_mask]]) 346 | x_query = F.embedding(edge_index[0], data.x) 347 | x_key = F.embedding(edge_index[1], data.x) 348 | x = torch.cat([x_query, x_key], dim=1) 349 | x = F.dropout(x, p=self.dropout, training=self.training) 350 | x = self.activation(self.reg_fc1(x)) 351 | x = F.dropout(x, p=self.dropout, training=self.training) 352 | x = F.softplus(self.reg_fc2(x)) 353 | und_edge_index = torch.stack( 354 | [torch.cat([edge_index[0], edge_index[1]], dim=0), 355 | torch.cat([edge_index[1], edge_index[0]], dim=0)]) 356 | und_edge_weight = torch.cat([x.view(-1), x.view(-1)], dim=0) 357 | L_edge_index, L_edge_weight = get_laplacian( 358 | und_edge_index, edge_weight=und_edge_weight, 359 | num_nodes=data.x.size(0)) 360 | L = to_dense_adj(L_edge_index, edge_attr=L_edge_weight)[0] 361 | return L + torch.eye(L.size(0), dtype=L.dtype, device=L.device) 362 | 363 | def get_cov(self, data): 364 | return torch.inverse(self.get_prec(data)) 365 | 366 | 367 | class CorCopulaSAGE(SAGE, CopulaModel): 368 | 369 | def __init__(self, 370 | num_features, 371 | hidden_size, 372 | marginal_type="Normal", 373 | dropout=0., 374 | activation="relu"): 375 | super().__init__( 376 | num_features=num_features, hidden_size=hidden_size, 377 | dropout=dropout, activation=activation, 378 | marginal_type=marginal_type) 379 | 380 | self.alpha = nn.Parameter(torch.tensor(1.0)) 381 | self.beta = nn.Parameter(torch.tensor(3.0)) 382 | self.S = None 383 | self.I = None 384 | 385 | def get_prec(self, data): 386 | if self.S is None: 387 | adj = to_dense_adj(data.edge_index)[0].cpu().numpy() 388 | degree = adj.sum(axis=0) 389 | degree[degree==0] = 1 390 | D = np.diag(degree**(-0.5)) 391 | S = D.dot(adj).dot(D) 392 | self.S = torch.tensor(S, dtype=torch.float32).to(data.x.device) 393 | self.I = torch.eye(self.S.size(0)).to(data.x.device) 394 | prec = torch.exp(self.beta) * (self.I - torch.tanh(self.alpha) * self.S) 395 | return prec 396 | 397 | def get_cov(self, data): 398 | return torch.inverse(self.get_prec(data)) 399 | 400 | 401 | class RegCopulaSAGE(SAGE, CopulaModel): 402 | 403 | def __init__(self, 404 | num_features, 405 | hidden_size, 406 | marginal_type="Normal", 407 | dropout=0., 408 | activation="relu"): 409 | super().__init__( 410 | num_features=num_features, hidden_size=hidden_size, 411 | dropout=dropout, activation=activation, 412 | marginal_type=marginal_type) 413 | 414 | self.reg_fc1 = nn.Linear(num_features * 2, hidden_size) 415 | self.reg_fc2 = nn.Linear(hidden_size, 1) 416 | 417 | def get_prec(self, data): 418 | triangle_mask = data.edge_index[0] < data.edge_index[1] 419 | edge_index = torch.stack([data.edge_index[0][triangle_mask], 420 | data.edge_index[1][triangle_mask]]) 421 | x_query = F.embedding(edge_index[0], data.x) 422 | x_key = F.embedding(edge_index[1], data.x) 423 | x = torch.cat([x_query, x_key], dim=1) 424 | x = F.dropout(x, p=self.dropout, training=self.training) 425 | x = self.activation(self.reg_fc1(x)) 426 | x = F.dropout(x, p=self.dropout, training=self.training) 427 | x = F.softplus(self.reg_fc2(x)) 428 | und_edge_index = torch.stack( 429 | [torch.cat([edge_index[0], edge_index[1]], dim=0), 430 | torch.cat([edge_index[1], edge_index[0]], dim=0)]) 431 | und_edge_weight = torch.cat([x.view(-1), x.view(-1)], dim=0) 432 | L_edge_index, L_edge_weight = get_laplacian( 433 | und_edge_index, edge_weight=und_edge_weight, 434 | num_nodes=data.x.size(0)) 435 | L = to_dense_adj(L_edge_index, edge_attr=L_edge_weight)[0] 436 | return L + torch.eye(L.size(0), dtype=L.dtype, device=L.device) 437 | 438 | def get_cov(self, data): 439 | return torch.inverse(self.get_prec(data)) 440 | --------------------------------------------------------------------------------