├── .gitignore ├── Analysis of Assortativity-Bottleneck and CTE latent space.ipynb ├── datasets ├── Erdos_Renyi_dataset.py └── SBM_dataset.py ├── environment_experiments.yml ├── experiments_all.sh ├── figs ├── AssortativeBottleneckAccuracy │ ├── COLLAB.png │ ├── HistAsssortBottleCOLLAB.pdf │ ├── HistAsssortBottleREDDIT-BINARY.pdf │ ├── REDDIT.png │ ├── ScatterAsssortBottleCOLLAB.pdf │ └── ScatterAsssortBottleREDDIT-BINARY.pdf ├── CT_REDDIT_READOUT_17_05_22__11_19.jpg ├── CT_REDDIT_READOUT_19_05_22__10_19.jpg ├── CT_REDDIT_READOUT_19_05_22__10_30.jpg ├── CT_RW.png ├── GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg ├── GAP_REDDIT_READOUT_17_05_22__11_31.jpg ├── GAP_REDDIT_READOUT_19_05_22__10_35.jpg ├── MinCut_REDDIT_READOUT_17_05_22__11_21.jpg ├── MinCut_REDDIT_READOUT_17_05_22__11_22.jpg ├── MinCut_REDDIT_READOUT_19_05_22__10_19.jpg ├── MinCut_REDDIT_READOUT_19_05_22__10_30.jpg ├── REDDIT_CT_READOUT_1.jpg └── data_statistics │ ├── Degree_hist_CIFAR10.png │ ├── Degree_hist_COLLAB.png │ ├── Degree_hist_ENZYMES.png │ ├── Degree_hist_IMDB-BINARY.png │ ├── Degree_hist_MNIST.png │ ├── Degree_hist_MUTAG.png │ ├── Degree_hist_PROTEINS.png │ ├── Degree_hist_REDDIT-BINARY.png │ ├── avg_degree_hist_CIFAR10.png │ ├── avg_degree_hist_COLLAB.png │ ├── avg_degree_hist_ENZYMES.png │ ├── avg_degree_hist_IMDB-BINARY.png │ ├── avg_degree_hist_MNIST.png │ ├── avg_degree_hist_MUTAG.png │ ├── avg_degree_hist_PROTEINS.png │ ├── avg_degree_hist_REDDIT-BINARY.png │ ├── number_nodes_hist_CIFAR10.png │ ├── number_nodes_hist_COLLAB.png │ ├── number_nodes_hist_ENZYMES.png │ ├── number_nodes_hist_IMDB-BINARY.png │ ├── number_nodes_hist_MNIST.png │ ├── number_nodes_hist_MUTAG.png │ ├── number_nodes_hist_PROTEINS.png │ └── number_nodes_hist_REDDIT-BINARY.png ├── layers ├── CT_layer.py ├── GAP_layer.py ├── MinCut_Layer.py └── utils │ ├── approximate_fiedler.py │ ├── ein_utils.py │ └── spectral_utils.py ├── nets.py ├── readme.md ├── requirements.txt ├── train.py ├── trained_models ├── CTNet │ ├── COLLAB_CTNet_17_05_22__08_56_iter0.pth │ ├── ERDOS_CTNet_19_05_22__14_46_iter1.pth │ ├── IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth │ ├── MUTAG_CTNet_19_05_22__11_35_iter0.pth │ ├── REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth │ └── SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth ├── GAPNet │ ├── COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth │ ├── ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth │ ├── IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth │ ├── MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth │ ├── REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth │ └── SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth └── Laplacian │ ├── COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth │ └── REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth └── transforms ├── __init__.py ├── sdrf ├── curvature.py ├── sdrf_transform.py └── utils.py └── transform_features.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .txt 6 | # C extensions 7 | *.so 8 | .DS_Store 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | /data 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | models/ 133 | *.sh 134 | !experiments_all.sh 135 | data_colab/ 136 | logs*/ 137 | !logs_stratified/ 138 | -------------------------------------------------------------------------------- /datasets/Erdos_Renyi_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric.utils.random import erdos_renyi_graph 3 | import torch 4 | from torch_geometric.data import Data, InMemoryDataset 5 | from torch_geometric.data import DataLoader 6 | from torch_geometric.utils import to_dense_adj 7 | 8 | class Erdos_Renyi_pyg(InMemoryDataset): 9 | def __init__(self, root, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1_min=0.1, p1_max=0.5, 10 | p2_min=0.5, p2_max=0.8, directed=False, transform=None, pre_transform=None): 11 | self.nb_nodes1 = nb_nodes1 12 | self.nb_graphs1 = nb_graphs1 13 | self.nb_nodes2 = nb_nodes2 14 | self.nb_graphs2 = nb_graphs2 15 | self.nb_graphs = self.nb_graphs1 + self.nb_graphs2 16 | 17 | self.directed = directed 18 | 19 | self.p1_min = p1_min 20 | self.p1_max = p1_max 21 | self.p2_min = p2_min 22 | self.p2_max = p2_max 23 | 24 | self.details = f"{self.nb_nodes1}nodesT{self.nb_graphs}graphsT" \ 25 | + f"{int(p1_min*100)}to{int(p1_max*100)}p1T{int(p2_min*100)}to{int(p2_max*100)}p2" 26 | self.root = f"{root}_{self.details}" 27 | print(self.root) 28 | super(Erdos_Renyi_pyg, self).__init__(self.root, transform, pre_transform) 29 | self.data, self.slices = torch.load(self.processed_paths[0]) 30 | 31 | @property 32 | def raw_file_names(self): 33 | return ['tentative'] 34 | 35 | @property 36 | def processed_file_names(self): 37 | return ['data.pt'] 38 | 39 | def download(self): 40 | pass 41 | 42 | def process(self): 43 | # Read data into huge `Data` lists. 44 | data_list1 = self._generate_graphs(self.nb_nodes1, self.nb_graphs1, self.p1_min, self.p1_max, 0) 45 | data_list2 = self._generate_graphs(self.nb_nodes2, self.nb_graphs2, self.p2_min, self.p2_max, 1) 46 | data_list = [*data_list1, *data_list2] 47 | 48 | if self.pre_filter is not None: 49 | data_list = [data for data in data_list if self.pre_filter(data)] 50 | 51 | if self.pre_transform is not None: 52 | data_list = [self.pre_transform(data) for data in data_list] 53 | 54 | data, slices = self.collate(data_list) 55 | torch.save((data, slices), self.processed_paths[0]) 56 | 57 | def _generate_graphs(self, nb_nodes, nb_graphs, p_min, p_max, myclass): 58 | dataset = [] 59 | m = (p_max - p_min)/nb_graphs # Linear slope 60 | dataset = [] 61 | for i in range(nb_graphs): 62 | p_i = m*(i+1) + p_min 63 | # Get the SBM graph 64 | d = erdos_renyi_graph(num_nodes=nb_nodes, edge_prob=p_i, directed=self.directed) 65 | 66 | # Get degree as feature 67 | adj = to_dense_adj(d) 68 | A = adj 69 | D = A.sum(dim=1) 70 | x = torch.transpose(D,0,1) 71 | mydata = Data(x=x, edge_index=d, y = myclass, num_nodes=nb_nodes) 72 | dataset.append(mydata) 73 | return dataset 74 | 75 | def __repr__(self) -> str: 76 | return f'{self.__class__.__name__}({self.details})' 77 | 78 | 79 | if __name__ == "__main__": 80 | print("Tesdting Erdos-Renyi") 81 | dataset = Erdos_Renyi_pyg('../data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, 82 | p1_min=0.1, p1_max=0.5, p2_min=0.3, p2_max=0.8) 83 | print() 84 | print(f'Dataset: {dataset}:') 85 | print('====================') 86 | print(f'Number of graphs: {len(dataset)}') 87 | print(f'Number of features: {dataset.num_features}') 88 | print(f'Number of classes: {dataset.num_classes}') 89 | 90 | data = dataset[0] # Get the first graph object. 91 | 92 | print() 93 | print(data) 94 | print('=============================================================') 95 | 96 | # Gather some statistics about the first graph. 97 | print(f'Number of nodes: {data.num_nodes}') 98 | print(f'Number of edges: {data.num_edges}') 99 | print(f'Label: {data.y}') 100 | print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') 101 | print(f'Has isolated nodes: {data.has_isolated_nodes()}') 102 | print(f'Has self-loops: {data.has_self_loops()}') 103 | print(f'Is undirected: {data.is_undirected()}') 104 | print(dataset.details) 105 | 106 | """train_dataset = dataset[:int(0.8*len(dataset))] 107 | test_dataset = dataset[int(0.8*len(dataset)):] 108 | 109 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Original 64 110 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Original 64 111 | 112 | 113 | for step, data in enumerate(train_loader): 114 | print(f'Step {step + 1}:') 115 | print('=======') 116 | print(f'Number of graphs in the current batch: {data.num_graphs}') 117 | print(data) 118 | print()""" 119 | 120 | -------------------------------------------------------------------------------- /datasets/SBM_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, InMemoryDataset 3 | from torch_geometric.utils.random import stochastic_blockmodel_graph 4 | from torch_geometric.data import DataLoader 5 | from torch_geometric.utils import to_dense_adj 6 | 7 | class SBM_pyg(InMemoryDataset): 8 | def __init__(self, root, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1=0.8, p2=0.5, qmin1=0.5, qmax1=0.8, qmin2=0.25, qmax2=0.7, directed=False, transform=None, pre_transform=None): 9 | """ 10 | Create SBM dataset with graph of 2 classes. Each graph will have 2 communities. 11 | For each class we can parametrize the number of nodes, number of graphs and 12 | intra-interclass edge probability max an min probability. 13 | 14 | Args: 15 | root (str): path to save the data. 16 | nb_nodes1 (int, optional): number of nodes in the graphs of class 1. Defaults to 200. 17 | nb_graphs1 (int, optional): number of graphs of class 1. Defaults to 500. 18 | nb_nodes2 (int, optional): number of nodes in the graphs of class 2. Defaults to 200. 19 | nb_graphs2 (int, optional): number of graphs of class 2. Defaults to 500. 20 | p1 (float, optional): intraclass edge probability for community 1 for both graph classes. Defaults to 0.8. 21 | p2 (float, optional): intraclass edge probability for community 2 for both graph classes. Defaults to 0.5. 22 | qmin1 (float, optional): minimun intercalass probability for graphs in class 1. Defaults to 0.5. 23 | qmax1 (float, optional): minimun interclass probability for graphs in class 2. Defaults to 0.8. 24 | qmin2 (float, optional): maximun intercalass probability for graphs in class 1. Defaults to 0.25. 25 | qmax2 (float, optional): maximun intercalass probability for graphs in class 2. Defaults to 0.7. 26 | directed (bool, optional): Create directed or Undirected Graphs. Defaults to False. 27 | transform (torch_geometric.transforms.BaseTransform, optional): on the fly transformation. Defaults to None. 28 | pre_transform (torch_geometric.transforms.BaseTransform, optional): transformation to save in disk. Defaults to None. 29 | """ 30 | self.nb_nodes1 = nb_nodes1 31 | self.nb_graphs1 = nb_graphs1 32 | self.nb_nodes2 = nb_nodes2 33 | self.nb_graphs2 = nb_graphs2 34 | self.nb_graphs = self.nb_graphs1 + self.nb_graphs2 35 | #self.num_features = 1 # Degree 36 | #self.num_classes = 2 37 | self.edge_probs1_min = [[p1, qmin1], [qmin1, p2]] # Minimal setting class 1 38 | self.edge_probs2_min = [[p1, qmin2], [qmin2, p2]] # Minimal setting class 2 39 | self.edge_probs1_max = [[p1, qmax1], [qmax1, p2]] # Maximal setting class 1 40 | self.edge_probs2_max = [[p1, qmax2], [qmax2, p2]] # Maximal setting class 2 41 | 42 | self.details = f"{self.nb_nodes1}nodesT{self.nb_graphs}graphsT{int(p1*100)}" \ 43 | + f"p1T{int(p2*100)}p2T{int(qmin1*100)}to{int(qmax1*100)}q1T{int(qmin2*100)}to{int(qmax2*100)}q2" 44 | self.root = f"{root}_{self.details}" 45 | print(self.root) 46 | super(SBM_pyg, self).__init__(self.root, transform, pre_transform) 47 | self.data, self.slices = torch.load(self.processed_paths[0]) 48 | 49 | @property 50 | def raw_file_names(self): 51 | return ['tentative'] 52 | 53 | @property 54 | def processed_file_names(self): 55 | return ['data.pt'] 56 | 57 | def download(self): 58 | pass 59 | 60 | def process(self): 61 | # Read data into huge `Data` lists. 62 | data_list1 = self._generate_graphs(self.nb_nodes1, self.nb_graphs1, self.edge_probs1_min, self.edge_probs1_max, 0) 63 | data_list2 = self._generate_graphs(self.nb_nodes2, self.nb_graphs2, self.edge_probs2_min, self.edge_probs2_max, 1) 64 | data_list = [*data_list1, *data_list2] 65 | 66 | if self.pre_filter is not None: 67 | data_list = [data for data in data_list if self.pre_filter(data)] 68 | 69 | if self.pre_transform is not None: 70 | data_list = [self.pre_transform(data) for data in data_list] 71 | 72 | data, slices = self.collate(data_list) 73 | torch.save((data, slices), self.processed_paths[0]) 74 | 75 | def _generate_graphs(self, nb_nodes, nb_graphs, edge_probs_min, edge_probs_max, myclass): 76 | # p and q are static 77 | # qmin < qmax 78 | # for each graph move from qmin to qmax 79 | qmin = edge_probs_min[1][0] 80 | qmax = edge_probs_max[1][0] 81 | m = (qmax - qmin)/nb_graphs # Linear slope 82 | dataset = [] 83 | for i in range(nb_graphs): 84 | q = m*(i+1) + qmin 85 | p1 = edge_probs_min[0][0] 86 | p2 = edge_probs_min[1][1] 87 | # Get the SBM graph 88 | d = stochastic_blockmodel_graph(block_sizes=[int(nb_nodes/2),int(nb_nodes/2)], edge_probs=[[p1, q],[q, p2]], directed=False) 89 | 90 | # Get degree as feature 91 | adj = to_dense_adj(d) 92 | A = adj 93 | D = A.sum(dim=1) 94 | x = torch.transpose(D,0,1) 95 | mydata = Data(x=x, edge_index=d, y = myclass, num_nodes=nb_nodes) 96 | dataset.append(mydata) 97 | return dataset 98 | 99 | def __repr__(self) -> str: 100 | return f'{self.__class__.__name__}({self.details})' 101 | 102 | 103 | if __name__ == "__main__": 104 | print("Tesdting SBM") 105 | dataset = SBM_pyg('../data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1=0.8, p2=0.5, qmin1=0.5, qmax1=0.8, qmin2=0.25, qmax2=0.71, directed=False, transform=None, pre_transform=None) 106 | 107 | print() 108 | print(f'Dataset: {dataset}:') 109 | print('====================') 110 | print(f'Number of graphs: {len(dataset)}') 111 | print(f'Number of features: {dataset.num_features}') 112 | print(f'Number of classes: {dataset.num_classes}') 113 | 114 | data = dataset[0] # Get the first graph object. 115 | 116 | print() 117 | print(data) 118 | print('=============================================================') 119 | 120 | # Gather some statistics about the first graph. 121 | print(f'Number of nodes: {data.num_nodes}') 122 | print(f'Number of edges: {data.num_edges}') 123 | print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') 124 | print(f'Has isolated nodes: {data.has_isolated_nodes()}') 125 | print(f'Has self-loops: {data.has_self_loops()}') 126 | print(f'Is undirected: {data.is_undirected()}') 127 | print(dataset.details) 128 | 129 | """train_dataset = dataset[:int(0.8*len(dataset))] 130 | test_dataset = dataset[int(0.8*len(dataset)):] 131 | 132 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Original 64 133 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Original 64 134 | 135 | 136 | for step, data in enumerate(train_loader): 137 | print(f'Step {step + 1}:') 138 | print('=======') 139 | print(f'Number of graphs in the current batch: {data.num_graphs}') 140 | print(data) 141 | print()""" 142 | 143 | -------------------------------------------------------------------------------- /environment_experiments.yml: -------------------------------------------------------------------------------- 1 | name: DiffWire 2 | channels: 3 | - pyg 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - blas=1.0=mkl 10 | - bottleneck=1.3.4=py39hce1f21e_0 11 | - brotli=1.0.9=he6710b0_2 12 | - brotlipy=0.7.0=py39h27cfd23_1003 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2022.4.26=h06a4308_0 15 | - certifi=2021.10.8=py39h06a4308_2 16 | - cffi=1.15.0=py39hd667e15_1 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - colorama=0.4.4=pyhd3eb1b0_0 19 | - cryptography=36.0.0=py39h9ce1e76_0 20 | - cudatoolkit=10.2.89=hfd86e86_1 21 | - cycler=0.11.0=pyhd3eb1b0_0 22 | - dbus=1.13.18=hb2f20db_0 23 | - expat=2.4.4=h295c915_0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - fontconfig=2.13.1=h6c09931_0 26 | - fonttools=4.25.0=pyhd3eb1b0_0 27 | - freetype=2.11.0=h70c0345_0 28 | - giflib=5.2.1=h7b6447c_0 29 | - glib=2.69.1=h4ff587b_1 30 | - gmp=6.2.1=h2531618_2 31 | - gnutls=3.6.15=he1e5248_0 32 | - gst-plugins-base=1.14.0=h8213a91_2 33 | - gstreamer=1.14.0=h28cd5cc_2 34 | - h5py=3.6.0=py39ha0f2276_0 35 | - hdf5=1.10.6=hb1b8bf9_0 36 | - icu=58.2=he6710b0_3 37 | - idna=3.3=pyhd3eb1b0_0 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - jinja2=3.0.3=pyhd3eb1b0_0 40 | - joblib=1.1.0=pyhd3eb1b0_0 41 | - jpeg=9d=h7f8727e_0 42 | - kiwisolver=1.3.2=py39h295c915_0 43 | - lame=3.100=h7b6447c_0 44 | - lcms2=2.12=h3be6417_0 45 | - ld_impl_linux-64=2.35.1=h7274673_9 46 | - libffi=3.3=he6710b0_2 47 | - libgcc-ng=9.3.0=h5101ec6_17 48 | - libgfortran-ng=7.5.0=ha8ba4b0_17 49 | - libgfortran4=7.5.0=ha8ba4b0_17 50 | - libgomp=9.3.0=h5101ec6_17 51 | - libiconv=1.15=h63c8f33_5 52 | - libidn2=2.3.2=h7f8727e_0 53 | - libpng=1.6.37=hbc83047_0 54 | - libprotobuf=3.19.1=h4ff587b_0 55 | - libstdcxx-ng=9.3.0=hd4cf53a_17 56 | - libtasn1=4.16.0=h27cfd23_0 57 | - libtiff=4.2.0=h85742a9_0 58 | - libunistring=0.9.10=h27cfd23_0 59 | - libuuid=1.0.3=h7f8727e_2 60 | - libuv=1.40.0=h7b6447c_0 61 | - libwebp=1.2.2=h55f646e_0 62 | - libwebp-base=1.2.2=h7f8727e_0 63 | - libxcb=1.14=h7b6447c_0 64 | - libxml2=2.9.12=h03d6c58_0 65 | - lz4-c=1.9.3=h295c915_1 66 | - markupsafe=2.0.1=py39h27cfd23_0 67 | - matplotlib=3.5.1=py39h06a4308_1 68 | - matplotlib-base=3.5.1=py39ha18d171_1 69 | - mkl=2021.4.0=h06a4308_640 70 | - mkl-service=2.4.0=py39h7f8727e_0 71 | - mkl_fft=1.3.1=py39hd3c417c_0 72 | - mkl_random=1.2.2=py39h51133e4_0 73 | - munkres=1.1.4=py_0 74 | - ncurses=6.3=h7f8727e_2 75 | - nettle=3.7.3=hbbd107a_1 76 | - networkx=2.7.1=pyhd3eb1b0_0 77 | - numexpr=2.8.1=py39h6abb31d_0 78 | - numpy=1.21.2=py39h20f2e39_0 79 | - numpy-base=1.21.2=py39h79a1101_0 80 | - openh264=2.1.1=h4ff587b_0 81 | - openssl=1.1.1o=h7f8727e_0 82 | - packaging=21.3=pyhd3eb1b0_0 83 | - pandas=1.4.1=py39h295c915_1 84 | - path=16.2.0=pyhd3eb1b0_0 85 | - pcre=8.45=h295c915_0 86 | - pillow=9.0.1=py39h22f2fdc_0 87 | - pip=21.2.4=py39h06a4308_0 88 | - plotly=5.6.0=pyhd3eb1b0_0 89 | - protobuf=3.19.1=py39h295c915_0 90 | - pycparser=2.21=pyhd3eb1b0_0 91 | - pyg=2.0.4=py39_torch_1.11.0_cu102 92 | - pyopenssl=22.0.0=pyhd3eb1b0_0 93 | - pyparsing=3.0.4=pyhd3eb1b0_0 94 | - pyqt=5.9.2=py39h2531618_6 95 | - pysocks=1.7.1=py39h06a4308_0 96 | - python=3.9.11=h12debd9_2 97 | - python-dateutil=2.8.2=pyhd3eb1b0_0 98 | - python-louvain=0.15=pyhd3eb1b0_0 99 | - pytorch=1.11.0=py3.9_cuda10.2_cudnn7.6.5_0 100 | - pytorch-cluster=1.6.0=py39_torch_1.11.0_cu102 101 | - pytorch-mutex=1.0=cuda 102 | - pytorch-scatter=2.0.9=py39_torch_1.11.0_cu102 103 | - pytorch-sparse=0.6.13=py39_torch_1.11.0_cu102 104 | - pytorch-spline-conv=1.2.1=py39_torch_1.11.0_cu102 105 | - pytz=2021.3=pyhd3eb1b0_0 106 | - pyyaml=6.0=py39h7f8727e_1 107 | - qt=5.9.7=h5867ecd_1 108 | - readline=8.1.2=h7f8727e_1 109 | - requests=2.27.1=pyhd3eb1b0_0 110 | - scikit-learn=1.0.2=py39h51133e4_1 111 | - scipy=1.7.3=py39hc147768_0 112 | - seaborn=0.11.2=pyhd3eb1b0_0 113 | - setuptools=58.0.4=py39h06a4308_0 114 | - sip=4.19.13=py39h295c915_0 115 | - six=1.16.0=pyhd3eb1b0_1 116 | - sqlite=3.38.0=hc218d9a_0 117 | - tenacity=8.0.1=py39h06a4308_0 118 | - tensorboardx=2.2=pyhd3eb1b0_0 119 | - threadpoolctl=2.2.0=pyh0d69192_0 120 | - tk=8.6.11=h1ccaba5_0 121 | - torchaudio=0.11.0=py39_cu102 122 | - torchvision=0.12.0=py39_cu102 123 | - tornado=6.1=py39h27cfd23_0 124 | - tqdm=4.63.0=pyhd3eb1b0_0 125 | - typing_extensions=4.1.1=pyh06a4308_0 126 | - tzdata=2021e=hda174b7_0 127 | - urllib3=1.26.8=pyhd3eb1b0_0 128 | - wheel=0.37.1=pyhd3eb1b0_0 129 | - xz=5.2.5=h7b6447c_0 130 | - yacs=0.1.6=pyhd3eb1b0_1 131 | - yaml=0.2.5=h7b6447c_0 132 | - zlib=1.2.11=h7f8727e_4 133 | - zstd=1.4.9=haebb681_0 134 | prefix: /home/user/anaconda3/envs/tfg 135 | -------------------------------------------------------------------------------- /experiments_all.sh: -------------------------------------------------------------------------------- 1 | python train.py --dataset MUTAG --model MinCutNet --cuda cuda:0 2 | python train.py --dataset MUTAG --model MinCutNet --cuda cuda:0 --prepro digl 3 | python train.py --dataset MUTAG --model CTNet --cuda cuda:0 4 | python train.py --dataset MUTAG --model GAPNet --derivative laplacian --cuda cuda:0 5 | python train.py --dataset MUTAG --model GAPNet --derivative normalized --cuda cuda:0 6 | python train.py --dataset MUTAG --model GAPNet --derivative normalizedv2 --cuda cuda:0 7 | python train.py --dataset MUTAG --model DiffWire --cuda cuda:0 8 | 9 | python train.py --dataset ENZYMES --model MinCutNet --cuda cuda:0 10 | python train.py --dataset ENZYMES --model MinCutNet --cuda cuda:0 --prepro digl 11 | python train.py --dataset ENZYMES --model CTNet --cuda cuda:0 12 | python train.py --dataset ENZYMES --model GAPNet --derivative laplacian --cuda cuda:0 13 | python train.py --dataset ENZYMES --model GAPNet --derivative normalized --cuda cuda:0 14 | python train.py --dataset ENZYMES --model GAPNet --derivative normalizedv2 --cuda cuda:0 15 | python train.py --dataset ENZYMES --model DiffWire --cuda cuda:0 16 | 17 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 18 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 --prepro digl 19 | python train.py --dataset PROTEINS --model CTNet --cuda cuda:0 20 | python train.py --dataset PROTEINS --model GAPNet --derivative laplacian --cuda cuda:0 21 | python train.py --dataset PROTEINS --model GAPNet --derivative normalized --cuda cuda:0 22 | python train.py --dataset PROTEINS --model GAPNet --derivative normalizedv2 --cuda cuda:0 23 | python train.py --dataset PROTEINS --model DiffWire --cuda cuda:0 24 | 25 | python train.py --dataset REDDIT-BINARY --model MinCutNet --cuda cuda:0 26 | python train.py --dataset REDDIT-BINARY --model MinCutNet --cuda cuda:0 --prepro digl 27 | python train.py --dataset REDDIT-BINARY --model CTNet --cuda cuda:0 28 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative laplacian --cuda cuda:0 29 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalized --cuda cuda:0 30 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalizedv2 --cuda cuda:0 31 | python train.py --dataset REDDIT-BINARY --model DiffWire --cuda cuda:0 32 | 33 | python train.py --dataset COLLAB --model MinCutNet --cuda cuda:0 34 | python train.py --dataset COLLAB --model MinCutNet --cuda cuda:0 --prepro digl 35 | python train.py --dataset COLLAB --model CTNet --cuda cuda:0 36 | python train.py --dataset COLLAB --model GAPNet --derivative laplacian --cuda cuda:0 37 | python train.py --dataset COLLAB --model GAPNet --derivative normalized --cuda cuda:0 38 | python train.py --dataset COLLAB --model GAPNet --derivative normalizedv2 --cuda cuda:0 39 | python train.py --dataset COLLAB --model DiffWire --cuda cuda:0 40 | 41 | python train.py --dataset IMDB-BINARY --model MinCutNet --cuda cuda:0 42 | python train.py --dataset IMDB-BINARY --model MinCutNet --cuda cuda:0 --prepro digl 43 | python train.py --dataset IMDB-BINARY --model CTNet --cuda cuda:0 44 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative laplacian --cuda cuda:0 45 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative normalized --cuda cuda:0 46 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative normalizedv2 --cuda cuda:0 47 | python train.py --dataset IMDB-BINARY --model DiffWire --cuda cuda:0 48 | 49 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 50 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 --prepro digl 51 | python train.py --dataset PROTEINS --model CTNet --cuda cuda:0 52 | python train.py --dataset PROTEINS --model GAPNet --derivative laplacian --cuda cuda:0 53 | python train.py --dataset PROTEINS --model GAPNet --derivative normalized --cuda cuda:0 54 | python train.py --dataset PROTEINS --model GAPNet --derivative normalizedv2 --cuda cuda:0 55 | python train.py --dataset PROTEINS --model DiffWire --cuda cuda:0 56 | 57 | python train.py --dataset MNIST --model MinCutNet --cuda cuda:0 58 | python train.py --dataset MNIST --model MinCutNet --cuda cuda:0 --prepro digl 59 | python train.py --dataset MNIST --model CTNet --cuda cuda:0 60 | python train.py --dataset MNIST --model GAPNet --derivative laplacian --cuda cuda:0 61 | python train.py --dataset MNIST --model GAPNet --derivative normalized --cuda cuda:0 62 | python train.py --dataset MNIST --model GAPNet --derivative normalizedv2 --cuda cuda:0 63 | python train.py --dataset MNIST --model DiffWire --cuda cuda:0 64 | 65 | python train.py --dataset CIFAR10 --model MinCutNet --cuda cuda:0 66 | python train.py --dataset CIFAR10 --model MinCutNet --cuda cuda:0 --prepro digl 67 | python train.py --dataset CIFAR10 --model CTNet --cuda cuda:0 68 | python train.py --dataset CIFAR10 --model GAPNet --derivative laplacian --cuda cuda:0 69 | python train.py --dataset CIFAR10 --model GAPNet --derivative normalized --cuda cuda:0 70 | python train.py --dataset CIFAR10 --model GAPNet --derivative normalizedv2 --cuda cuda:0 71 | python train.py --dataset CIFAR10 --model DiffWire --cuda cuda:0 72 | 73 | python train.py --dataset CSL --model MinCutNet --cuda cuda:0 74 | python train.py --dataset CSL --model CTNet --cuda cuda:0 75 | python train.py --dataset CSL --model GAPNet --derivative laplacian --cuda cuda:0 76 | python train.py --dataset CSL --model GAPNet --derivative normalized --cuda cuda:0 77 | python train.py --dataset CSL --model GAPNet --derivative normalizedv2 --cuda cuda:0 78 | python train.py --dataset CSL --model DiffWire --cuda cuda:0 -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/COLLAB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/COLLAB.png -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/HistAsssortBottleCOLLAB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/HistAsssortBottleCOLLAB.pdf -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/HistAsssortBottleREDDIT-BINARY.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/HistAsssortBottleREDDIT-BINARY.pdf -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/REDDIT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/REDDIT.png -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleCOLLAB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleCOLLAB.pdf -------------------------------------------------------------------------------- /figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleREDDIT-BINARY.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleREDDIT-BINARY.pdf -------------------------------------------------------------------------------- /figs/CT_REDDIT_READOUT_17_05_22__11_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_17_05_22__11_19.jpg -------------------------------------------------------------------------------- /figs/CT_REDDIT_READOUT_19_05_22__10_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_19_05_22__10_19.jpg -------------------------------------------------------------------------------- /figs/CT_REDDIT_READOUT_19_05_22__10_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_19_05_22__10_30.jpg -------------------------------------------------------------------------------- /figs/CT_RW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_RW.png -------------------------------------------------------------------------------- /figs/GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg -------------------------------------------------------------------------------- /figs/GAP_REDDIT_READOUT_17_05_22__11_31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAP_REDDIT_READOUT_17_05_22__11_31.jpg -------------------------------------------------------------------------------- /figs/GAP_REDDIT_READOUT_19_05_22__10_35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAP_REDDIT_READOUT_19_05_22__10_35.jpg -------------------------------------------------------------------------------- /figs/MinCut_REDDIT_READOUT_17_05_22__11_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_17_05_22__11_21.jpg -------------------------------------------------------------------------------- /figs/MinCut_REDDIT_READOUT_17_05_22__11_22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_17_05_22__11_22.jpg -------------------------------------------------------------------------------- /figs/MinCut_REDDIT_READOUT_19_05_22__10_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_19_05_22__10_19.jpg -------------------------------------------------------------------------------- /figs/MinCut_REDDIT_READOUT_19_05_22__10_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_19_05_22__10_30.jpg -------------------------------------------------------------------------------- /figs/REDDIT_CT_READOUT_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/REDDIT_CT_READOUT_1.jpg -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_CIFAR10.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_COLLAB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_COLLAB.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_ENZYMES.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_ENZYMES.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_IMDB-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_IMDB-BINARY.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_MNIST.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_MUTAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_MUTAG.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_PROTEINS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_PROTEINS.png -------------------------------------------------------------------------------- /figs/data_statistics/Degree_hist_REDDIT-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_REDDIT-BINARY.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_CIFAR10.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_COLLAB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_COLLAB.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_ENZYMES.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_ENZYMES.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_IMDB-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_IMDB-BINARY.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_MNIST.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_MUTAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_MUTAG.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_PROTEINS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_PROTEINS.png -------------------------------------------------------------------------------- /figs/data_statistics/avg_degree_hist_REDDIT-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_REDDIT-BINARY.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_CIFAR10.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_COLLAB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_COLLAB.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_ENZYMES.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_ENZYMES.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_IMDB-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_IMDB-BINARY.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_MNIST.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_MUTAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_MUTAG.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_PROTEINS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_PROTEINS.png -------------------------------------------------------------------------------- /figs/data_statistics/number_nodes_hist_REDDIT-BINARY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_REDDIT-BINARY.png -------------------------------------------------------------------------------- /layers/CT_layer.py: -------------------------------------------------------------------------------- 1 | # Commute Times rewiring 2 | #Graph Convolutional Network layer where the graph structure is given by an adjacency matrix. 3 | # We recommend user to use this module when applying graph convolution on dense graphs. 4 | #from torch_geometric.nn import GCNConv, DenseGraphConv 5 | import torch 6 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace 7 | 8 | def dense_CT_rewiring(x, adj, s, mask=None, EPS=1e-15): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20 9 | #print("Input x size to mincut pool", x.size()) 10 | x = x.unsqueeze(0) if x.dim() == 2 else x # x torch.Size([20, 40, 32]) if x has not 2 parameters 11 | #print("Unsqueezed x size to mincut pool", x.size(), x.dim()) # x.dim() is usually 3 12 | 13 | # adj torch.Size([20, N, N]) N=Mmax 14 | #print("Input adj size to mincut pool", adj.size()) 15 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax 16 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3 17 | 18 | # s torch.Size([20, N, k]) 19 | s = s.unsqueeze(0) if s.dim() == 2 else s # s torch.Size([20, N, k]) 20 | #print("Unsqueezed s size", s.size()) 21 | 22 | # x torch.Size([20, N, 32]) if x has not 2 parameters 23 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 24 | #print("batch_size and num_nodes", batch_size, num_nodes, k) # batch_size = 20, num_nodes = N, k = 16 25 | s = torch.tanh(s) # torch.Size([20, N, k]) One k for each N of each graph 26 | #print("s softmax size", s.size()) 27 | 28 | if mask is not None: # NOT None for now 29 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 30 | #print("mask size", mask.size()) # [20, N, 1] 31 | # Mask pointwise product. Since x is [20, N, 32] and s is [20, N, k] 32 | x, s = x * mask, s * mask # x*mask = [20, N, 32]*[20, N, 1] = [20, N, 32] s*mask = [20, N, k]*[20, N, 1] = [20, N, k] 33 | #print("x and s sizes after multiplying by mask", x.size(), s.size() 34 | 35 | # CT regularization 36 | # Calculate degree d_flat and degree matrix d 37 | d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N]) 38 | #print("d_flat size", d_flat.size()) 39 | d = _rank3_diag(d_flat)+EPS # d torch.Size([20, N, N]) 40 | #print("d size", d.size()) 41 | 42 | # Calculate Laplacian L = D - A 43 | L = d - adj 44 | #print("Laplacian", L[1,:,:]) 45 | 46 | # Calculate out_adj as A_CT = S.T*L*S 47 | # out_adj: this tensor contains A_CT = S.T*L*S so that we can take its trace and retain coarsened adjacency (Eq. 7) 48 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), L), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes 49 | #print("out_adj size", out_adj.size()) 50 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal 51 | 52 | # Calculate CT_num 53 | CT_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph 54 | #print("CT_num size", CT_num.size()) 55 | #print("CT_num", CT_num) 56 | # Calculate CT_den 57 | CT_den = _rank3_trace( 58 | torch.matmul(torch.matmul(s.transpose(1, 2), d ), s))+EPS # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph 59 | #print("CT_den size", CT_den.size()) 60 | #print("CT_den", CT_den) 61 | 62 | # Calculate CT_dist (distance matrix) 63 | CT_dist = torch.cdist(s,s) # [20, N, k], [20, N, k]-> [20,N,N] 64 | #print("CT_dist",CT_dist) 65 | 66 | # Calculate Vol (volumes): one per graph 67 | vol = _rank3_trace(d) # torch.Size([20]) 68 | #print("vol size", vol.size()) 69 | 70 | 71 | #print("vol_flat size", vol_flat.size()) 72 | vol = _rank3_trace(d) # d torch.Size([20, N, N]) 73 | #print("vol size", vol.size()) 74 | #print("vol", vol) 75 | 76 | # Calculate out_adj as CT_dist*(N-1)/vol(G) 77 | N = adj.size(1) 78 | #CT_dist = (CT_dist*(N-1)) / vol.unsqueeze(1).unsqueeze(1) 79 | CT_dist = (CT_dist) / vol.unsqueeze(1).unsqueeze(1) 80 | #CT_dist = (CT_dist) / ((N-1)*vol).unsqueeze(1).unsqueeze(1) 81 | 82 | #print("R_dist",CT_dist) 83 | 84 | # Mask with adjacency if proceeds 85 | adj = CT_dist*adj 86 | #adj = CT_dist 87 | 88 | # Losses 89 | CT_loss = CT_num / CT_den 90 | CT_loss = torch.mean(CT_loss) # Mean over 20 graphs! 91 | #print("CT_loss", CT_loss) 92 | 93 | # Orthogonality regularization. 94 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k] 95 | #print("ss size", ss.size()) 96 | i_s = torch.eye(k).type_as(ss) # [k, k] 97 | ortho_loss = torch.norm( 98 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - 99 | i_s ) 100 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph 101 | ortho_loss = torch.mean(ortho_loss) 102 | #print("ortho_loss", ortho_loss) 103 | 104 | return adj, CT_loss, ortho_loss # [20, k, 32], [20, B, N], [1], [1] 105 | -------------------------------------------------------------------------------- /layers/GAP_layer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch_geometric.utils import to_dense_batch, to_dense_adj 4 | from torch_geometric.nn import GCNConv, DenseGraphConv 5 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace 6 | from layers.utils.approximate_fiedler import approximate_Fiedler 7 | from layers.utils.approximate_fiedler import NLderivative_of_lambda2_wrt_adjacency, NLfiedler_values 8 | from layers.utils.approximate_fiedler import derivative_of_lambda2_wrt_adjacency, fiedler_values 9 | from layers.utils.approximate_fiedler import NLderivative_of_lambda2_wrt_adjacencyV2, NLfiedler_valuesV2 10 | 11 | def dense_mincut_rewiring(x, adj, s, mask=None, derivative = None, EPS=1e-15, device=None): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20 12 | 13 | k = 2 #We want bipartition to compute spectral gap 14 | # adj torch.Size([20, N, N]) N=Mmax 15 | #print("Input adj size to mincut pool", adj.size()) 16 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax 17 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3 18 | 19 | # s torch.Size([20, N, k]) 20 | s = s.unsqueeze(0) if s.dim() == 2 else s #s torch.Size([20, N, k]) 21 | #print("Unsqueezed s size", s.size()) 22 | 23 | s = torch.softmax(s, dim=-1) # torch.Size([20, N, k]) One k for each N of each graph 24 | #print("s softmax size", s.size()) 25 | #print("s softmax", s[0,1,:], torch.argmax(s,dim=(2)).size()) 26 | 27 | # Put here the calculus of the degree matrix to optimize the complex derivative 28 | d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N]) 29 | #print("d_flat size", d_flat.size()) 30 | d = _rank3_diag(d_flat) # d torch.Size([20, N, N]) 31 | #print("d size", d.size()) 32 | 33 | # Batched Laplacian 34 | L = d - adj 35 | 36 | # REWIRING: UPDATING adj wrt s using derivatives ------------------------------------------------- 37 | # Approximating the Fiedler vectors from s (assuming k=2) 38 | fiedlers = approximate_Fiedler(s, device) 39 | #print("fiedlers size", fiedlers.size()) 40 | #print("fiedlers ", fiedlers) 41 | 42 | # Recalculate 43 | if derivative == "laplacian": 44 | der = derivative_of_lambda2_wrt_adjacency(fiedlers, device) 45 | fvalues = fiedler_values(adj, fiedlers, EPS, device) 46 | elif derivative == "normalized": 47 | #start = time.time() 48 | der = NLderivative_of_lambda2_wrt_adjacency(adj, d_flat, fiedlers, EPS, device) 49 | fvalues = NLfiedler_values(L, d_flat, fiedlers, EPS, device) 50 | #print('\t\t NLderivative_of_lambda2_wrt_adjacency: {:.6f}s'.format(time.time()- start)) 51 | elif derivative == "normalizedv2": 52 | der = NLderivative_of_lambda2_wrt_adjacencyV2(adj, d_flat, fiedlers, EPS, device) 53 | fvalues = NLfiedler_valuesV2(L, d, fiedlers, EPS, device) 54 | 55 | mu = 0.01 56 | lambdaReg = 0.1 57 | lambdaReg = 1.0 58 | lambdaReg = 1.5 59 | lambdaReg = 2.0 60 | lambdaReg = 2.5 61 | lambdaReg = 5.0 62 | lambdaReg = 3.0 63 | lambdaReg = 1.0 64 | lambdaReg = 2.0 65 | #lambdaReg = 20.0 66 | 67 | Ac = adj.clone() 68 | for _ in range(5): 69 | #fvalues = fiedler_values(Ac, fiedlers) 70 | #print("Ac size", Ac.size()) 71 | partialJ = 2*(Ac-adj) + 2*lambdaReg*der*fvalues.unsqueeze(1).unsqueeze(2) # favalues is [B], partialJ is [B, N, N] 72 | #print("partialJ size", partialJ.size()) 73 | #print("diag size", torch.diag_embed(torch.diagonal(partialJ,dim1=1,dim2=2)).size()) 74 | dJ = partialJ + torch.transpose(partialJ,1,2) - torch.diag_embed(torch.diagonal(partialJ,dim1=1,dim2=2)) 75 | # Update adjacency 76 | Ac = Ac - mu*dJ 77 | # Clipping: negatives to 0, positives to 1 78 | #print("Ac is", Ac, Ac.size()) 79 | #Ac = torch.clamp(Ac, min=0.0, max=1.0) 80 | 81 | #print("Esta es la antigua adj",adj) 82 | #print("Esta es la antigua Ac",Ac) 83 | 84 | #print("Despues mask Ac",Ac) 85 | #print("Despues mask Adj",adj) 86 | #print("Mayores que 0",(Ac>0).sum()) #20,16,40 87 | #print("Menores que 0",(Ac<=0).sum()) 88 | Ac = torch.softmax(Ac, dim=-1) 89 | Ac = Ac*adj 90 | #print("Min Fiedlers",min(fvalues)) 91 | #print("NUeva salida",Ac) 92 | 93 | # out_adj: this tensor contains Apool = S.T*A*S so that we can take its trace and retain coarsened adjacency (Eq. 7) 94 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes 95 | #print("out_adj size", out_adj.size()) 96 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal 97 | 98 | # MinCUT regularization. 99 | mincut_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph 100 | #print("mincut_num size", mincut_num.size()) 101 | #d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N]) 102 | #print("d_flat size", d_flat.size()) 103 | #d = _rank3_diag(d_flat) # d torch.Size([20, N, N]) 104 | #print("d size", d.size()) 105 | mincut_den = _rank3_trace( 106 | torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph 107 | #print("mincut_den size", mincut_den.size()) 108 | 109 | mincut_loss = -(mincut_num / mincut_den) 110 | #print("mincut_loss", mincut_loss) 111 | mincut_loss = torch.mean(mincut_loss) # Mean over 20 graphs! 112 | 113 | # Orthogonality regularization. 114 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k] 115 | #print("ss size", ss.size()) 116 | i_s = torch.eye(k).type_as(ss) # [k, k] 117 | ortho_loss = torch.norm( 118 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - 119 | i_s / torch.norm(i_s), dim=(-1, -2)) 120 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph 121 | ortho_loss = torch.mean(ortho_loss) 122 | 123 | """# Fix and normalize coarsened adjacency matrix. 124 | ind = torch.arange(k, device=out_adj.device) # range e.g. from 0 to 15 (k=16) 125 | # out_adj is [20, k, k] 126 | out_adj[:, ind, ind] = 0 # [20, k, k] the diagnonal will be 0 now: Ahat = Apool - I_k*diag(Apool) (Eq. 8) 127 | #print("out_adj", out_adj[0,]) 128 | 129 | # Final degree matrix and normalization of out_adj: Ahatpool = Dhat^{-1/2}AhatD^{-1/2} (Eq. 8) 130 | d = torch.einsum('ijk->ij', out_adj) #d torch.Size([20, k]) 131 | #print("d size", d.size()) 132 | d = torch.sqrt(d)[:, None] + EP S # d torch.Size([20, 1, k]) 133 | #print("sqrt(d) size", d.size()) 134 | #print( (out_adj / d).shape) # out_adj is [20, k, k] and d is [20, 1, k] -> torch.Size([20, k, k]) 135 | out_adj = (out_adj / d) / d.transpose(1, 2) # ([20, k, k] / [20, k, 1] ) -> [20, k, k] 136 | # out_adj torch.Size([20, k, k]) 137 | #print("out_adj size", out_adj.size())""" 138 | return Ac, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1] 139 | #return out, out_adj, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1] -------------------------------------------------------------------------------- /layers/MinCut_Layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_dense_batch, to_dense_adj 3 | from torch_geometric.nn import GCNConv, DenseGraphConv 4 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace 5 | 6 | def dense_mincut_pool(x, adj, s, mask=None, EPS=1e-15): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20 7 | #print("Input x size to mincut pool", x.size()) 8 | x = x.unsqueeze(0) if x.dim() == 2 else x # x torch.Size([20, 40, 32]) if x has not 2 parameters 9 | #print("Unsqueezed x size to mincut pool", x.size(), x.dim()) # x.dim() is usually 3 10 | 11 | # adj torch.Size([20, N, N]) N=Mmax 12 | #print("Input adj size to mincut pool", adj.size()) 13 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax 14 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3 15 | 16 | # s torch.Size([20, N, k]) 17 | s = s.unsqueeze(0) if s.dim() == 2 else s # s torch.Size([20, N, k]) 18 | #print("Unsqueezed s size", s.size()) 19 | 20 | # x torch.Size([20, N, 32]) if x has not 2 parameters 21 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 22 | #print("batch_size and num_nodes", batch_size, num_nodes, k) # batch_size = 20, num_nodes = N, k = 16 23 | s = torch.softmax(s, dim=-1) # torch.Size([20, N, k]) One k for each N of each graph 24 | #print("s softmax size", s.size()) 25 | 26 | if mask is not None: # NOT None for now 27 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 28 | #print("mask size", mask.size()) # [20, N, 1] 29 | # Mask pointwise product. Since x is [20, N, 32] and s is [20, N, k] 30 | x, s = x * mask, s * mask # x*mask = [20, N, 32]*[20, N, 1] = [20, N, 32] s*mask = [20, N, k]*[20, N, 1] = [20, N, k] 31 | #print("x and s sizes after multiplying by mask", x.size(), s.size()) 32 | 33 | # out: this tensor contains Xpool=S.T*X (Eq. 7) 34 | out = torch.matmul(s.transpose(1, 2), x) # [20, k, N] * [20, N, 32] will yield [20, k, 32] 35 | #print("out size", out.size()) 36 | # out_adj: this tensor contains Apool = S.T*A*S so that we can take its trace and retain coarsened adjacency (Eq. 7) 37 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes 38 | #print("out_adj size", out_adj.size()) 39 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal 40 | 41 | # MinCUT regularization. 42 | mincut_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph 43 | #print("mincut_num size", mincut_num.size()) 44 | d_flat = torch.einsum('ijk->ij', adj) + EPS # torch.Size([20, N]) 45 | #print("d_flat size", d_flat.size()) 46 | d = _rank3_diag(d_flat) # d torch.Size([20, N, N]) 47 | #print("d size", d.size()) 48 | mincut_den = _rank3_trace( 49 | torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph 50 | #print("mincut_den size", mincut_den.size()) 51 | 52 | mincut_loss = -(mincut_num / mincut_den) 53 | #print("mincut_loss", mincut_loss) 54 | mincut_loss = torch.mean(mincut_loss) # Mean over 20 graphs! 55 | 56 | # Orthogonality regularization. 57 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k] 58 | #print("ss size", ss.size()) 59 | i_s = torch.eye(k).type_as(ss) # [k, k] 60 | ortho_loss = torch.norm( 61 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - 62 | i_s / torch.norm(i_s), dim=(-1, -2)) 63 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph 64 | ortho_loss = torch.mean(ortho_loss) 65 | 66 | # Fix and normalize coarsened adjacency matrix. 67 | ind = torch.arange(k, device=out_adj.device) # range e.g. from 0 to 15 (k=16) 68 | # out_adj is [20, k, k] 69 | out_adj[:, ind, ind] = 0 # [20, k, k] the diagnonal will be 0 now: Ahat = Apool - I_k*diag(Apool) (Eq. 8) 70 | #print("out_adj", out_adj[0,]) 71 | 72 | # Final degree matrix and normalization of out_adj: Ahatpool = Dhat^{-1/2}AhatD^{-1/2} (Eq. 8) 73 | d = torch.einsum('ijk->ij', out_adj) #d torch.Size([20, k]) 74 | #print("d size", d.size()) 75 | d = torch.sqrt(d+EPS)[:, None] + EPS # d torch.Size([20, 1, k]) 76 | #print("sqrt(d) size", d.size()) 77 | #print( (out_adj / d).shape) # out_adj is [20, k, k] and d is [20, 1, k] -> torch.Size([20, k, k]) 78 | out_adj = (out_adj / d) / d.transpose(1, 2) # ([20, k, k] / [20, k, 1] ) -> [20, k, k] 79 | # out_adj torch.Size([20, k, k]) 80 | #print("out_adj size", out_adj.size()) 81 | return out, out_adj, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1] -------------------------------------------------------------------------------- /layers/utils/approximate_fiedler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | 5 | from layers.utils.ein_utils import _rank3_diag 6 | 7 | def approximate_Fiedler(s, device=None): # torch.Size([20, N, k]) One k for each N of each graph (asume k=2) 8 | """ 9 | Calculate approximate fiedler vector from S matrix. S in R^{B x N x 2} and fiedler vector S in R^{B x N} 10 | """ 11 | s_0 = s.size(0) #number of graphs 12 | s_1 = s.size(1) 13 | maxcluster = torch.argmax(s,dim=(2)) # torch.Size([20, N]) with binary values {0,1} if k=2 14 | trimmed_s = torch.FloatTensor(s_0,s_1).to(device) 15 | #print('\t'*4,'[DEVICES] s device', s.device,' -- trimmed_s device', trimmed_s.device,' -- maxcluster device', trimmed_s.device) 16 | trimmed_s[maxcluster==1] = -1/np.sqrt(float(s_1)) 17 | trimmed_s[maxcluster==0] = 1/np.sqrt(float(s_1)) 18 | return trimmed_s 19 | 20 | def NLderivative_of_lambda2_wrt_adjacency(adj, d_flat, fiedlers, EPS, device): # fiedlers torch.Size([20, N]) 21 | """ 22 | Complex derivative 23 | 24 | Args: 25 | adj (_type_): _description_ 26 | d_flat (_type_): _description_ 27 | fiedlers (_type_): _description_ 28 | 29 | Returns: 30 | _type_: _description_ 31 | """ 32 | N = fiedlers.size(1) 33 | B = fiedlers.size(0) 34 | # Batched structures for the complex derivative 35 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N]) 36 | #print("first d_flat2 size", d_flat2.size()) 37 | Ahat = (adj/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N] 38 | AhatT = (adj.transpose(1,2)/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N] 39 | dinv = 1/(d_flat + EPS)[:, None] 40 | dder = -0.5*dinv*d_flat2 41 | dder = dder.transpose(1,2) # [B, N, 1] 42 | # Storage 43 | derivatives = torch.FloatTensor(B, N, N).to(device) 44 | 45 | for b in range(B): 46 | # Eigenvectors 47 | u2 = fiedlers[b,:] 48 | u2 = u2.unsqueeze(1) # [N, 1] 49 | #u2 = u2.to(device) #its already in device because fiedlers is already in device 50 | #print("size of u2", u2.size()) 51 | 52 | # First term central: [N,1]x ([1,N]x[N,N]x[N,1]) x [N,1] 53 | firstT = torch.matmul(torch.matmul(u2.T, AhatT[b,:,:]), u2) # [1,N]x[N,N]x[N,1] -> [1] 54 | #print("first term central size", firstT.size()) 55 | firstT = torch.matmul(torch.matmul(dder[b,:], firstT), torch.ones(N).unsqueeze(0).to(device)) 56 | 57 | # Second term 58 | secT = torch.matmul(torch.matmul(u2.T, Ahat[b,:,:]), u2) # [1,N]x[N,N]x[N,1] -> [1] 59 | #print("second term central size", secT.size()) 60 | secT = torch.matmul(torch.matmul(dder[b,:], secT), torch.ones(N).unsqueeze(0).to(device)) 61 | 62 | # Third term 63 | u2u2T = torch.matmul(u2,u2.T) # [N,1] x [1,N] -> [N,N] 64 | #print("u2u2T size", u2u2T.size()) 65 | #print("d_flat2[b,:] size", d_flat2[b,:].size()) 66 | Du2u2TD = (u2u2T / d_flat2[b,:]) / d_flat2[b,:].transpose(0, 1) 67 | #print("size of Du2u2TD", Du2u2TD.size()) 68 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL 69 | #dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N)) - u2u2T 70 | dl2 = firstT + secT + Du2u2TD 71 | # Symmetrize and subtract the diag since it is an undirected graph 72 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2)) 73 | derivatives[b,:,:] = -dl2 74 | return derivatives # derivatives torch.Size([20, N, N]) 75 | 76 | def NLfiedler_values(L, d_flat, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N]) 77 | N = fiedlers.size(1) 78 | B = fiedlers.size(0) 79 | #print("original fiedlers size", fiedlers.size()) 80 | 81 | # Batched Fiedlers 82 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N]) 83 | #print("d_flat2 size", d_flat2.size()) 84 | fiedlers = fiedlers.unsqueeze(2) # [B, N, 1] 85 | #print("fiedlers size", fiedlers.size()) 86 | fiedlers_hats = (fiedlers/d_flat2.transpose(1, 2)) # [B, N, 1] / [B, N, 1] -> [B, N, 1] 87 | gfiedlers_hats = (fiedlers*d_flat2.transpose(1, 2)) # [B, N, 1] * [B, N, 1] -> [B, N, 1] 88 | #print("fiedlers size", fiedlers_hats.size()) 89 | #print("gfiedlers size", gfiedlers_hats.size()) 90 | 91 | #Laplacians = torch.FloatTensor(B, N, N) 92 | fiedler_values = torch.FloatTensor(B).to(device) 93 | for b in range(B): 94 | f = fiedlers_hats[b,:] 95 | g = gfiedlers_hats[b,:] 96 | num = torch.matmul(f.T,torch.matmul(L[b,:,:],f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1] 97 | den = torch.matmul(g.T,g) 98 | #print("num fied", num.size()) 99 | #print("den fied", den.size()) 100 | #print("g size", g.size()) 101 | #print("f size", f.size()) 102 | #print("L size", L[b,:,:].size()) 103 | fiedler_values[b] = N*torch.abs(num/(den + EPS)) 104 | return fiedler_values # torch.Size([B]) 105 | 106 | def derivative_of_lambda2_wrt_adjacency(fiedlers, device): # fiedlers torch.Size([20, N]) 107 | """ 108 | Simple derivative 109 | """ 110 | N = fiedlers.size(1) 111 | B = fiedlers.size(0) 112 | derivatives = torch.FloatTensor(B, N, N).to(device) 113 | for b in range(B): 114 | u2 = fiedlers[b,:] 115 | u2 = u2.unsqueeze(1) 116 | #print("size of u2", u2.size()) 117 | u2u2T = torch.matmul(u2,u2.T) 118 | #print("size of u2u2T", u2u2T.size()) 119 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL 120 | dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N).to(device)) - u2u2T 121 | # Symmetrize and subtract the diag since it is an undirected graph 122 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2)) 123 | derivatives[b,:,:] = dl2 124 | 125 | return derivatives # derivatives torch.Size([20, N, N]) 126 | 127 | def fiedler_values(adj, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N]) 128 | N = fiedlers.size(1) 129 | B = fiedlers.size(0) 130 | #Laplacians = torch.FloatTensor(B, N, N) 131 | fiedler_values = torch.FloatTensor(B).to(device) 132 | for b in range(B): 133 | # Compute un-normalized Laplacian 134 | A = adj[b,:,:] 135 | D = A.sum(dim=1) 136 | D = torch.diag(D) 137 | L = D - A 138 | #Laplacians[b,:,:] = L 139 | #if torch.min(A)<0: 140 | # print("Negative adj") 141 | # Compute numerator 142 | f = fiedlers[b,:].unsqueeze(1) 143 | #f = f.to(device) 144 | num = torch.matmul(f.T,torch.matmul(L,f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1] 145 | # Create complete graph Laplacian 146 | CA = torch.ones(N,N).to(device)-torch.eye(N).to(device) 147 | CD = CA.sum(dim=1) 148 | CD = torch.diag(CD) 149 | CL = CD - CA 150 | CL = CL.to(device) 151 | # Compute denominator 152 | den = torch.matmul(f.T,torch.matmul(CL,f)) 153 | fiedler_values[b] = N*torch.abs(num/(den + EPS)) 154 | 155 | return fiedler_values # torch.Size([B]) 156 | 157 | 158 | def NLderivative_of_lambda2_wrt_adjacencyV2(adj, d_flat, fiedlers, EPS, device): # fiedlers torch.Size([20, N]) 159 | """ 160 | Complex derivative 161 | Args: 162 | adj (_type_): _description_ 163 | d_flat (_type_): _description_ 164 | fiedlers (_type_): _description_ 165 | Returns: 166 | _type_: _description_ 167 | """ 168 | N = fiedlers.size(1) 169 | B = fiedlers.size(0) 170 | # Batched structures for the complex derivative 171 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N]) 172 | d_flat = d_flat2.squeeze(1) 173 | #print("first d_flat2 size", d_flat2.size()) 174 | d_half = _rank3_diag(d_flat) # d torch.Size([B, N, N]) 175 | #print("d size", d.size()) 176 | Ahat = (adj/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N] 177 | AhatT = (adj.transpose(1,2)/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N] 178 | dinv = 1/(d_flat + EPS)[:, None] 179 | dder = -0.5*dinv*d_flat2 180 | dder = dder.transpose(1,2) # [B, N, 1] 181 | # Storage 182 | derivatives = torch.FloatTensor(B, N, N).to(device) 183 | for b in range(B): 184 | # Eigenvectors 185 | u2 = fiedlers[b,:] 186 | u2 = u2.unsqueeze(1) # [N, 1] 187 | #u2 = u2.to(device) #its already in device because fiedlers is already in device 188 | #print("size of u2", u2.size()) 189 | # First term central: [N,1]x ([1,N]x[N,N]x[N,1]) x [N,1] 190 | firstT = torch.matmul(torch.matmul(u2.T, torch.matmul(d_half[b,:,:], AhatT[b,:,:])), u2) # [1,N]x[N,N]x[N,1] -> [1] 191 | #print("first term central size", firstT.size()) 192 | firstT = torch.matmul(torch.matmul(dder[b,:], firstT), torch.ones(N).unsqueeze(0).to(device)) 193 | #print("first term size", firstT.size()) 194 | # Second term 195 | secT = torch.matmul(torch.matmul(u2.T, torch.matmul(d_half[b,:,:], Ahat[b,:,:])), u2) # [1,N]x[N,N]x[N,1] -> [1] 196 | #print("second term central size", secT.size()) 197 | secT = torch.matmul(torch.matmul(dder[b,:], secT), torch.ones(N).unsqueeze(0).to(device)) 198 | # Third term 199 | Du2u2TD = torch.matmul(u2,u2.T) # [N,1] x [1,N] -> [N,N] 200 | #print("Du2u2T size", u2u2T.size()) 201 | #print("d_flat2[b,:] size", d_flat2[b,:].size()) 202 | #Du2u2TD = (u2u2T / d_flat2[b,:]) / d_flat2[b,:].transpose(0, 1) 203 | #print("size of Du2u2TD", Du2u2TD.size()) 204 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL 205 | #dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N)) - u2u2T 206 | dl2 = firstT + secT + Du2u2TD 207 | # Symmetrize and subtract the diag since it is an undirected graph 208 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2)) 209 | derivatives[b,:,:] = -dl2 210 | return derivatives # derivatives torch.Size([20, N, N]) 211 | 212 | 213 | def NLfiedler_valuesV2(L, d, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N]) 214 | N = fiedlers.size(1) 215 | B = fiedlers.size(0) 216 | #print("original fiedlers size", fiedlers.size()) 217 | #print("d size", d.size()) 218 | #Laplacians = torch.FloatTensor(B, N, N) 219 | fiedler_values = torch.FloatTensor(B).to(device) 220 | for b in range(B): 221 | f = fiedlers[b,:].unsqueeze(1) 222 | num = torch.matmul(f.T,torch.matmul(L[b,:,:],f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1] 223 | den = torch.matmul(f.T,torch.matmul(d[b,:,:], f)) # d is [N, N], f is [N,1], f.T is [1,N] -> [1,N] x [N, N] x [N, 1] is [1] 224 | fiedler_values[b] = N*torch.abs(num/(den + EPS)) 225 | """print(f.shape) 226 | print(f.T.shape) 227 | print(num.shape, num) 228 | print(den.shape, den) 229 | print((N*torch.abs(num/(den + EPS))).shape) 230 | exit()""" 231 | return fiedler_values # torch.Size([B]) 232 | 233 | -------------------------------------------------------------------------------- /layers/utils/ein_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # Trace of a tensor [1,k,k] 3 | def _rank3_trace(x): 4 | return torch.einsum('ijj->i', x) 5 | 6 | # Diagonal version of a tensor [1,n] -> [1,n,n] 7 | def _rank3_diag(x): 8 | # Eye matrix of n=x.size(1): [n,n] 9 | eye = torch.eye(x.size(1)).type_as(x) 10 | #print(eye.size()) 11 | #print(x.unsqueeze(2).size()) 12 | # x.unsqueeze(2) adds a second dimension to [1,n] -> [1,n,1] 13 | # expand(*x.size(), x.size(1)) takes [1,n,1] and expands [1,n] with n -> [1,n,n] 14 | out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1)) 15 | return out -------------------------------------------------------------------------------- /layers/utils/spectral_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_dense_adj 3 | 4 | def unnormalized_laplacian_eigenvectors(L): # 5 | el, ev = torch.linalg.eig(L) 6 | el = torch.real(el) 7 | ev = torch.real(ev) 8 | idx = torch.argsort(el) 9 | el = el[idx] 10 | ev = ev[:,idx] 11 | #print("L", L, el[1]) 12 | return el, ev 13 | 14 | 15 | # Compute Fiedler values of all the graphs 16 | def compute_fiedler_vectors(dataset): 17 | """ 18 | Calculate fieldver vector for all graphs in dataset 19 | """ 20 | vectors = [] 21 | values = [] 22 | for g in range(len(dataset)): 23 | G = dataset[g] 24 | #print(G) 25 | adj = to_dense_adj(G.edge_index) 26 | # adj is [1, N, N] 27 | adj = adj.squeeze(0) 28 | # adj is [N, N] 29 | #print(adj.size()) 30 | A = adj 31 | D = A.sum(dim=1) 32 | D = torch.diag(D) 33 | L = D - A 34 | #print(L) 35 | el, ev = unnormalized_laplacian_eigenvectors(L) 36 | vectors.append(ev[:,1]) 37 | values.append(el[1]) 38 | 39 | return values, vectors 40 | 41 | 42 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Linear 5 | from torch_geometric.nn import DenseGraphConv 6 | from torch_geometric.utils import to_dense_batch, to_dense_adj 7 | from layers.CT_layer import dense_CT_rewiring 8 | from layers.MinCut_Layer import dense_mincut_pool 9 | from layers.GAP_layer import dense_mincut_rewiring 10 | 11 | 12 | class GAPNet(torch.nn.Module): 13 | def __init__(self, in_channels, out_channels, hidden_channels=32, derivative=None, EPS=1e-15, device=None): 14 | super(GAPNet, self).__init__() 15 | self.device = device 16 | self.derivative = derivative 17 | self.EPS = EPS 18 | # GCN Layer - MLP - Dense GCN Layer 19 | #self.conv1 = GCNConv(in_channels, hidden_channels) 20 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels) 21 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) 22 | num_of_centers2 = 16 # k2 23 | #num_of_centers2 = 10 # k2 24 | #num_of_centers2 = 5 # k2 25 | num_of_centers1 = 2 # k1 #Fiedler vector 26 | # The degree of the node belonging to any of the centers 27 | self.pool1 = Linear(hidden_channels, num_of_centers1) 28 | self.pool2 = Linear(hidden_channels, num_of_centers2) 29 | # MLPs towards out 30 | self.lin1 = Linear(in_channels, hidden_channels) 31 | self.lin2 = Linear(hidden_channels, hidden_channels) 32 | self.lin3 = Linear(hidden_channels, out_channels) 33 | 34 | # Input: Batch of 20 graphs, each node F=3 features 35 | # N1 + N2 + ... + N2 = 661 36 | # TSNE here? 37 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661]) 38 | 39 | # Make all adjacencies of size NxN 40 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N]) 41 | #print("adj_size", adj.size()) 42 | #print("adj",adj) 43 | 44 | 45 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40: 46 | #print("x size", x.size()) 47 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20 48 | #print("x size", x.size()) 49 | 50 | x = self.lin1(x) 51 | # First mincut pool for computing Fiedler adn rewire 52 | s1 = self.pool1(x) 53 | #s1 = torch.variable()#s1 torch.Size([20, N, k1=2) 54 | #s1 = Variable(torch.randn(D_in, H).type(float16), requires_grad=True) 55 | #print("s 1st pool",s1) 56 | #print("s 1st pool size", s1.size()) 57 | 58 | if torch.isnan(adj).any(): 59 | print("adj nan") 60 | if torch.isnan(x).any(): 61 | print("x nan") 62 | 63 | 64 | # REWIRING 65 | #start = time.time() 66 | adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask, derivative = self.derivative, EPS=self.EPS, device=self.device) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N]) 67 | #print('\t\tdense_mincut_rewiring: {:.6f}s'.format(time.time()- start)) 68 | #print("x",x) 69 | #print("adj",adj) 70 | #print("x and adj sizes", x.size(), adj.size()) 71 | #adj = torch.softmax(adj, dim=-1) 72 | #print("adj softmaxed", adj) 73 | 74 | # CONV1: Now on x and rewired adj: 75 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32]) 76 | #print("x_1 ", x) 77 | #print("x_1 size", x.size()) 78 | 79 | # MLP of k=16 outputs s 80 | #print("adj_size", adj.size()) 81 | s2 = self.pool2(x) # s torch.Size([20, N, k]) 82 | #print("s 2nd pool", s2) 83 | #print("s 2nd pool size", s2.size()) 84 | #adj = torch.softmax(adj, dim=-1) 85 | 86 | 87 | # MINCUT_POOL 88 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16 89 | #x, adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask) # x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 90 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 91 | #print("lossses2",mincut_loss2, ortho_loss2) 92 | #print("mincut pool x", x) 93 | #print("mincut pool adj", adj) 94 | #print("mincut pool x size", x.size()) 95 | #print("mincut pool adj size", adj.size()) # Some nan in adjacency: maybe comming from the rewiring-> dissapear after clipping 96 | 97 | 98 | # CONV2: Now on coarsened x and adj: 99 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32]) 100 | #print("x_2", x) 101 | #print("x_2 size", x.size()) 102 | 103 | # Readout for each of the 20 graphs 104 | #x = x.mean(dim=1) # x torch.Size([20, 32]) 105 | x = x.sum(dim=1) # x torch.Size([20, 32]) 106 | #print("mean x_2 size", x.size()) 107 | 108 | # Final MLP for graph classification: hidden channels = 32 109 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32]) 110 | #print("final x1 size", x.size()) 111 | x = self.lin3(x) #x torch.Size([20, 2]) 112 | #print("final x2 size", x.size()) 113 | #print("losses: ", mincut_loss1, mincut_loss2, ortho_loss2, mincut_loss2) 114 | mincut_loss = mincut_loss1 + mincut_loss2 115 | ortho_loss = ortho_loss1 + ortho_loss2 116 | #print("x", x) 117 | return F.log_softmax(x, dim=-1), mincut_loss, ortho_loss 118 | 119 | 120 | class CTNet(torch.nn.Module): 121 | def __init__(self, in_channels, out_channels, k_centers, hidden_channels=32, EPS=1e-15): 122 | super(CTNet, self).__init__() 123 | self.EPS=EPS 124 | # GCN Layer - MLP - Dense GCN Layer 125 | #self.conv1 = GCNConv(in_channels, hidden_channels) 126 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels) 127 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) 128 | 129 | # The degree of the node belonging to any of the centers 130 | num_of_centers1 = k_centers # k1 #order of number of nodes 131 | self.pool1 = Linear(hidden_channels, num_of_centers1) 132 | num_of_centers2 = 16 # k2 #mincut 133 | self.pool2 = Linear(hidden_channels, num_of_centers2) 134 | 135 | # MLPs towards out 136 | self.lin1 = Linear(in_channels, hidden_channels) 137 | self.lin2 = Linear(hidden_channels, hidden_channels) 138 | self.lin3 = Linear(hidden_channels, out_channels) 139 | 140 | 141 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661]) 142 | # Make all adjacencies of size NxN 143 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N]) 144 | #print("adj_size", adj.size()) 145 | #print("adj",adj) 146 | 147 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40: 148 | #print("x size", x.size()) 149 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20 150 | #print("x size", x.size()) 151 | 152 | x = self.lin1(x) 153 | # First mincut pool for computing Fiedler adn rewire 154 | s1 = self.pool1(x) 155 | #s1 = torch.variable()#s1 torch.Size([20, N, k1=2) 156 | #s1 = Variable(torch.randn(D_in, H).type(float16), requires_grad=True) 157 | #print("s 1st pool",s1) 158 | #print("s 1st pool size", s1.size()) 159 | 160 | if torch.isnan(adj).any(): 161 | print("adj nan") 162 | if torch.isnan(x).any(): 163 | print("x nan") 164 | 165 | # CT REWIRING 166 | adj, CT_loss, ortho_loss1 = dense_CT_rewiring(x, adj, s1, mask, EPS = self.EPS) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N]) 167 | 168 | #print("CT_loss, ortho_loss1", CT_loss, ortho_loss1) 169 | #print("x",x) 170 | #print("adj",adj) 171 | #print("x and adj sizes", x.size(), adj.size()) 172 | #adj = torch.softmax(adj, dim=-1) 173 | #print("adj softmaxed", adj) 174 | 175 | # CONV1: Now on x and rewired adj: 176 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32]) 177 | #print("x_1 ", x) 178 | #print("x_1 size", x.size()) 179 | 180 | # MLP of k=16 outputs s 181 | #print("adj_size", adj.size()) 182 | s2 = self.pool2(x) # s torch.Size([20, N, k]) 183 | #print("s 2nd pool", s2) 184 | #print("s 2nd pool size", s2.size()) 185 | #adj = torch.softmax(adj, dim=-1) 186 | 187 | 188 | # MINCUT_POOL 189 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16 190 | #x, adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask) # x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 191 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 192 | #print("lossses2",mincut_loss2, ortho_loss2) 193 | #print("mincut pool x", x) 194 | #print("mincut pool adj", adj) 195 | #print("mincut pool x size", x.size()) 196 | #print("mincut pool adj size", adj.size()) # Some nan in adjacency: maybe comming from the rewiring-> dissapear after clipping 197 | 198 | 199 | # CONV2: Now on coarsened x and adj: 200 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32]) 201 | #print("x_2", x) 202 | #print("x_2 size", x.size()) 203 | 204 | # Readout for each of the 20 graphs 205 | #x = x.mean(dim=1) # x torch.Size([20, 32]) 206 | x = x.sum(dim=1) # x torch.Size([20, 32]) 207 | #print("mean x_2 size", x.size()) 208 | 209 | # Final MLP for graph classification: hidden channels = 32 210 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32]) 211 | #print("final x1 size", x.size()) 212 | x = self.lin3(x) #x torch.Size([20, 2]) 213 | #print("final x2 size", x.size()) 214 | CT_loss = CT_loss + ortho_loss1 215 | mincut_loss = mincut_loss2 + ortho_loss2 216 | #print("x", x) 217 | return F.log_softmax(x, dim=-1), CT_loss, mincut_loss 218 | 219 | 220 | class MinCutNet(torch.nn.Module): 221 | def __init__(self, in_channels, out_channels, hidden_channels=32, EPS=1e-15): 222 | super(MinCutNet, self).__init__() 223 | self.EPS=EPS 224 | # GCN Layer - MLP - Dense GCN Layer 225 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels) 226 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) 227 | 228 | # The degree of the node belonging to any of the centers 229 | num_of_centers2 = 16 # k2 #mincut 230 | self.pool2 = Linear(hidden_channels, num_of_centers2) 231 | 232 | # MLPs towards out 233 | self.lin1 = Linear(in_channels, hidden_channels) 234 | self.lin2 = Linear(hidden_channels, hidden_channels) 235 | self.lin3 = Linear(hidden_channels, out_channels) 236 | 237 | 238 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661]) 239 | 240 | # Make all adjacencies of size NxN 241 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N]) 242 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40: 243 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20 244 | 245 | x = self.lin1(x) 246 | 247 | if torch.isnan(adj).any(): 248 | print("adj nan") 249 | if torch.isnan(x).any(): 250 | print("x nan") 251 | 252 | # CONV1: Now on x and rewired adj: 253 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32]) 254 | 255 | # MLP of k=16 outputs s 256 | s2 = self.pool2(x) # s torch.Size([20, N, k]) 257 | 258 | # MINCUT_POOL 259 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16 260 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 261 | 262 | # CONV2: Now on coarsened x and adj: 263 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32]) 264 | 265 | # Readout for each of the 20 graphs 266 | #x = x.mean(dim=1) # x torch.Size([20, 32]) 267 | x = x.sum(dim=1) # x torch.Size([20, 32]) 268 | # Final MLP for graph classification: hidden channels = 32 269 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32]) 270 | x = self.lin3(x) #x torch.Size([20, 2]) 271 | 272 | mincut_loss = mincut_loss2 + ortho_loss2 273 | #print("x", x) 274 | return F.log_softmax(x, dim=-1), mincut_loss2, ortho_loss2 275 | 276 | 277 | class DiffWire(torch.nn.Module): 278 | def __init__(self, in_channels, out_channels, k_centers, derivative=None, hidden_channels=32, EPS=1e-15, device=None): 279 | super(DiffWire, self).__init__() 280 | 281 | self.EPS=EPS 282 | self.derivative = derivative 283 | self.device=device 284 | 285 | # First X transformation 286 | self.lin1 = Linear(in_channels, hidden_channels) 287 | 288 | #Fiedler vector -- Pool previous to GAP-Layer 289 | self.pool_rw = Linear(hidden_channels, 2) 290 | 291 | #CT Embedding -- Pool previous to CT-Layer 292 | self.num_of_centers1 = k_centers # k1 - order of number of nodes 293 | self.pool_ct = Linear(hidden_channels, self.num_of_centers1) #CT 294 | 295 | #Conv1 296 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels) 297 | 298 | #MinCutPooling 299 | self.pool_mc = Linear(hidden_channels, 16) #MC 300 | 301 | #Conv2 302 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels) 303 | 304 | # MLPs towards out 305 | self.lin2 = Linear(hidden_channels, hidden_channels) 306 | self.lin3 = Linear(hidden_channels, out_channels) 307 | 308 | 309 | def forward(self, x, edge_index, batch): 310 | # Make all adjacencies of size NxN 311 | adj = to_dense_adj(edge_index, batch) 312 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40: 313 | x, mask = to_dense_batch(x, batch) 314 | 315 | x = self.lin1(x) 316 | 317 | if torch.isnan(adj).any(): 318 | print("adj nan") 319 | if torch.isnan(x).any(): 320 | print("x nan") 321 | 322 | #Gap Layer RW 323 | s0 = self.pool_rw(x) 324 | adj, mincut_loss_rw, ortho_loss_rw = dense_mincut_rewiring(x, adj, s0, mask, 325 | derivative = self.derivative, EPS=self.EPS, device=self.device) 326 | 327 | # CT REWIRING 328 | # First mincut pool for computing Fiedler adn rewire 329 | s1 = self.pool_ct(x) 330 | adj, CT_loss, ortho_loss_ct = dense_CT_rewiring(x, adj, s1, mask, EPS = self.EPS) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N]) 331 | 332 | # CONV1: Now on x and rewired adj: 333 | x = self.conv1(x, adj) 334 | 335 | # MINCUT_POOL 336 | # MLP of k=16 outputs s 337 | s2 = self.pool_mc(x) 338 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16 339 | x, adj, mincut_loss, ortho_loss_mc = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16]) 340 | 341 | # CONV2: Now on coarsened x and adj: 342 | x = self.conv2(x, adj) 343 | 344 | # Readout for each of the 20 graphs 345 | x = x.sum(dim=1) 346 | # Final MLP for graph classification: hidden channels = 32 347 | x = F.relu(self.lin2(x)) 348 | x = self.lin3(x) 349 | 350 | 351 | main_loss = mincut_loss_rw + CT_loss + mincut_loss 352 | ortho_loss = ortho_loss_rw + ortho_loss_ct + ortho_loss_mc 353 | #ortho_loss_rw/2 + (1/self.num_of_centers1)*ortho_loss_ct + ortho_loss_mc/16 354 | #print("x", x) 355 | return F.log_softmax(x, dim=-1), main_loss, ortho_loss 356 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DiffWire: Inductive Graph Rewiring via the Lovász Bound 2 | 3 | **Accepted at the First Learning on Graphs Conference 2022** 4 | 5 | [![LoG](https://img.shields.io/badge/Published%20-Learning%20on%20Graphs-blue.svg)](https://openreview.net/forum?id=IXvfIex0mX6f¬eId=t5zJZuEIy1y) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffwire-inductive-graph-rewiring-via-the/graph-classification-on-imdb-binary)](https://paperswithcode.com/sota/graph-classification-on-imdb-binary?p=diffwire-inductive-graph-rewiring-via-the) 7 | 8 | 9 | 10 | image 11 | 12 | $$ 13 | \left| \frac{1}{vol(G)}CT_{uv}-\left(\frac{1}{d_u} + \frac{1}{d_v}\right)\right|\le \frac{1}{\lambda_2}\frac{2}{d_{min}} 14 | $$ 15 | 16 | ```bibtex 17 | @InProceedings{arnaiz2022diffwire, 18 | title = {{DiffWire: Inductive Graph Rewiring via the Lov{\'a}sz Bound}}, 19 | author = {Arnaiz-Rodr{\'i}guez, Adri{\'a}n and Begga, Ahmed and Escolano, Francisco and Oliver, Nuria M}, 20 | booktitle = {Proceedings of the First Learning on Graphs Conference}, 21 | pages = {15:1--15:27}, 22 | year = {2022}, 23 | editor = {Rieck, Bastian and Pascanu, Razvan}, 24 | volume = {198}, 25 | series = {Proceedings of Machine Learning Research}, 26 | month = {09--12 Dec}, 27 | publisher = {PMLR}, 28 | pdf = {https://proceedings.mlr.press/v198/arnaiz-rodri-guez22a/arnaiz-rodri-guez22a.pdf}, 29 | url = {https://proceedings.mlr.press/v198/arnaiz-rodri-guez22a.html} 30 | } 31 | ``` 32 | 33 | ## Dependencies 34 | 35 | Conda environment 36 | ``` 37 | conda create --name --file requirements.txt 38 | ``` 39 | 40 | or 41 | 42 | ``` 43 | conda env create -f environment_experiments.yml 44 | conda activate DiffWire 45 | ``` 46 | ## Code organization 47 | 48 | * `datasets/`: script for creating synthetic datasets. For non-synthetic ones: we use PyG in `train.py` 49 | * `layers/`: Implementation of the proposed **GAP-Layer**, **CT-Layer**, and the baseline MinCutPool (based on his repo). 50 | * `tranforms/`: Implementation og graph preprocessing baselines DIGL and SDRF, both based on the official repositories of the work. 51 | * `trained_models/`: files with the weight of some trained models. 52 | * `nets.py`: Implementation of GNNs used in our experiments. 53 | * `train.py`: Script with inline arguments for running the experiments. 54 | 55 | ## Run experiments 56 | ```python 57 | python train.py --dataset REDDIT-BINARY --model CTNet --cuda cuda:0 58 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative laplacian --cuda cuda:0 59 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalizeed --cuda cuda:0 60 | ``` 61 | 62 | `experiments_all.sh` list all the experiments. 63 | 64 | 65 | ## Code Examples 66 | 67 | See jupyter notebook examples at the tutorial presented at ***The First Learning on Graphs Conference***: **[Graph Rewiring: From Theory to Applications in Fairness](https://github.com/ellisalicante/GraphRewiring-Tutorial)** 68 | 69 | * CT-Layer [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ellisalicante/GraphRewiring-Tutorial/blob/main/3-Inductive-Rewiring-CTLayer.ipynb) 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=2_kmp_llvm 6 | blas=2.114=mkl 7 | blas-devel=3.9.0=14_linux64_mkl 8 | ca-certificates=2021.10.8=ha878542_0 9 | certifi=2021.10.8=pypi_0 10 | charset-normalizer=2.0.12=pypi_0 11 | cudatoolkit=11.5.1=hcf5317a_10 12 | cycler=0.11.0=pypi_0 13 | fonttools=4.33.3=pypi_0 14 | freetype=2.10.4=h0708190_1 15 | giflib=5.2.1=h36c2ea0_2 16 | idna=3.3=pypi_0 17 | jbig=2.1=h7f98852_2003 18 | jinja2=3.1.2=pypi_0 19 | joblib=1.1.0=pypi_0 20 | jpeg=9e=h166bdaf_1 21 | kiwisolver=1.4.2=pypi_0 22 | lcms2=2.12=hddcbb42_0 23 | ld_impl_linux-64=2.36.1=hea4e1c9_2 24 | lerc=3.0=h9c3ff4c_0 25 | libblas=3.9.0=14_linux64_mkl 26 | libcblas=3.9.0=14_linux64_mkl 27 | libdeflate=1.10=h7f98852_0 28 | libffi=3.4.2=h7f98852_5 29 | libgcc-ng=11.2.0=h1d223b6_16 30 | libgfortran-ng=11.2.0=h69a702a_16 31 | libgfortran5=11.2.0=h5c6108e_16 32 | libgomp=11.2.0=h1d223b6_16 33 | liblapack=3.9.0=14_linux64_mkl 34 | liblapacke=3.9.0=14_linux64_mkl 35 | libnsl=2.0.0=h7f98852_0 36 | libpng=1.6.37=h21135ba_2 37 | libstdcxx-ng=11.2.0=he4da1e4_16 38 | libtiff=4.3.0=h542a066_3 39 | libuv=1.43.0=h7f98852_0 40 | libwebp=1.2.2=h3452ae3_0 41 | libwebp-base=1.2.2=h7f98852_1 42 | libxcb=1.13=h7f98852_1004 43 | libzlib=1.2.11=h166bdaf_1014 44 | llvm-openmp=14.0.3=he0ac6c6_0 45 | lz4-c=1.9.3=h9c3ff4c_1 46 | markupsafe=2.1.1=pypi_0 47 | matplotlib=3.5.2=pypi_0 48 | mkl=2022.0.1=h8d4b97c_803 49 | mkl-devel=2022.0.1=ha770c72_804 50 | mkl-include=2022.0.1=h8d4b97c_803 51 | ncurses=6.3=h27087fc_1 52 | numpy=1.22.3=py38h99721a1_2 53 | openjpeg=2.4.0=hb52868f_1 54 | openssl=3.0.3=h166bdaf_0 55 | packaging=21.3=pypi_0 56 | pandas=1.4.2=pypi_0 57 | pillow=9.1.0=py38h0ee0e06_2 58 | pip=22.0.4=pyhd8ed1ab_0 59 | pthread-stubs=0.4=h36c2ea0_1001 60 | pyparsing=3.0.8=pypi_0 61 | python=3.8.12=h0744224_3_cpython 62 | python-dateutil=2.8.2=pypi_0 63 | python_abi=3.8=2_cp38 64 | pytorch=1.11.0=py3.8_cuda11.5_cudnn8.3.2_0 65 | pytorch-mutex=1.0=cuda 66 | pytz=2022.1=pypi_0 67 | readline=8.1=h46c0cb4_0 68 | requests=2.27.1=pypi_0 69 | scikit-learn=1.0.2=pypi_0 70 | scipy=1.8.0=pypi_0 71 | setuptools=62.1.0=py38h578d9bd_0 72 | six=1.16.0=pyh6c4a22f_0 73 | sqlite=3.38.5=h4ff8645_0 74 | tbb=2021.5.0=h924138e_1 75 | threadpoolctl=3.1.0=pypi_0 76 | tk=8.6.12=h27826a3_0 77 | torch-cluster=1.6.0=pypi_0 78 | torch-geometric=2.0.4=pypi_0 79 | torch-scatter=2.0.9=pypi_0 80 | torch-sparse=0.6.13=pypi_0 81 | torch-spline-conv=1.2.1=pypi_0 82 | torchaudio=0.11.0=py38_cu115 83 | torchvision=0.2.2=py_3 84 | tqdm=4.64.0=pypi_0 85 | typing_extensions=4.2.0=pyha770c72_1 86 | urllib3=1.26.9=pypi_0 87 | wheel=0.37.1=pyhd8ed1ab_0 88 | xorg-libxau=1.0.9=h7f98852_0 89 | xorg-libxdmcp=1.1.3=h7f98852_0 90 | xz=5.2.5=h516909a_1 91 | zlib=1.2.11=h166bdaf_1014 92 | zstd=1.5.2=ha95c52a_0 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import os 3 | import random 4 | 5 | from sklearn.model_selection import train_test_split 6 | from datasets.Erdos_Renyi_dataset import Erdos_Renyi_pyg 7 | from datasets.SBM_dataset import SBM_pyg 8 | from nets import CTNet, DiffWire, GAPNet, MinCutNet 9 | import torch 10 | import torch.nn.functional as F 11 | from torch_geometric.datasets import GNNBenchmarkDataset, TUDataset 12 | import torch_geometric.transforms as T 13 | from transforms import FeatureDegree, DIGLedges, SDRF, KNNGraph 14 | from torch_geometric.loader import DataLoader 15 | from torch_geometric.utils import to_dense_batch, to_dense_adj 16 | import time 17 | import argparse 18 | import numpy as np 19 | ''' 20 | Dataset arguments: 21 | -Name: 22 | *TUDataset: 23 | + No features: [REDDEDIT BINARY, IMBD BINARY, COLLAB] 24 | + Featured: [MUTAG, ENZYMES, PROTEINS] 25 | *Benchmark: [MNIST,CIFAR10] 26 | Model arguments: 27 | -Name: 28 | MC(num_features,num_classes) 29 | GAPNet(new_num_features, dataset.new_num_classes,derivative) 30 | CTNet(num_features,num_classes) 31 | Other arguments: 32 | Lr 33 | weight decay 34 | 35 | ''' 36 | 37 | ################### Arguments parameters ################################### 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument( 40 | "--dataset", 41 | default="MUTAG", 42 | choices=["MUTAG","ENZYMES","PROTEINS","CIFAR10","MNIST","COLLAB","IMDB-BINARY","REDDIT-BINARY","CSL", "SBM", "ERDOS"], 43 | help="nada", 44 | ) 45 | parser.add_argument( 46 | "--model", 47 | default="CTNet", 48 | choices=["CTNet","GAPNet","MinCutNet", "DiffWire"], 49 | help="nada", 50 | ) 51 | parser.add_argument( 52 | "--derivative", 53 | default="laplacian", 54 | choices=["laplacian","normalizedv2","normalized"], #,"normalized" 55 | help="Only used if model is GAP", 56 | ) 57 | parser.add_argument( 58 | "--cuda", 59 | default="cuda:0", 60 | choices=["cuda:0","cuda:1"], 61 | help="cuda version", 62 | ) 63 | parser.add_argument( 64 | "--prepro", 65 | default=None, 66 | choices=[None,"digl", "sdrf", "knn"], 67 | help="preprocessing", 68 | ) 69 | parser.add_argument( 70 | "--store", 71 | action="store_true", 72 | help="nada", 73 | ) 74 | parser.add_argument( 75 | "--iter", 76 | type=int, 77 | default=10, 78 | help="The number of games to simulate" 79 | ) 80 | parser.add_argument( 81 | "--logs", 82 | default="logs", 83 | help="log folders", 84 | ) 85 | parser.add_argument( 86 | "--lr", type=float, default=5e-4, help="Outer learning rate of model" 87 | ) 88 | parser.add_argument( 89 | "--wd", type=float, default=1e-4, help="Outer weight decay rate of model" 90 | ) 91 | args = parser.parse_args() 92 | 93 | #Procesing dataset 94 | No_Features = ["COLLAB","IMDB-BINARY","REDDIT-BINARY", "CSL"] 95 | 96 | if args.prepro == 'digl': 97 | preprocessing = DIGLedges(alpha=0.001) 98 | aux_prepro_folder = "/DIGL" if args.prepro == "digl" else "" 99 | 100 | elif args.prepro == 'sdrf': 101 | preprocessing = SDRF(undirected = True, max_steps="dynamic", tau = 20, 102 | remove_edges = True, removal_bound = 0) 103 | aux_prepro_folder = "/SDRF" if args.prepro == "sdrf" else "" 104 | elif args.prepro == 'knn': 105 | preprocessing = KNNGraph( 106 | k=None, 107 | force_undirected=True 108 | ) 109 | aux_prepro_folder = "/KNN" if args.prepro == "knn" else "" 110 | 111 | elif args.prepro is None: 112 | aux_prepro_folder = "" 113 | preprocessing = None 114 | 115 | else: 116 | raise NotImplementedError("Not implemented preprocessing") 117 | 118 | if args.dataset == "SBM": 119 | dataset = SBM_pyg('./data/SBM_final'+aux_prepro_folder, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, 120 | p1=0.8, p2=0.5, qmin1=0.1, qmax1=0.15, qmin2=0.01, qmax2=0.1, 121 | directed=False, pre_transform=preprocessing) 122 | TRAIN_SPLIT = 800 123 | BATCH_SIZE = 32 124 | num_of_centers = 200 125 | 126 | elif args.dataset == "ERDOS": 127 | dataset = Erdos_Renyi_pyg('./data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, 128 | p1_min=0.4, p1_max=0.6, p2_min=0.5, p2_max=0.8) 129 | TRAIN_SPLIT = 800 130 | BATCH_SIZE = 32 131 | num_of_centers = 200 132 | 133 | elif args.dataset not in GNNBenchmarkDataset.names: 134 | if args.dataset not in No_Features: #Features 135 | dataset = TUDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'TUDataset', name=args.dataset, pre_transform=preprocessing) 136 | if args.dataset =="MUTAG": # 188 graphs 137 | TRAIN_SPLIT = 150 138 | BATCH_SIZE = 32 139 | num_of_centers = 17 #mean number of nodes according to PyGeom 140 | if args.dataset =="ENZYMES": # 600 graphs 141 | TRAIN_SPLIT = 500 142 | BATCH_SIZE = 32 143 | num_of_centers = 16 #mean number of nodes according to PyGeom 144 | if args.dataset =="PROTEINS": # 1113 graphs 145 | TRAIN_SPLIT = 1000 146 | BATCH_SIZE = 64 147 | num_of_centers = 39 #mean number of nodes according to PyGeom 148 | else: #No Features 149 | if args.prepro is not None: 150 | preprocessing = preprocessing 151 | processing = FeatureDegree() 152 | else: 153 | preprocessing = FeatureDegree() 154 | processing = None 155 | dataset = TUDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'TUDataset',name=args.dataset, 156 | pre_transform=preprocessing, transform = processing, use_node_attr=True) 157 | #dataset = TUDatasetFeatures(root='data/TUDataset', name=args.dataset,dataset=datasetGNN) 158 | if args.dataset =="IMDB-BINARY": # 1000 graphs 159 | TRAIN_SPLIT = 800 160 | BATCH_SIZE = 64 161 | num_of_centers = 20 #mean number of nodes according to PyGeom 162 | elif args.dataset == "REDDIT-BINARY": # 2000 graphs 163 | TRAIN_SPLIT = 1500 164 | BATCH_SIZE = 64 165 | num_of_centers = 420 #mean number of nodes according to PyGeom 166 | elif args.dataset == "COLLAB": # 2000 graphs 167 | TRAIN_SPLIT = 4500 168 | BATCH_SIZE = 64 169 | num_of_centers = 75 #mean number of nodes according to PyGeom 170 | else: 171 | raise Exception("Not dataset in list of datasets") 172 | else: #GNNBenchmarkDataset 173 | if args.dataset in No_Features: 174 | if args.prepro is not None: 175 | preprocessing = preprocessing 176 | processing = FeatureDegree() 177 | else: 178 | preprocessing = FeatureDegree() 179 | processing = None 180 | 181 | dataset = GNNBenchmarkDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'GNNBenchmarkDataset', name=args.dataset, 182 | pre_transform=preprocessing, transform = processing) 183 | if args.dataset =="CSL": 184 | TRAIN_SPLIT = 120 185 | BATCH_SIZE = 10 186 | num_of_centers = 42 187 | else: 188 | dataset = GNNBenchmarkDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'GNNBenchmarkDataset', name=args.dataset, pre_transform=preprocessing) #MNISTo CIFAR10 189 | if args.dataset =="MNIST": 190 | TRAIN_SPLIT = 50000 191 | BATCH_SIZE = 100 192 | num_of_centers = 100 193 | elif args.dataset == "CIFAR10": 194 | TRAIN_SPLIT = 40000 195 | BATCH_SIZE = 100 196 | num_of_centers = 100 197 | 198 | ##################### STATIC Variables ################################# 199 | 200 | device = args.cuda 201 | 202 | N_EPOCH = 60 203 | 204 | exp_name = f"{args.dataset}" 205 | exp_name = exp_name + f"{dataset.details}" if args.dataset == "SBM" else exp_name #add sbm details 206 | exp_name = exp_name + f"_{args.model}" 207 | exp_name = exp_name+"DIGL" if args.prepro=="digl" else exp_name 208 | exp_name = exp_name+"SDRF" if args.prepro=="sdrf" else exp_name 209 | exp_name = exp_name + f"_{args.derivative}" if args.model=="GAPNet" else exp_name # add derivative details 210 | exp_time = time.strftime('%d_%m_%y__%H_%M') 211 | train_log_file = exp_name + f"_{exp_time}.txt" 212 | 213 | RandList = [12345, 42345, 64345, 54345, 74345, 47345, 54321, 14321, 94321, 84328] 214 | RandList = RandList[:args.iter] 215 | 216 | if not os.path.exists(args.logs): 217 | os.makedirs(args.logs) 218 | if not os.path.exists("models/") and args.store: 219 | os.makedirs("models") 220 | 221 | ###################################################### 222 | ###################################################### 223 | def train(epoch, loader): 224 | model.train() 225 | loss_all = 0 226 | correct = 0 227 | #i = 0 228 | for data in loader: 229 | data = data.to(device) 230 | optimizer.zero_grad() 231 | out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) # data.batch torch.Size([783]) 232 | loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss 233 | loss.backward() 234 | loss_all += data.y.size(0) * loss.item() 235 | optimizer.step() 236 | correct += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item() #accuracy in train AFTER EACH BACH 237 | return loss_all / len(loader.dataset), correct / len(loader.dataset) 238 | 239 | @torch.no_grad() 240 | def test(loader): 241 | model.eval() 242 | correct = 0 243 | for data in loader: 244 | data = data.to(device) 245 | pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) 246 | loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss 247 | correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item() 248 | 249 | return loss, correct / len(loader.dataset) 250 | 251 | ###################################################### 252 | print(device) 253 | 254 | #torch.autograd.set_detect_anomaly(True) 255 | ExperimentResult = [] 256 | 257 | f = open(args.logs+os.sep+train_log_file, 'w') #clear file 258 | print("- M:", args.model, "- D:",dataset, 259 | "- Train_split:", TRAIN_SPLIT, "- B:",BATCH_SIZE, 260 | "- Centers (if CTNet):", num_of_centers, "- LAP (if GAPNet):", args.derivative, 261 | "- Classes" ,dataset.num_classes,"- Feats",dataset.num_features, file=f) 262 | f.close() 263 | 264 | print("- M:", args.model, "- D:",dataset, "- Train_split:", TRAIN_SPLIT, "- B:",BATCH_SIZE) 265 | 266 | EPS=1e-10 267 | for e in range(len(RandList)): 268 | if args.model == 'CTNet': 269 | model = CTNet(dataset.num_features, dataset.num_classes, k_centers=num_of_centers, EPS=EPS).to(device) 270 | elif args.model == 'GAPNet': 271 | model = GAPNet(dataset.num_features, dataset.num_classes, derivative=args.derivative, device=device).to(device) 272 | elif args.model == 'MinCutNet': 273 | model = MinCutNet(dataset.num_features, dataset.num_classes).to(device) 274 | elif args.model == 'DiffWire': 275 | model = DiffWire(dataset.num_features, dataset.num_classes, k_centers=num_of_centers, 276 | derivative="normalized", device=device, EPS=EPS).to(device) 277 | else: 278 | raise Exception(f"Not implemented model: {args.model}") 279 | 280 | train_indices, test_indices = train_test_split(list(range(len(dataset.data.y))), test_size=0.15, stratify=dataset.data.y, 281 | random_state=RandList[e], shuffle=True) 282 | train_dataset = torch.utils.data.Subset(dataset, train_indices) 283 | test_dataset = torch.utils.data.Subset(dataset, test_indices) 284 | #model = GAPNet(dataset.num_features, dataset.num_classes, derivative=DERIVATIVE, device=device).to(device) 285 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4) # 286 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # Original 64 287 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # Original 64 288 | #train_dataset = train_dataset.shuffle() 289 | #test_dataset = test_dataset.shuffle() 290 | optimizer.zero_grad() 291 | torch.manual_seed(RandList[e]) 292 | print("Experimen run", RandList[e]) 293 | 294 | for epoch in range(1, 60): 295 | start_time_epoch = time.time() 296 | train_loss, train_acc = train(epoch, train_loader) # return also train_acc_t if want Accuracy after BATCH 297 | #_, train_acc = test(train_loader) 298 | test_loss, test_acc = test(test_loader) 299 | time_lapse = time.time() - start_time_epoch 300 | 301 | f = open(args.logs+os.sep+train_log_file, 'a') 302 | print('Epoch: {:03d}, ' 303 | 'Train Loss: {:.3f}, Train Acc: {:.3f}, ' 304 | 'Test Loss: {:.3f}, Test Acc: {:.3f}'.format(epoch, train_loss, 305 | train_acc, test_loss, 306 | test_acc), file=f) 307 | print('{} - Epoch: {:03d}, ' 308 | 'Train Loss: {:.3f}, Train Acc: {:.3f}, ' 309 | 'Test Loss: {:.3f}, Test Acc: {:.3f}, Time: {:.2f}'.format(exp_name, epoch, train_loss, 310 | train_acc, test_loss, 311 | test_acc, time_lapse)) 312 | f.close() 313 | 314 | if args.store: 315 | torch.save(model.state_dict(), f"models{os.sep}{exp_name}_{exp_time}_iter{e}.pth") 316 | print(f"Model saved in models{os.sep}{exp_name}_{exp_time}_iter{e}.pth") 317 | 318 | ExperimentResult.append(test_acc) 319 | f = open(args.logs+os.sep+train_log_file, 'a') 320 | print('Result of run {:.3f} is {:.3f}'.format(e,test_acc), file=f) 321 | f.close() 322 | 323 | f = open(args.logs+os.sep+train_log_file, 'a') 324 | print('Test Acc of 10 execs {}'.format(ExperimentResult), file=f) 325 | print('{} +- {}'.format(np.mean(ExperimentResult), np.std(ExperimentResult)), file=f) 326 | f.close() 327 | -------------------------------------------------------------------------------- /trained_models/CTNet/COLLAB_CTNet_17_05_22__08_56_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/COLLAB_CTNet_17_05_22__08_56_iter0.pth -------------------------------------------------------------------------------- /trained_models/CTNet/ERDOS_CTNet_19_05_22__14_46_iter1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/ERDOS_CTNet_19_05_22__14_46_iter1.pth -------------------------------------------------------------------------------- /trained_models/CTNet/IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth -------------------------------------------------------------------------------- /trained_models/CTNet/MUTAG_CTNet_19_05_22__11_35_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/MUTAG_CTNet_19_05_22__11_35_iter0.pth -------------------------------------------------------------------------------- /trained_models/CTNet/REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth -------------------------------------------------------------------------------- /trained_models/CTNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth -------------------------------------------------------------------------------- /trained_models/GAPNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth -------------------------------------------------------------------------------- /trained_models/Laplacian/COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/Laplacian/COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth -------------------------------------------------------------------------------- /trained_models/Laplacian/REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/Laplacian/REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | #from transforms.sdrf.sdrf_transform import SDRF 2 | from transforms.transform_features import FeatureDegree 3 | from transforms.transform_features import DIGLedges 4 | from transforms.transform_features import KNNGraph -------------------------------------------------------------------------------- /transforms/sdrf/curvature.py: -------------------------------------------------------------------------------- 1 | import math 2 | from numba import cuda 3 | import numpy as np 4 | import torch 5 | from torch_geometric.utils import ( 6 | to_networkx, 7 | from_networkx, 8 | to_dense_adj, 9 | remove_self_loops, 10 | to_undirected, 11 | ) 12 | 13 | from transforms.sdrf.utils import softmax 14 | 15 | 16 | @cuda.jit( 17 | "void(float32[:,:], float32[:,:], float32[:], float32[:], int32, float32[:,:])" 18 | ) 19 | def _balanced_forman_curvature(A, A2, d_in, d_out, N, C): 20 | i, j = cuda.grid(2) 21 | 22 | if (i < N) and (j < N): 23 | if A[i, j] == 0: 24 | C[i, j] = 0 25 | return 26 | 27 | if d_in[i] > d_out[j]: 28 | d_max = d_in[i] 29 | d_min = d_out[j] 30 | else: 31 | d_max = d_out[j] 32 | d_min = d_in[i] 33 | 34 | if d_max * d_min == 0: 35 | C[i, j] = 0 36 | return 37 | 38 | sharp_ij = 0 39 | lambda_ij = 0 40 | for k in range(N): 41 | TMP = A[k, j] * (A2[i, k] - A[i, k]) * A[i, j] 42 | if TMP > 0: 43 | sharp_ij += 1 44 | if TMP > lambda_ij: 45 | lambda_ij = TMP 46 | 47 | TMP = A[i, k] * (A2[k, j] - A[k, j]) * A[i, j] 48 | if TMP > 0: 49 | sharp_ij += 1 50 | if TMP > lambda_ij: 51 | lambda_ij = TMP 52 | 53 | C[i, j] = ( 54 | (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2[i, j] * A[i, j] 55 | ) 56 | if lambda_ij > 0: 57 | C[i, j] += sharp_ij / (d_max * lambda_ij) 58 | 59 | 60 | def balanced_forman_curvature(A, C=None): 61 | N = A.shape[0] 62 | A2 = torch.matmul(A, A) 63 | d_in = A.sum(axis=0) 64 | d_out = A.sum(axis=1) 65 | if C is None: 66 | C = torch.zeros(N, N).cuda() 67 | 68 | threadsperblock = (16, 16) 69 | blockspergrid_x = math.ceil(N / threadsperblock[0]) 70 | blockspergrid_y = math.ceil(N / threadsperblock[1]) 71 | blockspergrid = (blockspergrid_x, blockspergrid_y) 72 | 73 | _balanced_forman_curvature[blockspergrid, threadsperblock](A, A2, d_in, d_out, N, C) 74 | return C 75 | 76 | 77 | @cuda.jit( 78 | "void(float32[:,:], float32[:,:], float32, float32, int32, float32[:,:], int32, int32, int32[:], int32[:], int32, int32)" 79 | ) 80 | def _balanced_forman_post_delta( 81 | A, A2, d_in_x, d_out_y, N, D, x, y, i_neighbors, j_neighbors, dim_i, dim_j 82 | ): 83 | I, J = cuda.grid(2) 84 | 85 | if (I < dim_i) and (J < dim_j): 86 | i = i_neighbors[I] 87 | j = j_neighbors[J] 88 | 89 | if (i == j) or (A[i, j] != 0): 90 | D[I, J] = -1000 91 | return 92 | 93 | # Difference in degree terms 94 | if j == x: 95 | d_in_x += 1 96 | elif i == y: 97 | d_out_y += 1 98 | 99 | if d_in_x * d_out_y == 0: 100 | D[I, J] = 0 101 | return 102 | 103 | if d_in_x > d_out_y: 104 | d_max = d_in_x 105 | d_min = d_out_y 106 | else: 107 | d_max = d_out_y 108 | d_min = d_in_x 109 | 110 | # Difference in triangles term 111 | A2_x_y = A2[x, y] 112 | if (x == i) and (A[j, y] != 0): 113 | A2_x_y += A[j, y] 114 | elif (y == j) and (A[x, i] != 0): 115 | A2_x_y += A[x, i] 116 | 117 | # Difference in four-cycles term 118 | sharp_ij = 0 119 | lambda_ij = 0 120 | for z in range(N): 121 | A_z_y = A[z, y] + 0 122 | A_x_z = A[x, z] + 0 123 | A2_z_y = A2[z, y] + 0 124 | A2_x_z = A2[x, z] + 0 125 | 126 | if (z == i) and (y == j): 127 | A_z_y += 1 128 | if (x == i) and (z == j): 129 | A_x_z += 1 130 | if (z == i) and (A[j, y] != 0): 131 | A2_z_y += A[j, y] 132 | if (x == i) and (A[j, z] != 0): 133 | A2_x_z += A[j, z] 134 | if (y == j) and (A[z, i] != 0): 135 | A2_z_y += A[z, i] 136 | if (z == j) and (A[x, i] != 0): 137 | A2_x_z += A[x, i] 138 | 139 | TMP = A_z_y * (A2_x_z - A_x_z) * A[x, y] 140 | if TMP > 0: 141 | sharp_ij += 1 142 | if TMP > lambda_ij: 143 | lambda_ij = TMP 144 | 145 | TMP = A_x_z * (A2_z_y - A_z_y) * A[x, y] 146 | if TMP > 0: 147 | sharp_ij += 1 148 | if TMP > lambda_ij: 149 | lambda_ij = TMP 150 | 151 | D[I, J] = ( 152 | (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2_x_y * A[x, y] 153 | ) 154 | if lambda_ij > 0: 155 | D[I, J] += sharp_ij / (d_max * lambda_ij) 156 | 157 | 158 | def balanced_forman_post_delta(A, x, y, i_neighbors, j_neighbors, D=None): 159 | N = A.shape[0] 160 | A2 = torch.matmul(A, A) 161 | d_in = A[:, x].sum() 162 | d_out = A[y].sum() 163 | if D is None: 164 | D = torch.zeros(len(i_neighbors), len(j_neighbors)).cuda() 165 | 166 | threadsperblock = (16, 16) 167 | blockspergrid_x = math.ceil(D.shape[0] / threadsperblock[0]) 168 | blockspergrid_y = math.ceil(D.shape[1] / threadsperblock[1]) 169 | blockspergrid = (blockspergrid_x, blockspergrid_y) 170 | 171 | _balanced_forman_post_delta[blockspergrid, threadsperblock]( 172 | A, 173 | A2, 174 | d_in, 175 | d_out, 176 | N, 177 | D, 178 | x, 179 | y, 180 | np.array(i_neighbors), 181 | np.array(j_neighbors), 182 | D.shape[0], 183 | D.shape[1], 184 | ) 185 | return D 186 | 187 | 188 | def sdrf( 189 | data, 190 | loops=10, 191 | remove_edges=True, 192 | removal_bound=0.5, 193 | tau=1, 194 | is_undirected=False, 195 | ): 196 | edge_index = data.edge_index 197 | if is_undirected: 198 | edge_index = to_undirected(edge_index) 199 | A = to_dense_adj(remove_self_loops(edge_index)[0])[0] 200 | N = A.shape[0] 201 | G = to_networkx(data) 202 | if is_undirected: 203 | G = G.to_undirected() 204 | A = A.cuda() 205 | C = torch.zeros(N, N).cuda() 206 | 207 | for x in range(loops): 208 | can_add = True 209 | balanced_forman_curvature(A, C=C) 210 | ix_min = C.argmin().item() 211 | x = ix_min // N 212 | y = ix_min % N 213 | 214 | if is_undirected: 215 | x_neighbors = list(G.neighbors(x)) + [x] 216 | y_neighbors = list(G.neighbors(y)) + [y] 217 | else: 218 | x_neighbors = list(G.successors(x)) + [x] 219 | y_neighbors = list(G.predecessors(y)) + [y] 220 | candidates = [] 221 | for i in x_neighbors: 222 | for j in y_neighbors: 223 | if (i != j) and (not G.has_edge(i, j)): 224 | candidates.append((i, j)) 225 | 226 | if len(candidates): 227 | D = balanced_forman_post_delta(A, x, y, x_neighbors, y_neighbors) 228 | improvements = [] 229 | for (i, j) in candidates: 230 | improvements.append( 231 | (D - C[x, y])[x_neighbors.index(i), y_neighbors.index(j)].item() 232 | ) 233 | 234 | k, l = candidates[ 235 | np.random.choice( 236 | range(len(candidates)), p=softmax(np.array(improvements), tau=tau) 237 | ) 238 | ] 239 | G.add_edge(k, l) 240 | if is_undirected: 241 | A[k, l] = A[l, k] = 1 242 | else: 243 | A[k, l] = 1 244 | else: 245 | can_add = False 246 | if not remove_edges: 247 | break 248 | 249 | if remove_edges: 250 | ix_max = C.argmax().item() 251 | x = ix_max // N 252 | y = ix_max % N 253 | if C[x, y] > removal_bound: 254 | G.remove_edge(x, y) 255 | if is_undirected: 256 | A[x, y] = A[y, x] = 0 257 | else: 258 | A[x, y] = 0 259 | else: 260 | if can_add is False: 261 | break 262 | 263 | return from_networkx(G) 264 | -------------------------------------------------------------------------------- /transforms/sdrf/sdrf_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import BaseTransform 3 | from torch_geometric.data import Data 4 | from transforms.sdrf.curvature import sdrf 5 | from transforms.sdrf.utils import get_dataset 6 | 7 | class SDRF(BaseTransform): 8 | 9 | def __init__(self, 10 | max_steps: int = None, 11 | remove_edges: bool = True, 12 | removal_bound: float = 0.5, 13 | tau: float = 1, 14 | undirected: bool = False, 15 | use_edge_weigths = True 16 | ): 17 | 18 | self.max_steps = max_steps 19 | self.remove_edges = remove_edges 20 | self.removal_bound = removal_bound 21 | self.tau = tau 22 | self.undirected = undirected 23 | self.use_edge_weigths = use_edge_weigths 24 | 25 | def __call__(self, graph_data): 26 | 27 | graph_data = get_dataset(graph_data, use_lcc=False) 28 | 29 | if self.max_steps == 'dynamic': 30 | self.max_steps = int(0.7 * graph_data.num_nodes) 31 | else: 32 | self.max_steps = int(self.max_steps) 33 | 34 | altered_data = sdrf( 35 | graph_data, 36 | loops=self.max_steps, 37 | remove_edges=self.remove_edges, 38 | removal_bound=self.removal_bound, 39 | tau=self.tau, 40 | is_undirected=self.undirected, 41 | ) 42 | 43 | new_data = Data( 44 | edge_index=torch.LongTensor(altered_data.edge_index), 45 | edge_attr=torch.FloatTensor(altered_data.edge_attr) if altered_data.edge_attr is not None else None, 46 | y=graph_data.y, 47 | x=graph_data.x, 48 | num_nodes = graph_data.num_nodes 49 | ) 50 | return new_data 51 | 52 | 53 | def __repr__(self) -> str: 54 | return f'{self.__class__.__name__}({self.max_steps})' 55 | 56 | -------------------------------------------------------------------------------- /transforms/sdrf/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch_geometric.data import Data 3 | import torch 4 | 5 | def softmax(a, tau=1): 6 | exp_a = np.exp(a * tau) 7 | return exp_a / exp_a.sum() 8 | 9 | 10 | def get_node_mapper(lcc: np.ndarray) -> dict: 11 | mapper = {} 12 | counter = 0 13 | for node in lcc: 14 | mapper[node] = counter 15 | counter += 1 16 | return mapper 17 | 18 | 19 | def remap_edges(edges: list, mapper: dict) -> list: 20 | row = [e[0] for e in edges] 21 | col = [e[1] for e in edges] 22 | row = list(map(lambda x: mapper[x], row)) 23 | col = list(map(lambda x: mapper[x], col)) 24 | return [row, col] 25 | 26 | 27 | def get_component(dataset, start: int = 0) -> set: 28 | visited_nodes = set() 29 | queued_nodes = set([start]) 30 | row, col = dataset.edge_index.numpy() 31 | while queued_nodes: 32 | current_node = queued_nodes.pop() 33 | visited_nodes.update([current_node]) 34 | neighbors = col[np.where(row == current_node)[0]] 35 | neighbors = [ 36 | n for n in neighbors if n not in visited_nodes and n not in queued_nodes 37 | ] 38 | queued_nodes.update(neighbors) 39 | return visited_nodes 40 | 41 | 42 | def get_largest_connected_component(dataset) -> np.ndarray: 43 | remaining_nodes = set(range(dataset.num_nodes)) 44 | comps = [] 45 | while remaining_nodes: 46 | start = min(remaining_nodes) 47 | comp = get_component(dataset, start) 48 | comps.append(comp) 49 | remaining_nodes = remaining_nodes.difference(comp) 50 | return np.array(list(comps[np.argmax(list(map(len, comps)))])) 51 | 52 | 53 | def get_dataset(dataset: Data, use_lcc: bool = True): 54 | if use_lcc: 55 | lcc = get_largest_connected_component(dataset) 56 | 57 | if dataset.x is not None: 58 | x_new = dataset.x[lcc] 59 | else: 60 | x_new = None 61 | y_new = dataset.y # for graph clf, same y 62 | 63 | row, col = dataset.edge_index.numpy() 64 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc] 65 | edges = remap_edges(edges, get_node_mapper(lcc)) 66 | 67 | data = Data( 68 | x=x_new, 69 | edge_index=torch.LongTensor(edges), 70 | y=y_new, 71 | num_nodes = dataset.num_nodes 72 | ) 73 | dataset = data 74 | 75 | return dataset 76 | -------------------------------------------------------------------------------- /transforms/transform_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import torch_geometric 5 | from torch_geometric.transforms import BaseTransform 6 | from torch_geometric.utils import degree, to_undirected 7 | from torch_geometric.utils.convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix 8 | 9 | import scipy.sparse as sp 10 | import numpy as np 11 | 12 | 13 | class FeatureDegree(BaseTransform): 14 | r"""Adds the node degree as one hot encodings to the node features. 15 | 16 | Args: 17 | max_degree (int): Maximum degree. 18 | in_degree (bool, optional): If set to :obj:`True`, will compute the 19 | in-degree of nodes instead of the out-degree. 20 | (default: :obj:`False`) 21 | cat (bool, optional): Concat node degrees to node features instead 22 | of replacing them. (default: :obj:`True`) 23 | """ 24 | def __init__(self, in_degree=False, cat=True): 25 | self.in_degree = in_degree 26 | self.cat = cat 27 | 28 | def __call__(self, data): 29 | idx, x = data.edge_index[1 if self.in_degree else 0], data.x 30 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1) 31 | #deg = F.one_hot(deg, num_classes=self.max_degree + 1).to(torch.float) 32 | 33 | if x is not None and self.cat: 34 | x = x.view(-1, 1) if x.dim() == 1 else x 35 | data.x = torch.cat([x, deg.to(x.dtype)], dim=-1) 36 | else: 37 | data.x = deg 38 | 39 | return data 40 | 41 | def __repr__(self) -> str: 42 | return f'{self.__class__.__name__}({self.in_degree})' 43 | 44 | 45 | class DIGLedges(BaseTransform): 46 | def __init__(self, alpha:float, use_edge_weigths = False): 47 | self.alpha = alpha 48 | self.eps = 0.005 49 | self.use_edge_weigths = use_edge_weigths 50 | 51 | def __call__(self, data): 52 | new_edges, new_weights = self.digl_edges(data.edge_index, data.num_edges) 53 | data.edge_index = new_edges 54 | 55 | if self.use_edge_weigths: 56 | data.edge_weight = new_weights 57 | 58 | return data 59 | 60 | 61 | def gdc(self, A: sp.csr_matrix, alpha: float, num_previous_edges): 62 | N = A.shape[0] 63 | 64 | # Self-loops 65 | A_loop = sp.eye(N) + A 66 | 67 | # Symmetric transition matrix 68 | D_loop_vec = A_loop.sum(0).A1 69 | D_loop_vec_invsqrt = 1 / np.sqrt(D_loop_vec) 70 | D_loop_invsqrt = sp.diags(D_loop_vec_invsqrt) 71 | T_sym = D_loop_invsqrt @ A_loop @ D_loop_invsqrt 72 | 73 | # PPR-based diffusion 74 | S = alpha * sp.linalg.inv(sp.eye(N) - (1 - alpha) * T_sym) 75 | 76 | # Same as e-threshold based on average degree 77 | # but dynamic for all graphs 78 | A = np.array(S.todense()) 79 | top_k_idx = np.unravel_index(np.argsort(A.ravel())[-num_previous_edges:], A.shape) 80 | mask = np.ones(A.shape, bool) 81 | mask[top_k_idx] = False 82 | A[mask] = 0 83 | S_tilde = sp.csr_matrix(A) 84 | 85 | 86 | # Column-normalized transition matrix on graph S_tilde 87 | D_tilde_vec = S_tilde.sum(0).A1 88 | T_S = S_tilde / D_tilde_vec 89 | 90 | return T_S 91 | 92 | def get_top_k_matrix(self, A: np.ndarray, k: int = 128) -> np.ndarray: 93 | """ 94 | Get k best edges for EACH NODE 95 | """ 96 | num_nodes = A.shape[0] 97 | print('AA', num_nodes) 98 | row_idx = np.arange(num_nodes) 99 | A[A.argsort(axis=0)[:num_nodes - k], row_idx] = 0. 100 | norm = A.sum(axis=0) 101 | norm[norm <= 0] = 1 # avoid dividing by zero 102 | return A/norm 103 | 104 | def digl_edges(self, edges, num_previous_edges): 105 | A0 = sp.csr_matrix(to_scipy_sparse_matrix(edges)) 106 | new_sp_matrix = sp.csr_matrix(self.gdc(A0, self.alpha, num_previous_edges)) 107 | new_edge_index, weights = from_scipy_sparse_matrix(new_sp_matrix) 108 | return new_edge_index, weights 109 | 110 | 111 | def __repr__(self) -> str: 112 | return f'{self.__class__.__name__}(alpha={self.alpha}, eps={self.alpha})' 113 | 114 | 115 | class KNNGraph(BaseTransform): 116 | r"""Creates a k-NN graph based on node positions :obj:`pos` 117 | (functional name: :obj:`knn_graph`). 118 | 119 | Args: 120 | k (int, optional): The number of neighbors. (default: :obj:`6`) 121 | loop (bool, optional): If :obj:`True`, the graph will contain 122 | self-loops. (default: :obj:`False`) 123 | force_undirected (bool, optional): If set to :obj:`True`, new edges 124 | will be undirected. (default: :obj:`False`) 125 | flow (string, optional): The flow direction when used in combination 126 | with message passing (:obj:`"source_to_target"` or 127 | :obj:`"target_to_source"`). 128 | If set to :obj:`"source_to_target"`, every target node will have 129 | exactly :math:`k` source nodes pointing to it. 130 | (default: :obj:`"source_to_target"`) 131 | cosine (boolean, optional): If :obj:`True`, will use the cosine 132 | distance instead of euclidean distance to find nearest neighbors. 133 | (default: :obj:`False`) 134 | num_workers (int): Number of workers to use for computation. Has no 135 | effect in case :obj:`batch` is not :obj:`None`, or the input lies 136 | on the GPU. (default: :obj:`1`) 137 | """ 138 | def __init__( 139 | self, 140 | k=None, 141 | loop=False, 142 | force_undirected=True, 143 | flow='source_to_target', 144 | cosine: bool = False, 145 | num_workers: int = 1, 146 | ): 147 | self.k = k 148 | self.loop = loop 149 | self.force_undirected = force_undirected 150 | self.flow = flow 151 | self.cosine = cosine 152 | self.num_workers = num_workers 153 | 154 | def __call__(self, data): 155 | had_features = True 156 | data.edge_attr = None 157 | batch = data.batch if 'batch' in data else None 158 | 159 | if data.x is None: 160 | idx = data.edge_index[1] 161 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1) 162 | data.x = deg 163 | had_features = False 164 | 165 | if self.k is None: 166 | self.k = int(data.num_edges / (data.num_nodes*4)) #mean degree - Note: num_edges is already doubled by default in PyG 167 | 168 | 169 | edge_index = torch_geometric.nn.knn_graph( 170 | data.x, 171 | self.k, 172 | batch, 173 | loop=self.loop, 174 | flow=self.flow, 175 | cosine=self.cosine, 176 | num_workers=self.num_workers, 177 | ) 178 | if self.force_undirected: 179 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) 180 | 181 | data.edge_index = edge_index 182 | 183 | #Update degree 184 | if not had_features: 185 | idx = data.edge_index[1] 186 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1) 187 | data.x = deg 188 | 189 | return data 190 | 191 | def __repr__(self) -> str: 192 | return f'{self.__class__.__name__}(k={self.k})' --------------------------------------------------------------------------------