├── .gitignore
├── Analysis of Assortativity-Bottleneck and CTE latent space.ipynb
├── datasets
├── Erdos_Renyi_dataset.py
└── SBM_dataset.py
├── environment_experiments.yml
├── experiments_all.sh
├── figs
├── AssortativeBottleneckAccuracy
│ ├── COLLAB.png
│ ├── HistAsssortBottleCOLLAB.pdf
│ ├── HistAsssortBottleREDDIT-BINARY.pdf
│ ├── REDDIT.png
│ ├── ScatterAsssortBottleCOLLAB.pdf
│ └── ScatterAsssortBottleREDDIT-BINARY.pdf
├── CT_REDDIT_READOUT_17_05_22__11_19.jpg
├── CT_REDDIT_READOUT_19_05_22__10_19.jpg
├── CT_REDDIT_READOUT_19_05_22__10_30.jpg
├── CT_RW.png
├── GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg
├── GAP_REDDIT_READOUT_17_05_22__11_31.jpg
├── GAP_REDDIT_READOUT_19_05_22__10_35.jpg
├── MinCut_REDDIT_READOUT_17_05_22__11_21.jpg
├── MinCut_REDDIT_READOUT_17_05_22__11_22.jpg
├── MinCut_REDDIT_READOUT_19_05_22__10_19.jpg
├── MinCut_REDDIT_READOUT_19_05_22__10_30.jpg
├── REDDIT_CT_READOUT_1.jpg
└── data_statistics
│ ├── Degree_hist_CIFAR10.png
│ ├── Degree_hist_COLLAB.png
│ ├── Degree_hist_ENZYMES.png
│ ├── Degree_hist_IMDB-BINARY.png
│ ├── Degree_hist_MNIST.png
│ ├── Degree_hist_MUTAG.png
│ ├── Degree_hist_PROTEINS.png
│ ├── Degree_hist_REDDIT-BINARY.png
│ ├── avg_degree_hist_CIFAR10.png
│ ├── avg_degree_hist_COLLAB.png
│ ├── avg_degree_hist_ENZYMES.png
│ ├── avg_degree_hist_IMDB-BINARY.png
│ ├── avg_degree_hist_MNIST.png
│ ├── avg_degree_hist_MUTAG.png
│ ├── avg_degree_hist_PROTEINS.png
│ ├── avg_degree_hist_REDDIT-BINARY.png
│ ├── number_nodes_hist_CIFAR10.png
│ ├── number_nodes_hist_COLLAB.png
│ ├── number_nodes_hist_ENZYMES.png
│ ├── number_nodes_hist_IMDB-BINARY.png
│ ├── number_nodes_hist_MNIST.png
│ ├── number_nodes_hist_MUTAG.png
│ ├── number_nodes_hist_PROTEINS.png
│ └── number_nodes_hist_REDDIT-BINARY.png
├── layers
├── CT_layer.py
├── GAP_layer.py
├── MinCut_Layer.py
└── utils
│ ├── approximate_fiedler.py
│ ├── ein_utils.py
│ └── spectral_utils.py
├── nets.py
├── readme.md
├── requirements.txt
├── train.py
├── trained_models
├── CTNet
│ ├── COLLAB_CTNet_17_05_22__08_56_iter0.pth
│ ├── ERDOS_CTNet_19_05_22__14_46_iter1.pth
│ ├── IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth
│ ├── MUTAG_CTNet_19_05_22__11_35_iter0.pth
│ ├── REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth
│ └── SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth
├── GAPNet
│ ├── COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth
│ ├── ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth
│ ├── IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth
│ ├── MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth
│ ├── REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth
│ └── SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth
└── Laplacian
│ ├── COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth
│ └── REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth
└── transforms
├── __init__.py
├── sdrf
├── curvature.py
├── sdrf_transform.py
└── utils.py
└── transform_features.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .txt
6 | # C extensions
7 | *.so
8 | .DS_Store
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 | /data
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | data/
132 | models/
133 | *.sh
134 | !experiments_all.sh
135 | data_colab/
136 | logs*/
137 | !logs_stratified/
138 |
--------------------------------------------------------------------------------
/datasets/Erdos_Renyi_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric.utils.random import erdos_renyi_graph
3 | import torch
4 | from torch_geometric.data import Data, InMemoryDataset
5 | from torch_geometric.data import DataLoader
6 | from torch_geometric.utils import to_dense_adj
7 |
8 | class Erdos_Renyi_pyg(InMemoryDataset):
9 | def __init__(self, root, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1_min=0.1, p1_max=0.5,
10 | p2_min=0.5, p2_max=0.8, directed=False, transform=None, pre_transform=None):
11 | self.nb_nodes1 = nb_nodes1
12 | self.nb_graphs1 = nb_graphs1
13 | self.nb_nodes2 = nb_nodes2
14 | self.nb_graphs2 = nb_graphs2
15 | self.nb_graphs = self.nb_graphs1 + self.nb_graphs2
16 |
17 | self.directed = directed
18 |
19 | self.p1_min = p1_min
20 | self.p1_max = p1_max
21 | self.p2_min = p2_min
22 | self.p2_max = p2_max
23 |
24 | self.details = f"{self.nb_nodes1}nodesT{self.nb_graphs}graphsT" \
25 | + f"{int(p1_min*100)}to{int(p1_max*100)}p1T{int(p2_min*100)}to{int(p2_max*100)}p2"
26 | self.root = f"{root}_{self.details}"
27 | print(self.root)
28 | super(Erdos_Renyi_pyg, self).__init__(self.root, transform, pre_transform)
29 | self.data, self.slices = torch.load(self.processed_paths[0])
30 |
31 | @property
32 | def raw_file_names(self):
33 | return ['tentative']
34 |
35 | @property
36 | def processed_file_names(self):
37 | return ['data.pt']
38 |
39 | def download(self):
40 | pass
41 |
42 | def process(self):
43 | # Read data into huge `Data` lists.
44 | data_list1 = self._generate_graphs(self.nb_nodes1, self.nb_graphs1, self.p1_min, self.p1_max, 0)
45 | data_list2 = self._generate_graphs(self.nb_nodes2, self.nb_graphs2, self.p2_min, self.p2_max, 1)
46 | data_list = [*data_list1, *data_list2]
47 |
48 | if self.pre_filter is not None:
49 | data_list = [data for data in data_list if self.pre_filter(data)]
50 |
51 | if self.pre_transform is not None:
52 | data_list = [self.pre_transform(data) for data in data_list]
53 |
54 | data, slices = self.collate(data_list)
55 | torch.save((data, slices), self.processed_paths[0])
56 |
57 | def _generate_graphs(self, nb_nodes, nb_graphs, p_min, p_max, myclass):
58 | dataset = []
59 | m = (p_max - p_min)/nb_graphs # Linear slope
60 | dataset = []
61 | for i in range(nb_graphs):
62 | p_i = m*(i+1) + p_min
63 | # Get the SBM graph
64 | d = erdos_renyi_graph(num_nodes=nb_nodes, edge_prob=p_i, directed=self.directed)
65 |
66 | # Get degree as feature
67 | adj = to_dense_adj(d)
68 | A = adj
69 | D = A.sum(dim=1)
70 | x = torch.transpose(D,0,1)
71 | mydata = Data(x=x, edge_index=d, y = myclass, num_nodes=nb_nodes)
72 | dataset.append(mydata)
73 | return dataset
74 |
75 | def __repr__(self) -> str:
76 | return f'{self.__class__.__name__}({self.details})'
77 |
78 |
79 | if __name__ == "__main__":
80 | print("Tesdting Erdos-Renyi")
81 | dataset = Erdos_Renyi_pyg('../data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500,
82 | p1_min=0.1, p1_max=0.5, p2_min=0.3, p2_max=0.8)
83 | print()
84 | print(f'Dataset: {dataset}:')
85 | print('====================')
86 | print(f'Number of graphs: {len(dataset)}')
87 | print(f'Number of features: {dataset.num_features}')
88 | print(f'Number of classes: {dataset.num_classes}')
89 |
90 | data = dataset[0] # Get the first graph object.
91 |
92 | print()
93 | print(data)
94 | print('=============================================================')
95 |
96 | # Gather some statistics about the first graph.
97 | print(f'Number of nodes: {data.num_nodes}')
98 | print(f'Number of edges: {data.num_edges}')
99 | print(f'Label: {data.y}')
100 | print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
101 | print(f'Has isolated nodes: {data.has_isolated_nodes()}')
102 | print(f'Has self-loops: {data.has_self_loops()}')
103 | print(f'Is undirected: {data.is_undirected()}')
104 | print(dataset.details)
105 |
106 | """train_dataset = dataset[:int(0.8*len(dataset))]
107 | test_dataset = dataset[int(0.8*len(dataset)):]
108 |
109 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Original 64
110 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Original 64
111 |
112 |
113 | for step, data in enumerate(train_loader):
114 | print(f'Step {step + 1}:')
115 | print('=======')
116 | print(f'Number of graphs in the current batch: {data.num_graphs}')
117 | print(data)
118 | print()"""
119 |
120 |
--------------------------------------------------------------------------------
/datasets/SBM_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data, InMemoryDataset
3 | from torch_geometric.utils.random import stochastic_blockmodel_graph
4 | from torch_geometric.data import DataLoader
5 | from torch_geometric.utils import to_dense_adj
6 |
7 | class SBM_pyg(InMemoryDataset):
8 | def __init__(self, root, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1=0.8, p2=0.5, qmin1=0.5, qmax1=0.8, qmin2=0.25, qmax2=0.7, directed=False, transform=None, pre_transform=None):
9 | """
10 | Create SBM dataset with graph of 2 classes. Each graph will have 2 communities.
11 | For each class we can parametrize the number of nodes, number of graphs and
12 | intra-interclass edge probability max an min probability.
13 |
14 | Args:
15 | root (str): path to save the data.
16 | nb_nodes1 (int, optional): number of nodes in the graphs of class 1. Defaults to 200.
17 | nb_graphs1 (int, optional): number of graphs of class 1. Defaults to 500.
18 | nb_nodes2 (int, optional): number of nodes in the graphs of class 2. Defaults to 200.
19 | nb_graphs2 (int, optional): number of graphs of class 2. Defaults to 500.
20 | p1 (float, optional): intraclass edge probability for community 1 for both graph classes. Defaults to 0.8.
21 | p2 (float, optional): intraclass edge probability for community 2 for both graph classes. Defaults to 0.5.
22 | qmin1 (float, optional): minimun intercalass probability for graphs in class 1. Defaults to 0.5.
23 | qmax1 (float, optional): minimun interclass probability for graphs in class 2. Defaults to 0.8.
24 | qmin2 (float, optional): maximun intercalass probability for graphs in class 1. Defaults to 0.25.
25 | qmax2 (float, optional): maximun intercalass probability for graphs in class 2. Defaults to 0.7.
26 | directed (bool, optional): Create directed or Undirected Graphs. Defaults to False.
27 | transform (torch_geometric.transforms.BaseTransform, optional): on the fly transformation. Defaults to None.
28 | pre_transform (torch_geometric.transforms.BaseTransform, optional): transformation to save in disk. Defaults to None.
29 | """
30 | self.nb_nodes1 = nb_nodes1
31 | self.nb_graphs1 = nb_graphs1
32 | self.nb_nodes2 = nb_nodes2
33 | self.nb_graphs2 = nb_graphs2
34 | self.nb_graphs = self.nb_graphs1 + self.nb_graphs2
35 | #self.num_features = 1 # Degree
36 | #self.num_classes = 2
37 | self.edge_probs1_min = [[p1, qmin1], [qmin1, p2]] # Minimal setting class 1
38 | self.edge_probs2_min = [[p1, qmin2], [qmin2, p2]] # Minimal setting class 2
39 | self.edge_probs1_max = [[p1, qmax1], [qmax1, p2]] # Maximal setting class 1
40 | self.edge_probs2_max = [[p1, qmax2], [qmax2, p2]] # Maximal setting class 2
41 |
42 | self.details = f"{self.nb_nodes1}nodesT{self.nb_graphs}graphsT{int(p1*100)}" \
43 | + f"p1T{int(p2*100)}p2T{int(qmin1*100)}to{int(qmax1*100)}q1T{int(qmin2*100)}to{int(qmax2*100)}q2"
44 | self.root = f"{root}_{self.details}"
45 | print(self.root)
46 | super(SBM_pyg, self).__init__(self.root, transform, pre_transform)
47 | self.data, self.slices = torch.load(self.processed_paths[0])
48 |
49 | @property
50 | def raw_file_names(self):
51 | return ['tentative']
52 |
53 | @property
54 | def processed_file_names(self):
55 | return ['data.pt']
56 |
57 | def download(self):
58 | pass
59 |
60 | def process(self):
61 | # Read data into huge `Data` lists.
62 | data_list1 = self._generate_graphs(self.nb_nodes1, self.nb_graphs1, self.edge_probs1_min, self.edge_probs1_max, 0)
63 | data_list2 = self._generate_graphs(self.nb_nodes2, self.nb_graphs2, self.edge_probs2_min, self.edge_probs2_max, 1)
64 | data_list = [*data_list1, *data_list2]
65 |
66 | if self.pre_filter is not None:
67 | data_list = [data for data in data_list if self.pre_filter(data)]
68 |
69 | if self.pre_transform is not None:
70 | data_list = [self.pre_transform(data) for data in data_list]
71 |
72 | data, slices = self.collate(data_list)
73 | torch.save((data, slices), self.processed_paths[0])
74 |
75 | def _generate_graphs(self, nb_nodes, nb_graphs, edge_probs_min, edge_probs_max, myclass):
76 | # p and q are static
77 | # qmin < qmax
78 | # for each graph move from qmin to qmax
79 | qmin = edge_probs_min[1][0]
80 | qmax = edge_probs_max[1][0]
81 | m = (qmax - qmin)/nb_graphs # Linear slope
82 | dataset = []
83 | for i in range(nb_graphs):
84 | q = m*(i+1) + qmin
85 | p1 = edge_probs_min[0][0]
86 | p2 = edge_probs_min[1][1]
87 | # Get the SBM graph
88 | d = stochastic_blockmodel_graph(block_sizes=[int(nb_nodes/2),int(nb_nodes/2)], edge_probs=[[p1, q],[q, p2]], directed=False)
89 |
90 | # Get degree as feature
91 | adj = to_dense_adj(d)
92 | A = adj
93 | D = A.sum(dim=1)
94 | x = torch.transpose(D,0,1)
95 | mydata = Data(x=x, edge_index=d, y = myclass, num_nodes=nb_nodes)
96 | dataset.append(mydata)
97 | return dataset
98 |
99 | def __repr__(self) -> str:
100 | return f'{self.__class__.__name__}({self.details})'
101 |
102 |
103 | if __name__ == "__main__":
104 | print("Tesdting SBM")
105 | dataset = SBM_pyg('../data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500, p1=0.8, p2=0.5, qmin1=0.5, qmax1=0.8, qmin2=0.25, qmax2=0.71, directed=False, transform=None, pre_transform=None)
106 |
107 | print()
108 | print(f'Dataset: {dataset}:')
109 | print('====================')
110 | print(f'Number of graphs: {len(dataset)}')
111 | print(f'Number of features: {dataset.num_features}')
112 | print(f'Number of classes: {dataset.num_classes}')
113 |
114 | data = dataset[0] # Get the first graph object.
115 |
116 | print()
117 | print(data)
118 | print('=============================================================')
119 |
120 | # Gather some statistics about the first graph.
121 | print(f'Number of nodes: {data.num_nodes}')
122 | print(f'Number of edges: {data.num_edges}')
123 | print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
124 | print(f'Has isolated nodes: {data.has_isolated_nodes()}')
125 | print(f'Has self-loops: {data.has_self_loops()}')
126 | print(f'Is undirected: {data.is_undirected()}')
127 | print(dataset.details)
128 |
129 | """train_dataset = dataset[:int(0.8*len(dataset))]
130 | test_dataset = dataset[int(0.8*len(dataset)):]
131 |
132 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Original 64
133 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Original 64
134 |
135 |
136 | for step, data in enumerate(train_loader):
137 | print(f'Step {step + 1}:')
138 | print('=======')
139 | print(f'Number of graphs in the current batch: {data.num_graphs}')
140 | print(data)
141 | print()"""
142 |
143 |
--------------------------------------------------------------------------------
/environment_experiments.yml:
--------------------------------------------------------------------------------
1 | name: DiffWire
2 | channels:
3 | - pyg
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=4.5=1_gnu
9 | - blas=1.0=mkl
10 | - bottleneck=1.3.4=py39hce1f21e_0
11 | - brotli=1.0.9=he6710b0_2
12 | - brotlipy=0.7.0=py39h27cfd23_1003
13 | - bzip2=1.0.8=h7b6447c_0
14 | - ca-certificates=2022.4.26=h06a4308_0
15 | - certifi=2021.10.8=py39h06a4308_2
16 | - cffi=1.15.0=py39hd667e15_1
17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
18 | - colorama=0.4.4=pyhd3eb1b0_0
19 | - cryptography=36.0.0=py39h9ce1e76_0
20 | - cudatoolkit=10.2.89=hfd86e86_1
21 | - cycler=0.11.0=pyhd3eb1b0_0
22 | - dbus=1.13.18=hb2f20db_0
23 | - expat=2.4.4=h295c915_0
24 | - ffmpeg=4.3=hf484d3e_0
25 | - fontconfig=2.13.1=h6c09931_0
26 | - fonttools=4.25.0=pyhd3eb1b0_0
27 | - freetype=2.11.0=h70c0345_0
28 | - giflib=5.2.1=h7b6447c_0
29 | - glib=2.69.1=h4ff587b_1
30 | - gmp=6.2.1=h2531618_2
31 | - gnutls=3.6.15=he1e5248_0
32 | - gst-plugins-base=1.14.0=h8213a91_2
33 | - gstreamer=1.14.0=h28cd5cc_2
34 | - h5py=3.6.0=py39ha0f2276_0
35 | - hdf5=1.10.6=hb1b8bf9_0
36 | - icu=58.2=he6710b0_3
37 | - idna=3.3=pyhd3eb1b0_0
38 | - intel-openmp=2021.4.0=h06a4308_3561
39 | - jinja2=3.0.3=pyhd3eb1b0_0
40 | - joblib=1.1.0=pyhd3eb1b0_0
41 | - jpeg=9d=h7f8727e_0
42 | - kiwisolver=1.3.2=py39h295c915_0
43 | - lame=3.100=h7b6447c_0
44 | - lcms2=2.12=h3be6417_0
45 | - ld_impl_linux-64=2.35.1=h7274673_9
46 | - libffi=3.3=he6710b0_2
47 | - libgcc-ng=9.3.0=h5101ec6_17
48 | - libgfortran-ng=7.5.0=ha8ba4b0_17
49 | - libgfortran4=7.5.0=ha8ba4b0_17
50 | - libgomp=9.3.0=h5101ec6_17
51 | - libiconv=1.15=h63c8f33_5
52 | - libidn2=2.3.2=h7f8727e_0
53 | - libpng=1.6.37=hbc83047_0
54 | - libprotobuf=3.19.1=h4ff587b_0
55 | - libstdcxx-ng=9.3.0=hd4cf53a_17
56 | - libtasn1=4.16.0=h27cfd23_0
57 | - libtiff=4.2.0=h85742a9_0
58 | - libunistring=0.9.10=h27cfd23_0
59 | - libuuid=1.0.3=h7f8727e_2
60 | - libuv=1.40.0=h7b6447c_0
61 | - libwebp=1.2.2=h55f646e_0
62 | - libwebp-base=1.2.2=h7f8727e_0
63 | - libxcb=1.14=h7b6447c_0
64 | - libxml2=2.9.12=h03d6c58_0
65 | - lz4-c=1.9.3=h295c915_1
66 | - markupsafe=2.0.1=py39h27cfd23_0
67 | - matplotlib=3.5.1=py39h06a4308_1
68 | - matplotlib-base=3.5.1=py39ha18d171_1
69 | - mkl=2021.4.0=h06a4308_640
70 | - mkl-service=2.4.0=py39h7f8727e_0
71 | - mkl_fft=1.3.1=py39hd3c417c_0
72 | - mkl_random=1.2.2=py39h51133e4_0
73 | - munkres=1.1.4=py_0
74 | - ncurses=6.3=h7f8727e_2
75 | - nettle=3.7.3=hbbd107a_1
76 | - networkx=2.7.1=pyhd3eb1b0_0
77 | - numexpr=2.8.1=py39h6abb31d_0
78 | - numpy=1.21.2=py39h20f2e39_0
79 | - numpy-base=1.21.2=py39h79a1101_0
80 | - openh264=2.1.1=h4ff587b_0
81 | - openssl=1.1.1o=h7f8727e_0
82 | - packaging=21.3=pyhd3eb1b0_0
83 | - pandas=1.4.1=py39h295c915_1
84 | - path=16.2.0=pyhd3eb1b0_0
85 | - pcre=8.45=h295c915_0
86 | - pillow=9.0.1=py39h22f2fdc_0
87 | - pip=21.2.4=py39h06a4308_0
88 | - plotly=5.6.0=pyhd3eb1b0_0
89 | - protobuf=3.19.1=py39h295c915_0
90 | - pycparser=2.21=pyhd3eb1b0_0
91 | - pyg=2.0.4=py39_torch_1.11.0_cu102
92 | - pyopenssl=22.0.0=pyhd3eb1b0_0
93 | - pyparsing=3.0.4=pyhd3eb1b0_0
94 | - pyqt=5.9.2=py39h2531618_6
95 | - pysocks=1.7.1=py39h06a4308_0
96 | - python=3.9.11=h12debd9_2
97 | - python-dateutil=2.8.2=pyhd3eb1b0_0
98 | - python-louvain=0.15=pyhd3eb1b0_0
99 | - pytorch=1.11.0=py3.9_cuda10.2_cudnn7.6.5_0
100 | - pytorch-cluster=1.6.0=py39_torch_1.11.0_cu102
101 | - pytorch-mutex=1.0=cuda
102 | - pytorch-scatter=2.0.9=py39_torch_1.11.0_cu102
103 | - pytorch-sparse=0.6.13=py39_torch_1.11.0_cu102
104 | - pytorch-spline-conv=1.2.1=py39_torch_1.11.0_cu102
105 | - pytz=2021.3=pyhd3eb1b0_0
106 | - pyyaml=6.0=py39h7f8727e_1
107 | - qt=5.9.7=h5867ecd_1
108 | - readline=8.1.2=h7f8727e_1
109 | - requests=2.27.1=pyhd3eb1b0_0
110 | - scikit-learn=1.0.2=py39h51133e4_1
111 | - scipy=1.7.3=py39hc147768_0
112 | - seaborn=0.11.2=pyhd3eb1b0_0
113 | - setuptools=58.0.4=py39h06a4308_0
114 | - sip=4.19.13=py39h295c915_0
115 | - six=1.16.0=pyhd3eb1b0_1
116 | - sqlite=3.38.0=hc218d9a_0
117 | - tenacity=8.0.1=py39h06a4308_0
118 | - tensorboardx=2.2=pyhd3eb1b0_0
119 | - threadpoolctl=2.2.0=pyh0d69192_0
120 | - tk=8.6.11=h1ccaba5_0
121 | - torchaudio=0.11.0=py39_cu102
122 | - torchvision=0.12.0=py39_cu102
123 | - tornado=6.1=py39h27cfd23_0
124 | - tqdm=4.63.0=pyhd3eb1b0_0
125 | - typing_extensions=4.1.1=pyh06a4308_0
126 | - tzdata=2021e=hda174b7_0
127 | - urllib3=1.26.8=pyhd3eb1b0_0
128 | - wheel=0.37.1=pyhd3eb1b0_0
129 | - xz=5.2.5=h7b6447c_0
130 | - yacs=0.1.6=pyhd3eb1b0_1
131 | - yaml=0.2.5=h7b6447c_0
132 | - zlib=1.2.11=h7f8727e_4
133 | - zstd=1.4.9=haebb681_0
134 | prefix: /home/user/anaconda3/envs/tfg
135 |
--------------------------------------------------------------------------------
/experiments_all.sh:
--------------------------------------------------------------------------------
1 | python train.py --dataset MUTAG --model MinCutNet --cuda cuda:0
2 | python train.py --dataset MUTAG --model MinCutNet --cuda cuda:0 --prepro digl
3 | python train.py --dataset MUTAG --model CTNet --cuda cuda:0
4 | python train.py --dataset MUTAG --model GAPNet --derivative laplacian --cuda cuda:0
5 | python train.py --dataset MUTAG --model GAPNet --derivative normalized --cuda cuda:0
6 | python train.py --dataset MUTAG --model GAPNet --derivative normalizedv2 --cuda cuda:0
7 | python train.py --dataset MUTAG --model DiffWire --cuda cuda:0
8 |
9 | python train.py --dataset ENZYMES --model MinCutNet --cuda cuda:0
10 | python train.py --dataset ENZYMES --model MinCutNet --cuda cuda:0 --prepro digl
11 | python train.py --dataset ENZYMES --model CTNet --cuda cuda:0
12 | python train.py --dataset ENZYMES --model GAPNet --derivative laplacian --cuda cuda:0
13 | python train.py --dataset ENZYMES --model GAPNet --derivative normalized --cuda cuda:0
14 | python train.py --dataset ENZYMES --model GAPNet --derivative normalizedv2 --cuda cuda:0
15 | python train.py --dataset ENZYMES --model DiffWire --cuda cuda:0
16 |
17 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0
18 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 --prepro digl
19 | python train.py --dataset PROTEINS --model CTNet --cuda cuda:0
20 | python train.py --dataset PROTEINS --model GAPNet --derivative laplacian --cuda cuda:0
21 | python train.py --dataset PROTEINS --model GAPNet --derivative normalized --cuda cuda:0
22 | python train.py --dataset PROTEINS --model GAPNet --derivative normalizedv2 --cuda cuda:0
23 | python train.py --dataset PROTEINS --model DiffWire --cuda cuda:0
24 |
25 | python train.py --dataset REDDIT-BINARY --model MinCutNet --cuda cuda:0
26 | python train.py --dataset REDDIT-BINARY --model MinCutNet --cuda cuda:0 --prepro digl
27 | python train.py --dataset REDDIT-BINARY --model CTNet --cuda cuda:0
28 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative laplacian --cuda cuda:0
29 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalized --cuda cuda:0
30 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalizedv2 --cuda cuda:0
31 | python train.py --dataset REDDIT-BINARY --model DiffWire --cuda cuda:0
32 |
33 | python train.py --dataset COLLAB --model MinCutNet --cuda cuda:0
34 | python train.py --dataset COLLAB --model MinCutNet --cuda cuda:0 --prepro digl
35 | python train.py --dataset COLLAB --model CTNet --cuda cuda:0
36 | python train.py --dataset COLLAB --model GAPNet --derivative laplacian --cuda cuda:0
37 | python train.py --dataset COLLAB --model GAPNet --derivative normalized --cuda cuda:0
38 | python train.py --dataset COLLAB --model GAPNet --derivative normalizedv2 --cuda cuda:0
39 | python train.py --dataset COLLAB --model DiffWire --cuda cuda:0
40 |
41 | python train.py --dataset IMDB-BINARY --model MinCutNet --cuda cuda:0
42 | python train.py --dataset IMDB-BINARY --model MinCutNet --cuda cuda:0 --prepro digl
43 | python train.py --dataset IMDB-BINARY --model CTNet --cuda cuda:0
44 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative laplacian --cuda cuda:0
45 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative normalized --cuda cuda:0
46 | python train.py --dataset IMDB-BINARY --model GAPNet --derivative normalizedv2 --cuda cuda:0
47 | python train.py --dataset IMDB-BINARY --model DiffWire --cuda cuda:0
48 |
49 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0
50 | python train.py --dataset PROTEINS --model MinCutNet --cuda cuda:0 --prepro digl
51 | python train.py --dataset PROTEINS --model CTNet --cuda cuda:0
52 | python train.py --dataset PROTEINS --model GAPNet --derivative laplacian --cuda cuda:0
53 | python train.py --dataset PROTEINS --model GAPNet --derivative normalized --cuda cuda:0
54 | python train.py --dataset PROTEINS --model GAPNet --derivative normalizedv2 --cuda cuda:0
55 | python train.py --dataset PROTEINS --model DiffWire --cuda cuda:0
56 |
57 | python train.py --dataset MNIST --model MinCutNet --cuda cuda:0
58 | python train.py --dataset MNIST --model MinCutNet --cuda cuda:0 --prepro digl
59 | python train.py --dataset MNIST --model CTNet --cuda cuda:0
60 | python train.py --dataset MNIST --model GAPNet --derivative laplacian --cuda cuda:0
61 | python train.py --dataset MNIST --model GAPNet --derivative normalized --cuda cuda:0
62 | python train.py --dataset MNIST --model GAPNet --derivative normalizedv2 --cuda cuda:0
63 | python train.py --dataset MNIST --model DiffWire --cuda cuda:0
64 |
65 | python train.py --dataset CIFAR10 --model MinCutNet --cuda cuda:0
66 | python train.py --dataset CIFAR10 --model MinCutNet --cuda cuda:0 --prepro digl
67 | python train.py --dataset CIFAR10 --model CTNet --cuda cuda:0
68 | python train.py --dataset CIFAR10 --model GAPNet --derivative laplacian --cuda cuda:0
69 | python train.py --dataset CIFAR10 --model GAPNet --derivative normalized --cuda cuda:0
70 | python train.py --dataset CIFAR10 --model GAPNet --derivative normalizedv2 --cuda cuda:0
71 | python train.py --dataset CIFAR10 --model DiffWire --cuda cuda:0
72 |
73 | python train.py --dataset CSL --model MinCutNet --cuda cuda:0
74 | python train.py --dataset CSL --model CTNet --cuda cuda:0
75 | python train.py --dataset CSL --model GAPNet --derivative laplacian --cuda cuda:0
76 | python train.py --dataset CSL --model GAPNet --derivative normalized --cuda cuda:0
77 | python train.py --dataset CSL --model GAPNet --derivative normalizedv2 --cuda cuda:0
78 | python train.py --dataset CSL --model DiffWire --cuda cuda:0
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/COLLAB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/COLLAB.png
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/HistAsssortBottleCOLLAB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/HistAsssortBottleCOLLAB.pdf
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/HistAsssortBottleREDDIT-BINARY.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/HistAsssortBottleREDDIT-BINARY.pdf
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/REDDIT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/REDDIT.png
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleCOLLAB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleCOLLAB.pdf
--------------------------------------------------------------------------------
/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleREDDIT-BINARY.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/AssortativeBottleneckAccuracy/ScatterAsssortBottleREDDIT-BINARY.pdf
--------------------------------------------------------------------------------
/figs/CT_REDDIT_READOUT_17_05_22__11_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_17_05_22__11_19.jpg
--------------------------------------------------------------------------------
/figs/CT_REDDIT_READOUT_19_05_22__10_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_19_05_22__10_19.jpg
--------------------------------------------------------------------------------
/figs/CT_REDDIT_READOUT_19_05_22__10_30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_REDDIT_READOUT_19_05_22__10_30.jpg
--------------------------------------------------------------------------------
/figs/CT_RW.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/CT_RW.png
--------------------------------------------------------------------------------
/figs/GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAPLap_REDDIT_READOUT_17_05_22__11_38.jpg
--------------------------------------------------------------------------------
/figs/GAP_REDDIT_READOUT_17_05_22__11_31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAP_REDDIT_READOUT_17_05_22__11_31.jpg
--------------------------------------------------------------------------------
/figs/GAP_REDDIT_READOUT_19_05_22__10_35.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/GAP_REDDIT_READOUT_19_05_22__10_35.jpg
--------------------------------------------------------------------------------
/figs/MinCut_REDDIT_READOUT_17_05_22__11_21.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_17_05_22__11_21.jpg
--------------------------------------------------------------------------------
/figs/MinCut_REDDIT_READOUT_17_05_22__11_22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_17_05_22__11_22.jpg
--------------------------------------------------------------------------------
/figs/MinCut_REDDIT_READOUT_19_05_22__10_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_19_05_22__10_19.jpg
--------------------------------------------------------------------------------
/figs/MinCut_REDDIT_READOUT_19_05_22__10_30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/MinCut_REDDIT_READOUT_19_05_22__10_30.jpg
--------------------------------------------------------------------------------
/figs/REDDIT_CT_READOUT_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/REDDIT_CT_READOUT_1.jpg
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_CIFAR10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_CIFAR10.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_COLLAB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_COLLAB.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_ENZYMES.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_ENZYMES.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_IMDB-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_IMDB-BINARY.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_MNIST.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_MUTAG.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_MUTAG.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_PROTEINS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_PROTEINS.png
--------------------------------------------------------------------------------
/figs/data_statistics/Degree_hist_REDDIT-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/Degree_hist_REDDIT-BINARY.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_CIFAR10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_CIFAR10.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_COLLAB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_COLLAB.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_ENZYMES.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_ENZYMES.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_IMDB-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_IMDB-BINARY.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_MNIST.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_MUTAG.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_MUTAG.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_PROTEINS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_PROTEINS.png
--------------------------------------------------------------------------------
/figs/data_statistics/avg_degree_hist_REDDIT-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/avg_degree_hist_REDDIT-BINARY.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_CIFAR10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_CIFAR10.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_COLLAB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_COLLAB.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_ENZYMES.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_ENZYMES.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_IMDB-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_IMDB-BINARY.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_MNIST.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_MUTAG.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_MUTAG.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_PROTEINS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_PROTEINS.png
--------------------------------------------------------------------------------
/figs/data_statistics/number_nodes_hist_REDDIT-BINARY.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/figs/data_statistics/number_nodes_hist_REDDIT-BINARY.png
--------------------------------------------------------------------------------
/layers/CT_layer.py:
--------------------------------------------------------------------------------
1 | # Commute Times rewiring
2 | #Graph Convolutional Network layer where the graph structure is given by an adjacency matrix.
3 | # We recommend user to use this module when applying graph convolution on dense graphs.
4 | #from torch_geometric.nn import GCNConv, DenseGraphConv
5 | import torch
6 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace
7 |
8 | def dense_CT_rewiring(x, adj, s, mask=None, EPS=1e-15): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20
9 | #print("Input x size to mincut pool", x.size())
10 | x = x.unsqueeze(0) if x.dim() == 2 else x # x torch.Size([20, 40, 32]) if x has not 2 parameters
11 | #print("Unsqueezed x size to mincut pool", x.size(), x.dim()) # x.dim() is usually 3
12 |
13 | # adj torch.Size([20, N, N]) N=Mmax
14 | #print("Input adj size to mincut pool", adj.size())
15 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax
16 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3
17 |
18 | # s torch.Size([20, N, k])
19 | s = s.unsqueeze(0) if s.dim() == 2 else s # s torch.Size([20, N, k])
20 | #print("Unsqueezed s size", s.size())
21 |
22 | # x torch.Size([20, N, 32]) if x has not 2 parameters
23 | (batch_size, num_nodes, _), k = x.size(), s.size(-1)
24 | #print("batch_size and num_nodes", batch_size, num_nodes, k) # batch_size = 20, num_nodes = N, k = 16
25 | s = torch.tanh(s) # torch.Size([20, N, k]) One k for each N of each graph
26 | #print("s softmax size", s.size())
27 |
28 | if mask is not None: # NOT None for now
29 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
30 | #print("mask size", mask.size()) # [20, N, 1]
31 | # Mask pointwise product. Since x is [20, N, 32] and s is [20, N, k]
32 | x, s = x * mask, s * mask # x*mask = [20, N, 32]*[20, N, 1] = [20, N, 32] s*mask = [20, N, k]*[20, N, 1] = [20, N, k]
33 | #print("x and s sizes after multiplying by mask", x.size(), s.size()
34 |
35 | # CT regularization
36 | # Calculate degree d_flat and degree matrix d
37 | d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N])
38 | #print("d_flat size", d_flat.size())
39 | d = _rank3_diag(d_flat)+EPS # d torch.Size([20, N, N])
40 | #print("d size", d.size())
41 |
42 | # Calculate Laplacian L = D - A
43 | L = d - adj
44 | #print("Laplacian", L[1,:,:])
45 |
46 | # Calculate out_adj as A_CT = S.T*L*S
47 | # out_adj: this tensor contains A_CT = S.T*L*S so that we can take its trace and retain coarsened adjacency (Eq. 7)
48 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), L), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes
49 | #print("out_adj size", out_adj.size())
50 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal
51 |
52 | # Calculate CT_num
53 | CT_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph
54 | #print("CT_num size", CT_num.size())
55 | #print("CT_num", CT_num)
56 | # Calculate CT_den
57 | CT_den = _rank3_trace(
58 | torch.matmul(torch.matmul(s.transpose(1, 2), d ), s))+EPS # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph
59 | #print("CT_den size", CT_den.size())
60 | #print("CT_den", CT_den)
61 |
62 | # Calculate CT_dist (distance matrix)
63 | CT_dist = torch.cdist(s,s) # [20, N, k], [20, N, k]-> [20,N,N]
64 | #print("CT_dist",CT_dist)
65 |
66 | # Calculate Vol (volumes): one per graph
67 | vol = _rank3_trace(d) # torch.Size([20])
68 | #print("vol size", vol.size())
69 |
70 |
71 | #print("vol_flat size", vol_flat.size())
72 | vol = _rank3_trace(d) # d torch.Size([20, N, N])
73 | #print("vol size", vol.size())
74 | #print("vol", vol)
75 |
76 | # Calculate out_adj as CT_dist*(N-1)/vol(G)
77 | N = adj.size(1)
78 | #CT_dist = (CT_dist*(N-1)) / vol.unsqueeze(1).unsqueeze(1)
79 | CT_dist = (CT_dist) / vol.unsqueeze(1).unsqueeze(1)
80 | #CT_dist = (CT_dist) / ((N-1)*vol).unsqueeze(1).unsqueeze(1)
81 |
82 | #print("R_dist",CT_dist)
83 |
84 | # Mask with adjacency if proceeds
85 | adj = CT_dist*adj
86 | #adj = CT_dist
87 |
88 | # Losses
89 | CT_loss = CT_num / CT_den
90 | CT_loss = torch.mean(CT_loss) # Mean over 20 graphs!
91 | #print("CT_loss", CT_loss)
92 |
93 | # Orthogonality regularization.
94 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k]
95 | #print("ss size", ss.size())
96 | i_s = torch.eye(k).type_as(ss) # [k, k]
97 | ortho_loss = torch.norm(
98 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
99 | i_s )
100 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph
101 | ortho_loss = torch.mean(ortho_loss)
102 | #print("ortho_loss", ortho_loss)
103 |
104 | return adj, CT_loss, ortho_loss # [20, k, 32], [20, B, N], [1], [1]
105 |
--------------------------------------------------------------------------------
/layers/GAP_layer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from torch_geometric.utils import to_dense_batch, to_dense_adj
4 | from torch_geometric.nn import GCNConv, DenseGraphConv
5 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace
6 | from layers.utils.approximate_fiedler import approximate_Fiedler
7 | from layers.utils.approximate_fiedler import NLderivative_of_lambda2_wrt_adjacency, NLfiedler_values
8 | from layers.utils.approximate_fiedler import derivative_of_lambda2_wrt_adjacency, fiedler_values
9 | from layers.utils.approximate_fiedler import NLderivative_of_lambda2_wrt_adjacencyV2, NLfiedler_valuesV2
10 |
11 | def dense_mincut_rewiring(x, adj, s, mask=None, derivative = None, EPS=1e-15, device=None): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20
12 |
13 | k = 2 #We want bipartition to compute spectral gap
14 | # adj torch.Size([20, N, N]) N=Mmax
15 | #print("Input adj size to mincut pool", adj.size())
16 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax
17 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3
18 |
19 | # s torch.Size([20, N, k])
20 | s = s.unsqueeze(0) if s.dim() == 2 else s #s torch.Size([20, N, k])
21 | #print("Unsqueezed s size", s.size())
22 |
23 | s = torch.softmax(s, dim=-1) # torch.Size([20, N, k]) One k for each N of each graph
24 | #print("s softmax size", s.size())
25 | #print("s softmax", s[0,1,:], torch.argmax(s,dim=(2)).size())
26 |
27 | # Put here the calculus of the degree matrix to optimize the complex derivative
28 | d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N])
29 | #print("d_flat size", d_flat.size())
30 | d = _rank3_diag(d_flat) # d torch.Size([20, N, N])
31 | #print("d size", d.size())
32 |
33 | # Batched Laplacian
34 | L = d - adj
35 |
36 | # REWIRING: UPDATING adj wrt s using derivatives -------------------------------------------------
37 | # Approximating the Fiedler vectors from s (assuming k=2)
38 | fiedlers = approximate_Fiedler(s, device)
39 | #print("fiedlers size", fiedlers.size())
40 | #print("fiedlers ", fiedlers)
41 |
42 | # Recalculate
43 | if derivative == "laplacian":
44 | der = derivative_of_lambda2_wrt_adjacency(fiedlers, device)
45 | fvalues = fiedler_values(adj, fiedlers, EPS, device)
46 | elif derivative == "normalized":
47 | #start = time.time()
48 | der = NLderivative_of_lambda2_wrt_adjacency(adj, d_flat, fiedlers, EPS, device)
49 | fvalues = NLfiedler_values(L, d_flat, fiedlers, EPS, device)
50 | #print('\t\t NLderivative_of_lambda2_wrt_adjacency: {:.6f}s'.format(time.time()- start))
51 | elif derivative == "normalizedv2":
52 | der = NLderivative_of_lambda2_wrt_adjacencyV2(adj, d_flat, fiedlers, EPS, device)
53 | fvalues = NLfiedler_valuesV2(L, d, fiedlers, EPS, device)
54 |
55 | mu = 0.01
56 | lambdaReg = 0.1
57 | lambdaReg = 1.0
58 | lambdaReg = 1.5
59 | lambdaReg = 2.0
60 | lambdaReg = 2.5
61 | lambdaReg = 5.0
62 | lambdaReg = 3.0
63 | lambdaReg = 1.0
64 | lambdaReg = 2.0
65 | #lambdaReg = 20.0
66 |
67 | Ac = adj.clone()
68 | for _ in range(5):
69 | #fvalues = fiedler_values(Ac, fiedlers)
70 | #print("Ac size", Ac.size())
71 | partialJ = 2*(Ac-adj) + 2*lambdaReg*der*fvalues.unsqueeze(1).unsqueeze(2) # favalues is [B], partialJ is [B, N, N]
72 | #print("partialJ size", partialJ.size())
73 | #print("diag size", torch.diag_embed(torch.diagonal(partialJ,dim1=1,dim2=2)).size())
74 | dJ = partialJ + torch.transpose(partialJ,1,2) - torch.diag_embed(torch.diagonal(partialJ,dim1=1,dim2=2))
75 | # Update adjacency
76 | Ac = Ac - mu*dJ
77 | # Clipping: negatives to 0, positives to 1
78 | #print("Ac is", Ac, Ac.size())
79 | #Ac = torch.clamp(Ac, min=0.0, max=1.0)
80 |
81 | #print("Esta es la antigua adj",adj)
82 | #print("Esta es la antigua Ac",Ac)
83 |
84 | #print("Despues mask Ac",Ac)
85 | #print("Despues mask Adj",adj)
86 | #print("Mayores que 0",(Ac>0).sum()) #20,16,40
87 | #print("Menores que 0",(Ac<=0).sum())
88 | Ac = torch.softmax(Ac, dim=-1)
89 | Ac = Ac*adj
90 | #print("Min Fiedlers",min(fvalues))
91 | #print("NUeva salida",Ac)
92 |
93 | # out_adj: this tensor contains Apool = S.T*A*S so that we can take its trace and retain coarsened adjacency (Eq. 7)
94 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes
95 | #print("out_adj size", out_adj.size())
96 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal
97 |
98 | # MinCUT regularization.
99 | mincut_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph
100 | #print("mincut_num size", mincut_num.size())
101 | #d_flat = torch.einsum('ijk->ij', adj) # torch.Size([20, N])
102 | #print("d_flat size", d_flat.size())
103 | #d = _rank3_diag(d_flat) # d torch.Size([20, N, N])
104 | #print("d size", d.size())
105 | mincut_den = _rank3_trace(
106 | torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph
107 | #print("mincut_den size", mincut_den.size())
108 |
109 | mincut_loss = -(mincut_num / mincut_den)
110 | #print("mincut_loss", mincut_loss)
111 | mincut_loss = torch.mean(mincut_loss) # Mean over 20 graphs!
112 |
113 | # Orthogonality regularization.
114 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k]
115 | #print("ss size", ss.size())
116 | i_s = torch.eye(k).type_as(ss) # [k, k]
117 | ortho_loss = torch.norm(
118 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
119 | i_s / torch.norm(i_s), dim=(-1, -2))
120 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph
121 | ortho_loss = torch.mean(ortho_loss)
122 |
123 | """# Fix and normalize coarsened adjacency matrix.
124 | ind = torch.arange(k, device=out_adj.device) # range e.g. from 0 to 15 (k=16)
125 | # out_adj is [20, k, k]
126 | out_adj[:, ind, ind] = 0 # [20, k, k] the diagnonal will be 0 now: Ahat = Apool - I_k*diag(Apool) (Eq. 8)
127 | #print("out_adj", out_adj[0,])
128 |
129 | # Final degree matrix and normalization of out_adj: Ahatpool = Dhat^{-1/2}AhatD^{-1/2} (Eq. 8)
130 | d = torch.einsum('ijk->ij', out_adj) #d torch.Size([20, k])
131 | #print("d size", d.size())
132 | d = torch.sqrt(d)[:, None] + EP S # d torch.Size([20, 1, k])
133 | #print("sqrt(d) size", d.size())
134 | #print( (out_adj / d).shape) # out_adj is [20, k, k] and d is [20, 1, k] -> torch.Size([20, k, k])
135 | out_adj = (out_adj / d) / d.transpose(1, 2) # ([20, k, k] / [20, k, 1] ) -> [20, k, k]
136 | # out_adj torch.Size([20, k, k])
137 | #print("out_adj size", out_adj.size())"""
138 | return Ac, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1]
139 | #return out, out_adj, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1]
--------------------------------------------------------------------------------
/layers/MinCut_Layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_dense_batch, to_dense_adj
3 | from torch_geometric.nn import GCNConv, DenseGraphConv
4 | from layers.utils.ein_utils import _rank3_diag, _rank3_trace
5 |
6 | def dense_mincut_pool(x, adj, s, mask=None, EPS=1e-15): # x torch.Size([20, 40, 32]) ; mask torch.Size([20, 40]) batch_size=20
7 | #print("Input x size to mincut pool", x.size())
8 | x = x.unsqueeze(0) if x.dim() == 2 else x # x torch.Size([20, 40, 32]) if x has not 2 parameters
9 | #print("Unsqueezed x size to mincut pool", x.size(), x.dim()) # x.dim() is usually 3
10 |
11 | # adj torch.Size([20, N, N]) N=Mmax
12 | #print("Input adj size to mincut pool", adj.size())
13 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([20, N, N]) N=Mmax
14 | #print("Unsqueezed adj size to mincut pool", adj.size(), adj.dim()) # adj.dim() is usually 3
15 |
16 | # s torch.Size([20, N, k])
17 | s = s.unsqueeze(0) if s.dim() == 2 else s # s torch.Size([20, N, k])
18 | #print("Unsqueezed s size", s.size())
19 |
20 | # x torch.Size([20, N, 32]) if x has not 2 parameters
21 | (batch_size, num_nodes, _), k = x.size(), s.size(-1)
22 | #print("batch_size and num_nodes", batch_size, num_nodes, k) # batch_size = 20, num_nodes = N, k = 16
23 | s = torch.softmax(s, dim=-1) # torch.Size([20, N, k]) One k for each N of each graph
24 | #print("s softmax size", s.size())
25 |
26 | if mask is not None: # NOT None for now
27 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
28 | #print("mask size", mask.size()) # [20, N, 1]
29 | # Mask pointwise product. Since x is [20, N, 32] and s is [20, N, k]
30 | x, s = x * mask, s * mask # x*mask = [20, N, 32]*[20, N, 1] = [20, N, 32] s*mask = [20, N, k]*[20, N, 1] = [20, N, k]
31 | #print("x and s sizes after multiplying by mask", x.size(), s.size())
32 |
33 | # out: this tensor contains Xpool=S.T*X (Eq. 7)
34 | out = torch.matmul(s.transpose(1, 2), x) # [20, k, N] * [20, N, 32] will yield [20, k, 32]
35 | #print("out size", out.size())
36 | # out_adj: this tensor contains Apool = S.T*A*S so that we can take its trace and retain coarsened adjacency (Eq. 7)
37 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) #[20, k, N]*[20, N, N]-> [20, k ,N]*[20, N, k] = [20, k, k] 20 graphs of k nodes
38 | #print("out_adj size", out_adj.size())
39 | #print("out_adj ", out_adj[0,]) # Has no zeros in the diagonal
40 |
41 | # MinCUT regularization.
42 | mincut_num = _rank3_trace(out_adj) # mincut_num torch.Size([20]) one sum over each graph
43 | #print("mincut_num size", mincut_num.size())
44 | d_flat = torch.einsum('ijk->ij', adj) + EPS # torch.Size([20, N])
45 | #print("d_flat size", d_flat.size())
46 | d = _rank3_diag(d_flat) # d torch.Size([20, N, N])
47 | #print("d size", d.size())
48 | mincut_den = _rank3_trace(
49 | torch.matmul(torch.matmul(s.transpose(1, 2), d), s)) # [20, k, N]*[20, N, N]->[20, k, N]*[20, N, k] -> [20] one sum over each graph
50 | #print("mincut_den size", mincut_den.size())
51 |
52 | mincut_loss = -(mincut_num / mincut_den)
53 | #print("mincut_loss", mincut_loss)
54 | mincut_loss = torch.mean(mincut_loss) # Mean over 20 graphs!
55 |
56 | # Orthogonality regularization.
57 | ss = torch.matmul(s.transpose(1, 2), s) #[20, k, N]*[20, N, k]-> [20, k, k]
58 | #print("ss size", ss.size())
59 | i_s = torch.eye(k).type_as(ss) # [k, k]
60 | ortho_loss = torch.norm(
61 | ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
62 | i_s / torch.norm(i_s), dim=(-1, -2))
63 | #print("ortho_loss size", ortho_loss.size()) # [20] one sum over each graph
64 | ortho_loss = torch.mean(ortho_loss)
65 |
66 | # Fix and normalize coarsened adjacency matrix.
67 | ind = torch.arange(k, device=out_adj.device) # range e.g. from 0 to 15 (k=16)
68 | # out_adj is [20, k, k]
69 | out_adj[:, ind, ind] = 0 # [20, k, k] the diagnonal will be 0 now: Ahat = Apool - I_k*diag(Apool) (Eq. 8)
70 | #print("out_adj", out_adj[0,])
71 |
72 | # Final degree matrix and normalization of out_adj: Ahatpool = Dhat^{-1/2}AhatD^{-1/2} (Eq. 8)
73 | d = torch.einsum('ijk->ij', out_adj) #d torch.Size([20, k])
74 | #print("d size", d.size())
75 | d = torch.sqrt(d+EPS)[:, None] + EPS # d torch.Size([20, 1, k])
76 | #print("sqrt(d) size", d.size())
77 | #print( (out_adj / d).shape) # out_adj is [20, k, k] and d is [20, 1, k] -> torch.Size([20, k, k])
78 | out_adj = (out_adj / d) / d.transpose(1, 2) # ([20, k, k] / [20, k, 1] ) -> [20, k, k]
79 | # out_adj torch.Size([20, k, k])
80 | #print("out_adj size", out_adj.size())
81 | return out, out_adj, mincut_loss, ortho_loss # [20, k, 32], [20, k, k], [1], [1]
--------------------------------------------------------------------------------
/layers/utils/approximate_fiedler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import time
4 |
5 | from layers.utils.ein_utils import _rank3_diag
6 |
7 | def approximate_Fiedler(s, device=None): # torch.Size([20, N, k]) One k for each N of each graph (asume k=2)
8 | """
9 | Calculate approximate fiedler vector from S matrix. S in R^{B x N x 2} and fiedler vector S in R^{B x N}
10 | """
11 | s_0 = s.size(0) #number of graphs
12 | s_1 = s.size(1)
13 | maxcluster = torch.argmax(s,dim=(2)) # torch.Size([20, N]) with binary values {0,1} if k=2
14 | trimmed_s = torch.FloatTensor(s_0,s_1).to(device)
15 | #print('\t'*4,'[DEVICES] s device', s.device,' -- trimmed_s device', trimmed_s.device,' -- maxcluster device', trimmed_s.device)
16 | trimmed_s[maxcluster==1] = -1/np.sqrt(float(s_1))
17 | trimmed_s[maxcluster==0] = 1/np.sqrt(float(s_1))
18 | return trimmed_s
19 |
20 | def NLderivative_of_lambda2_wrt_adjacency(adj, d_flat, fiedlers, EPS, device): # fiedlers torch.Size([20, N])
21 | """
22 | Complex derivative
23 |
24 | Args:
25 | adj (_type_): _description_
26 | d_flat (_type_): _description_
27 | fiedlers (_type_): _description_
28 |
29 | Returns:
30 | _type_: _description_
31 | """
32 | N = fiedlers.size(1)
33 | B = fiedlers.size(0)
34 | # Batched structures for the complex derivative
35 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N])
36 | #print("first d_flat2 size", d_flat2.size())
37 | Ahat = (adj/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N]
38 | AhatT = (adj.transpose(1,2)/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N]
39 | dinv = 1/(d_flat + EPS)[:, None]
40 | dder = -0.5*dinv*d_flat2
41 | dder = dder.transpose(1,2) # [B, N, 1]
42 | # Storage
43 | derivatives = torch.FloatTensor(B, N, N).to(device)
44 |
45 | for b in range(B):
46 | # Eigenvectors
47 | u2 = fiedlers[b,:]
48 | u2 = u2.unsqueeze(1) # [N, 1]
49 | #u2 = u2.to(device) #its already in device because fiedlers is already in device
50 | #print("size of u2", u2.size())
51 |
52 | # First term central: [N,1]x ([1,N]x[N,N]x[N,1]) x [N,1]
53 | firstT = torch.matmul(torch.matmul(u2.T, AhatT[b,:,:]), u2) # [1,N]x[N,N]x[N,1] -> [1]
54 | #print("first term central size", firstT.size())
55 | firstT = torch.matmul(torch.matmul(dder[b,:], firstT), torch.ones(N).unsqueeze(0).to(device))
56 |
57 | # Second term
58 | secT = torch.matmul(torch.matmul(u2.T, Ahat[b,:,:]), u2) # [1,N]x[N,N]x[N,1] -> [1]
59 | #print("second term central size", secT.size())
60 | secT = torch.matmul(torch.matmul(dder[b,:], secT), torch.ones(N).unsqueeze(0).to(device))
61 |
62 | # Third term
63 | u2u2T = torch.matmul(u2,u2.T) # [N,1] x [1,N] -> [N,N]
64 | #print("u2u2T size", u2u2T.size())
65 | #print("d_flat2[b,:] size", d_flat2[b,:].size())
66 | Du2u2TD = (u2u2T / d_flat2[b,:]) / d_flat2[b,:].transpose(0, 1)
67 | #print("size of Du2u2TD", Du2u2TD.size())
68 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL
69 | #dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N)) - u2u2T
70 | dl2 = firstT + secT + Du2u2TD
71 | # Symmetrize and subtract the diag since it is an undirected graph
72 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2))
73 | derivatives[b,:,:] = -dl2
74 | return derivatives # derivatives torch.Size([20, N, N])
75 |
76 | def NLfiedler_values(L, d_flat, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N])
77 | N = fiedlers.size(1)
78 | B = fiedlers.size(0)
79 | #print("original fiedlers size", fiedlers.size())
80 |
81 | # Batched Fiedlers
82 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N])
83 | #print("d_flat2 size", d_flat2.size())
84 | fiedlers = fiedlers.unsqueeze(2) # [B, N, 1]
85 | #print("fiedlers size", fiedlers.size())
86 | fiedlers_hats = (fiedlers/d_flat2.transpose(1, 2)) # [B, N, 1] / [B, N, 1] -> [B, N, 1]
87 | gfiedlers_hats = (fiedlers*d_flat2.transpose(1, 2)) # [B, N, 1] * [B, N, 1] -> [B, N, 1]
88 | #print("fiedlers size", fiedlers_hats.size())
89 | #print("gfiedlers size", gfiedlers_hats.size())
90 |
91 | #Laplacians = torch.FloatTensor(B, N, N)
92 | fiedler_values = torch.FloatTensor(B).to(device)
93 | for b in range(B):
94 | f = fiedlers_hats[b,:]
95 | g = gfiedlers_hats[b,:]
96 | num = torch.matmul(f.T,torch.matmul(L[b,:,:],f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1]
97 | den = torch.matmul(g.T,g)
98 | #print("num fied", num.size())
99 | #print("den fied", den.size())
100 | #print("g size", g.size())
101 | #print("f size", f.size())
102 | #print("L size", L[b,:,:].size())
103 | fiedler_values[b] = N*torch.abs(num/(den + EPS))
104 | return fiedler_values # torch.Size([B])
105 |
106 | def derivative_of_lambda2_wrt_adjacency(fiedlers, device): # fiedlers torch.Size([20, N])
107 | """
108 | Simple derivative
109 | """
110 | N = fiedlers.size(1)
111 | B = fiedlers.size(0)
112 | derivatives = torch.FloatTensor(B, N, N).to(device)
113 | for b in range(B):
114 | u2 = fiedlers[b,:]
115 | u2 = u2.unsqueeze(1)
116 | #print("size of u2", u2.size())
117 | u2u2T = torch.matmul(u2,u2.T)
118 | #print("size of u2u2T", u2u2T.size())
119 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL
120 | dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N).to(device)) - u2u2T
121 | # Symmetrize and subtract the diag since it is an undirected graph
122 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2))
123 | derivatives[b,:,:] = dl2
124 |
125 | return derivatives # derivatives torch.Size([20, N, N])
126 |
127 | def fiedler_values(adj, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N])
128 | N = fiedlers.size(1)
129 | B = fiedlers.size(0)
130 | #Laplacians = torch.FloatTensor(B, N, N)
131 | fiedler_values = torch.FloatTensor(B).to(device)
132 | for b in range(B):
133 | # Compute un-normalized Laplacian
134 | A = adj[b,:,:]
135 | D = A.sum(dim=1)
136 | D = torch.diag(D)
137 | L = D - A
138 | #Laplacians[b,:,:] = L
139 | #if torch.min(A)<0:
140 | # print("Negative adj")
141 | # Compute numerator
142 | f = fiedlers[b,:].unsqueeze(1)
143 | #f = f.to(device)
144 | num = torch.matmul(f.T,torch.matmul(L,f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1]
145 | # Create complete graph Laplacian
146 | CA = torch.ones(N,N).to(device)-torch.eye(N).to(device)
147 | CD = CA.sum(dim=1)
148 | CD = torch.diag(CD)
149 | CL = CD - CA
150 | CL = CL.to(device)
151 | # Compute denominator
152 | den = torch.matmul(f.T,torch.matmul(CL,f))
153 | fiedler_values[b] = N*torch.abs(num/(den + EPS))
154 |
155 | return fiedler_values # torch.Size([B])
156 |
157 |
158 | def NLderivative_of_lambda2_wrt_adjacencyV2(adj, d_flat, fiedlers, EPS, device): # fiedlers torch.Size([20, N])
159 | """
160 | Complex derivative
161 | Args:
162 | adj (_type_): _description_
163 | d_flat (_type_): _description_
164 | fiedlers (_type_): _description_
165 | Returns:
166 | _type_: _description_
167 | """
168 | N = fiedlers.size(1)
169 | B = fiedlers.size(0)
170 | # Batched structures for the complex derivative
171 | d_flat2 = torch.sqrt(d_flat+EPS)[:, None] + EPS # d torch.Size([B, 1, N])
172 | d_flat = d_flat2.squeeze(1)
173 | #print("first d_flat2 size", d_flat2.size())
174 | d_half = _rank3_diag(d_flat) # d torch.Size([B, N, N])
175 | #print("d size", d.size())
176 | Ahat = (adj/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N]
177 | AhatT = (adj.transpose(1,2)/d_flat2.transpose(1, 2)) # [B, N, N] / [B, N, 1] -> [B, N, N]
178 | dinv = 1/(d_flat + EPS)[:, None]
179 | dder = -0.5*dinv*d_flat2
180 | dder = dder.transpose(1,2) # [B, N, 1]
181 | # Storage
182 | derivatives = torch.FloatTensor(B, N, N).to(device)
183 | for b in range(B):
184 | # Eigenvectors
185 | u2 = fiedlers[b,:]
186 | u2 = u2.unsqueeze(1) # [N, 1]
187 | #u2 = u2.to(device) #its already in device because fiedlers is already in device
188 | #print("size of u2", u2.size())
189 | # First term central: [N,1]x ([1,N]x[N,N]x[N,1]) x [N,1]
190 | firstT = torch.matmul(torch.matmul(u2.T, torch.matmul(d_half[b,:,:], AhatT[b,:,:])), u2) # [1,N]x[N,N]x[N,1] -> [1]
191 | #print("first term central size", firstT.size())
192 | firstT = torch.matmul(torch.matmul(dder[b,:], firstT), torch.ones(N).unsqueeze(0).to(device))
193 | #print("first term size", firstT.size())
194 | # Second term
195 | secT = torch.matmul(torch.matmul(u2.T, torch.matmul(d_half[b,:,:], Ahat[b,:,:])), u2) # [1,N]x[N,N]x[N,1] -> [1]
196 | #print("second term central size", secT.size())
197 | secT = torch.matmul(torch.matmul(dder[b,:], secT), torch.ones(N).unsqueeze(0).to(device))
198 | # Third term
199 | Du2u2TD = torch.matmul(u2,u2.T) # [N,1] x [1,N] -> [N,N]
200 | #print("Du2u2T size", u2u2T.size())
201 | #print("d_flat2[b,:] size", d_flat2[b,:].size())
202 | #Du2u2TD = (u2u2T / d_flat2[b,:]) / d_flat2[b,:].transpose(0, 1)
203 | #print("size of Du2u2TD", Du2u2TD.size())
204 | # dl2 = torch.matmul(torch.diag(u2u2T),torch.ones(N,N)) - u2u2T ERROR FUNCTIONAL
205 | #dl2 = torch.matmul(torch.diag(torch.diag(u2u2T)),torch.ones(N,N)) - u2u2T
206 | dl2 = firstT + secT + Du2u2TD
207 | # Symmetrize and subtract the diag since it is an undirected graph
208 | #dl2 = dl2 + dl2.T - torch.diag(torch.diag(dl2))
209 | derivatives[b,:,:] = -dl2
210 | return derivatives # derivatives torch.Size([20, N, N])
211 |
212 |
213 | def NLfiedler_valuesV2(L, d, fiedlers, EPS, device): # adj torch.Size([B, N, N]) fiedlers torch.Size([B, N])
214 | N = fiedlers.size(1)
215 | B = fiedlers.size(0)
216 | #print("original fiedlers size", fiedlers.size())
217 | #print("d size", d.size())
218 | #Laplacians = torch.FloatTensor(B, N, N)
219 | fiedler_values = torch.FloatTensor(B).to(device)
220 | for b in range(B):
221 | f = fiedlers[b,:].unsqueeze(1)
222 | num = torch.matmul(f.T,torch.matmul(L[b,:,:],f)) # f is [N,1], L is [N, N], f.T is [1,N] -> Lf is [N,1] -> f.TLf is [1]
223 | den = torch.matmul(f.T,torch.matmul(d[b,:,:], f)) # d is [N, N], f is [N,1], f.T is [1,N] -> [1,N] x [N, N] x [N, 1] is [1]
224 | fiedler_values[b] = N*torch.abs(num/(den + EPS))
225 | """print(f.shape)
226 | print(f.T.shape)
227 | print(num.shape, num)
228 | print(den.shape, den)
229 | print((N*torch.abs(num/(den + EPS))).shape)
230 | exit()"""
231 | return fiedler_values # torch.Size([B])
232 |
233 |
--------------------------------------------------------------------------------
/layers/utils/ein_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # Trace of a tensor [1,k,k]
3 | def _rank3_trace(x):
4 | return torch.einsum('ijj->i', x)
5 |
6 | # Diagonal version of a tensor [1,n] -> [1,n,n]
7 | def _rank3_diag(x):
8 | # Eye matrix of n=x.size(1): [n,n]
9 | eye = torch.eye(x.size(1)).type_as(x)
10 | #print(eye.size())
11 | #print(x.unsqueeze(2).size())
12 | # x.unsqueeze(2) adds a second dimension to [1,n] -> [1,n,1]
13 | # expand(*x.size(), x.size(1)) takes [1,n,1] and expands [1,n] with n -> [1,n,n]
14 | out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1))
15 | return out
--------------------------------------------------------------------------------
/layers/utils/spectral_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_dense_adj
3 |
4 | def unnormalized_laplacian_eigenvectors(L): #
5 | el, ev = torch.linalg.eig(L)
6 | el = torch.real(el)
7 | ev = torch.real(ev)
8 | idx = torch.argsort(el)
9 | el = el[idx]
10 | ev = ev[:,idx]
11 | #print("L", L, el[1])
12 | return el, ev
13 |
14 |
15 | # Compute Fiedler values of all the graphs
16 | def compute_fiedler_vectors(dataset):
17 | """
18 | Calculate fieldver vector for all graphs in dataset
19 | """
20 | vectors = []
21 | values = []
22 | for g in range(len(dataset)):
23 | G = dataset[g]
24 | #print(G)
25 | adj = to_dense_adj(G.edge_index)
26 | # adj is [1, N, N]
27 | adj = adj.squeeze(0)
28 | # adj is [N, N]
29 | #print(adj.size())
30 | A = adj
31 | D = A.sum(dim=1)
32 | D = torch.diag(D)
33 | L = D - A
34 | #print(L)
35 | el, ev = unnormalized_laplacian_eigenvectors(L)
36 | vectors.append(ev[:,1])
37 | values.append(el[1])
38 |
39 | return values, vectors
40 |
41 |
42 |
--------------------------------------------------------------------------------
/nets.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn import Linear
5 | from torch_geometric.nn import DenseGraphConv
6 | from torch_geometric.utils import to_dense_batch, to_dense_adj
7 | from layers.CT_layer import dense_CT_rewiring
8 | from layers.MinCut_Layer import dense_mincut_pool
9 | from layers.GAP_layer import dense_mincut_rewiring
10 |
11 |
12 | class GAPNet(torch.nn.Module):
13 | def __init__(self, in_channels, out_channels, hidden_channels=32, derivative=None, EPS=1e-15, device=None):
14 | super(GAPNet, self).__init__()
15 | self.device = device
16 | self.derivative = derivative
17 | self.EPS = EPS
18 | # GCN Layer - MLP - Dense GCN Layer
19 | #self.conv1 = GCNConv(in_channels, hidden_channels)
20 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels)
21 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
22 | num_of_centers2 = 16 # k2
23 | #num_of_centers2 = 10 # k2
24 | #num_of_centers2 = 5 # k2
25 | num_of_centers1 = 2 # k1 #Fiedler vector
26 | # The degree of the node belonging to any of the centers
27 | self.pool1 = Linear(hidden_channels, num_of_centers1)
28 | self.pool2 = Linear(hidden_channels, num_of_centers2)
29 | # MLPs towards out
30 | self.lin1 = Linear(in_channels, hidden_channels)
31 | self.lin2 = Linear(hidden_channels, hidden_channels)
32 | self.lin3 = Linear(hidden_channels, out_channels)
33 |
34 | # Input: Batch of 20 graphs, each node F=3 features
35 | # N1 + N2 + ... + N2 = 661
36 | # TSNE here?
37 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661])
38 |
39 | # Make all adjacencies of size NxN
40 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N])
41 | #print("adj_size", adj.size())
42 | #print("adj",adj)
43 |
44 |
45 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40:
46 | #print("x size", x.size())
47 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20
48 | #print("x size", x.size())
49 |
50 | x = self.lin1(x)
51 | # First mincut pool for computing Fiedler adn rewire
52 | s1 = self.pool1(x)
53 | #s1 = torch.variable()#s1 torch.Size([20, N, k1=2)
54 | #s1 = Variable(torch.randn(D_in, H).type(float16), requires_grad=True)
55 | #print("s 1st pool",s1)
56 | #print("s 1st pool size", s1.size())
57 |
58 | if torch.isnan(adj).any():
59 | print("adj nan")
60 | if torch.isnan(x).any():
61 | print("x nan")
62 |
63 |
64 | # REWIRING
65 | #start = time.time()
66 | adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask, derivative = self.derivative, EPS=self.EPS, device=self.device) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N])
67 | #print('\t\tdense_mincut_rewiring: {:.6f}s'.format(time.time()- start))
68 | #print("x",x)
69 | #print("adj",adj)
70 | #print("x and adj sizes", x.size(), adj.size())
71 | #adj = torch.softmax(adj, dim=-1)
72 | #print("adj softmaxed", adj)
73 |
74 | # CONV1: Now on x and rewired adj:
75 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32])
76 | #print("x_1 ", x)
77 | #print("x_1 size", x.size())
78 |
79 | # MLP of k=16 outputs s
80 | #print("adj_size", adj.size())
81 | s2 = self.pool2(x) # s torch.Size([20, N, k])
82 | #print("s 2nd pool", s2)
83 | #print("s 2nd pool size", s2.size())
84 | #adj = torch.softmax(adj, dim=-1)
85 |
86 |
87 | # MINCUT_POOL
88 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16
89 | #x, adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask) # x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
90 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
91 | #print("lossses2",mincut_loss2, ortho_loss2)
92 | #print("mincut pool x", x)
93 | #print("mincut pool adj", adj)
94 | #print("mincut pool x size", x.size())
95 | #print("mincut pool adj size", adj.size()) # Some nan in adjacency: maybe comming from the rewiring-> dissapear after clipping
96 |
97 |
98 | # CONV2: Now on coarsened x and adj:
99 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32])
100 | #print("x_2", x)
101 | #print("x_2 size", x.size())
102 |
103 | # Readout for each of the 20 graphs
104 | #x = x.mean(dim=1) # x torch.Size([20, 32])
105 | x = x.sum(dim=1) # x torch.Size([20, 32])
106 | #print("mean x_2 size", x.size())
107 |
108 | # Final MLP for graph classification: hidden channels = 32
109 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32])
110 | #print("final x1 size", x.size())
111 | x = self.lin3(x) #x torch.Size([20, 2])
112 | #print("final x2 size", x.size())
113 | #print("losses: ", mincut_loss1, mincut_loss2, ortho_loss2, mincut_loss2)
114 | mincut_loss = mincut_loss1 + mincut_loss2
115 | ortho_loss = ortho_loss1 + ortho_loss2
116 | #print("x", x)
117 | return F.log_softmax(x, dim=-1), mincut_loss, ortho_loss
118 |
119 |
120 | class CTNet(torch.nn.Module):
121 | def __init__(self, in_channels, out_channels, k_centers, hidden_channels=32, EPS=1e-15):
122 | super(CTNet, self).__init__()
123 | self.EPS=EPS
124 | # GCN Layer - MLP - Dense GCN Layer
125 | #self.conv1 = GCNConv(in_channels, hidden_channels)
126 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels)
127 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
128 |
129 | # The degree of the node belonging to any of the centers
130 | num_of_centers1 = k_centers # k1 #order of number of nodes
131 | self.pool1 = Linear(hidden_channels, num_of_centers1)
132 | num_of_centers2 = 16 # k2 #mincut
133 | self.pool2 = Linear(hidden_channels, num_of_centers2)
134 |
135 | # MLPs towards out
136 | self.lin1 = Linear(in_channels, hidden_channels)
137 | self.lin2 = Linear(hidden_channels, hidden_channels)
138 | self.lin3 = Linear(hidden_channels, out_channels)
139 |
140 |
141 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661])
142 | # Make all adjacencies of size NxN
143 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N])
144 | #print("adj_size", adj.size())
145 | #print("adj",adj)
146 |
147 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40:
148 | #print("x size", x.size())
149 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20
150 | #print("x size", x.size())
151 |
152 | x = self.lin1(x)
153 | # First mincut pool for computing Fiedler adn rewire
154 | s1 = self.pool1(x)
155 | #s1 = torch.variable()#s1 torch.Size([20, N, k1=2)
156 | #s1 = Variable(torch.randn(D_in, H).type(float16), requires_grad=True)
157 | #print("s 1st pool",s1)
158 | #print("s 1st pool size", s1.size())
159 |
160 | if torch.isnan(adj).any():
161 | print("adj nan")
162 | if torch.isnan(x).any():
163 | print("x nan")
164 |
165 | # CT REWIRING
166 | adj, CT_loss, ortho_loss1 = dense_CT_rewiring(x, adj, s1, mask, EPS = self.EPS) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N])
167 |
168 | #print("CT_loss, ortho_loss1", CT_loss, ortho_loss1)
169 | #print("x",x)
170 | #print("adj",adj)
171 | #print("x and adj sizes", x.size(), adj.size())
172 | #adj = torch.softmax(adj, dim=-1)
173 | #print("adj softmaxed", adj)
174 |
175 | # CONV1: Now on x and rewired adj:
176 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32])
177 | #print("x_1 ", x)
178 | #print("x_1 size", x.size())
179 |
180 | # MLP of k=16 outputs s
181 | #print("adj_size", adj.size())
182 | s2 = self.pool2(x) # s torch.Size([20, N, k])
183 | #print("s 2nd pool", s2)
184 | #print("s 2nd pool size", s2.size())
185 | #adj = torch.softmax(adj, dim=-1)
186 |
187 |
188 | # MINCUT_POOL
189 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16
190 | #x, adj, mincut_loss1, ortho_loss1 = dense_mincut_rewiring(x, adj, s1, mask) # x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
191 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
192 | #print("lossses2",mincut_loss2, ortho_loss2)
193 | #print("mincut pool x", x)
194 | #print("mincut pool adj", adj)
195 | #print("mincut pool x size", x.size())
196 | #print("mincut pool adj size", adj.size()) # Some nan in adjacency: maybe comming from the rewiring-> dissapear after clipping
197 |
198 |
199 | # CONV2: Now on coarsened x and adj:
200 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32])
201 | #print("x_2", x)
202 | #print("x_2 size", x.size())
203 |
204 | # Readout for each of the 20 graphs
205 | #x = x.mean(dim=1) # x torch.Size([20, 32])
206 | x = x.sum(dim=1) # x torch.Size([20, 32])
207 | #print("mean x_2 size", x.size())
208 |
209 | # Final MLP for graph classification: hidden channels = 32
210 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32])
211 | #print("final x1 size", x.size())
212 | x = self.lin3(x) #x torch.Size([20, 2])
213 | #print("final x2 size", x.size())
214 | CT_loss = CT_loss + ortho_loss1
215 | mincut_loss = mincut_loss2 + ortho_loss2
216 | #print("x", x)
217 | return F.log_softmax(x, dim=-1), CT_loss, mincut_loss
218 |
219 |
220 | class MinCutNet(torch.nn.Module):
221 | def __init__(self, in_channels, out_channels, hidden_channels=32, EPS=1e-15):
222 | super(MinCutNet, self).__init__()
223 | self.EPS=EPS
224 | # GCN Layer - MLP - Dense GCN Layer
225 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels)
226 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
227 |
228 | # The degree of the node belonging to any of the centers
229 | num_of_centers2 = 16 # k2 #mincut
230 | self.pool2 = Linear(hidden_channels, num_of_centers2)
231 |
232 | # MLPs towards out
233 | self.lin1 = Linear(in_channels, hidden_channels)
234 | self.lin2 = Linear(hidden_channels, hidden_channels)
235 | self.lin3 = Linear(hidden_channels, out_channels)
236 |
237 |
238 | def forward(self, x, edge_index, batch): # x torch.Size([N, N]), data.batch torch.Size([661])
239 |
240 | # Make all adjacencies of size NxN
241 | adj = to_dense_adj(edge_index, batch) # adj torch.Size(B, N, N])
242 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40:
243 | x, mask = to_dense_batch(x, batch) # x torch.Size([20, N, 32]) ; mask torch.Size([20, N]) batch_size=20
244 |
245 | x = self.lin1(x)
246 |
247 | if torch.isnan(adj).any():
248 | print("adj nan")
249 | if torch.isnan(x).any():
250 | print("x nan")
251 |
252 | # CONV1: Now on x and rewired adj:
253 | x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32])
254 |
255 | # MLP of k=16 outputs s
256 | s2 = self.pool2(x) # s torch.Size([20, N, k])
257 |
258 | # MINCUT_POOL
259 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16
260 | x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
261 |
262 | # CONV2: Now on coarsened x and adj:
263 | x = self.conv2(x, adj) #out x torch.Size([20, 16, 32])
264 |
265 | # Readout for each of the 20 graphs
266 | #x = x.mean(dim=1) # x torch.Size([20, 32])
267 | x = x.sum(dim=1) # x torch.Size([20, 32])
268 | # Final MLP for graph classification: hidden channels = 32
269 | x = F.relu(self.lin2(x)) # x torch.Size([20, 32])
270 | x = self.lin3(x) #x torch.Size([20, 2])
271 |
272 | mincut_loss = mincut_loss2 + ortho_loss2
273 | #print("x", x)
274 | return F.log_softmax(x, dim=-1), mincut_loss2, ortho_loss2
275 |
276 |
277 | class DiffWire(torch.nn.Module):
278 | def __init__(self, in_channels, out_channels, k_centers, derivative=None, hidden_channels=32, EPS=1e-15, device=None):
279 | super(DiffWire, self).__init__()
280 |
281 | self.EPS=EPS
282 | self.derivative = derivative
283 | self.device=device
284 |
285 | # First X transformation
286 | self.lin1 = Linear(in_channels, hidden_channels)
287 |
288 | #Fiedler vector -- Pool previous to GAP-Layer
289 | self.pool_rw = Linear(hidden_channels, 2)
290 |
291 | #CT Embedding -- Pool previous to CT-Layer
292 | self.num_of_centers1 = k_centers # k1 - order of number of nodes
293 | self.pool_ct = Linear(hidden_channels, self.num_of_centers1) #CT
294 |
295 | #Conv1
296 | self.conv1 = DenseGraphConv(hidden_channels, hidden_channels)
297 |
298 | #MinCutPooling
299 | self.pool_mc = Linear(hidden_channels, 16) #MC
300 |
301 | #Conv2
302 | self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
303 |
304 | # MLPs towards out
305 | self.lin2 = Linear(hidden_channels, hidden_channels)
306 | self.lin3 = Linear(hidden_channels, out_channels)
307 |
308 |
309 | def forward(self, x, edge_index, batch):
310 | # Make all adjacencies of size NxN
311 | adj = to_dense_adj(edge_index, batch)
312 | # Make all x_i of size N=MAX(N1,...,N20), e.g. N=40:
313 | x, mask = to_dense_batch(x, batch)
314 |
315 | x = self.lin1(x)
316 |
317 | if torch.isnan(adj).any():
318 | print("adj nan")
319 | if torch.isnan(x).any():
320 | print("x nan")
321 |
322 | #Gap Layer RW
323 | s0 = self.pool_rw(x)
324 | adj, mincut_loss_rw, ortho_loss_rw = dense_mincut_rewiring(x, adj, s0, mask,
325 | derivative = self.derivative, EPS=self.EPS, device=self.device)
326 |
327 | # CT REWIRING
328 | # First mincut pool for computing Fiedler adn rewire
329 | s1 = self.pool_ct(x)
330 | adj, CT_loss, ortho_loss_ct = dense_CT_rewiring(x, adj, s1, mask, EPS = self.EPS) # out: x torch.Size([20, N, F'=32]), adj torch.Size([20, N, N])
331 |
332 | # CONV1: Now on x and rewired adj:
333 | x = self.conv1(x, adj)
334 |
335 | # MINCUT_POOL
336 | # MLP of k=16 outputs s
337 | s2 = self.pool_mc(x)
338 | # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16
339 | x, adj, mincut_loss, ortho_loss_mc = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]), adj torch.Size([20, k2=16, k2=16])
340 |
341 | # CONV2: Now on coarsened x and adj:
342 | x = self.conv2(x, adj)
343 |
344 | # Readout for each of the 20 graphs
345 | x = x.sum(dim=1)
346 | # Final MLP for graph classification: hidden channels = 32
347 | x = F.relu(self.lin2(x))
348 | x = self.lin3(x)
349 |
350 |
351 | main_loss = mincut_loss_rw + CT_loss + mincut_loss
352 | ortho_loss = ortho_loss_rw + ortho_loss_ct + ortho_loss_mc
353 | #ortho_loss_rw/2 + (1/self.num_of_centers1)*ortho_loss_ct + ortho_loss_mc/16
354 | #print("x", x)
355 | return F.log_softmax(x, dim=-1), main_loss, ortho_loss
356 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # DiffWire: Inductive Graph Rewiring via the Lovász Bound
2 |
3 | **Accepted at the First Learning on Graphs Conference 2022**
4 |
5 | [](https://openreview.net/forum?id=IXvfIex0mX6f¬eId=t5zJZuEIy1y)
6 | [](https://paperswithcode.com/sota/graph-classification-on-imdb-binary?p=diffwire-inductive-graph-rewiring-via-the)
7 |
8 |
9 |
10 |
11 |
12 | $$
13 | \left| \frac{1}{vol(G)}CT_{uv}-\left(\frac{1}{d_u} + \frac{1}{d_v}\right)\right|\le \frac{1}{\lambda_2}\frac{2}{d_{min}}
14 | $$
15 |
16 | ```bibtex
17 | @InProceedings{arnaiz2022diffwire,
18 | title = {{DiffWire: Inductive Graph Rewiring via the Lov{\'a}sz Bound}},
19 | author = {Arnaiz-Rodr{\'i}guez, Adri{\'a}n and Begga, Ahmed and Escolano, Francisco and Oliver, Nuria M},
20 | booktitle = {Proceedings of the First Learning on Graphs Conference},
21 | pages = {15:1--15:27},
22 | year = {2022},
23 | editor = {Rieck, Bastian and Pascanu, Razvan},
24 | volume = {198},
25 | series = {Proceedings of Machine Learning Research},
26 | month = {09--12 Dec},
27 | publisher = {PMLR},
28 | pdf = {https://proceedings.mlr.press/v198/arnaiz-rodri-guez22a/arnaiz-rodri-guez22a.pdf},
29 | url = {https://proceedings.mlr.press/v198/arnaiz-rodri-guez22a.html}
30 | }
31 | ```
32 |
33 | ## Dependencies
34 |
35 | Conda environment
36 | ```
37 | conda create --name --file requirements.txt
38 | ```
39 |
40 | or
41 |
42 | ```
43 | conda env create -f environment_experiments.yml
44 | conda activate DiffWire
45 | ```
46 | ## Code organization
47 |
48 | * `datasets/`: script for creating synthetic datasets. For non-synthetic ones: we use PyG in `train.py`
49 | * `layers/`: Implementation of the proposed **GAP-Layer**, **CT-Layer**, and the baseline MinCutPool (based on his repo).
50 | * `tranforms/`: Implementation og graph preprocessing baselines DIGL and SDRF, both based on the official repositories of the work.
51 | * `trained_models/`: files with the weight of some trained models.
52 | * `nets.py`: Implementation of GNNs used in our experiments.
53 | * `train.py`: Script with inline arguments for running the experiments.
54 |
55 | ## Run experiments
56 | ```python
57 | python train.py --dataset REDDIT-BINARY --model CTNet --cuda cuda:0
58 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative laplacian --cuda cuda:0
59 | python train.py --dataset REDDIT-BINARY --model GAPNet --derivative normalizeed --cuda cuda:0
60 | ```
61 |
62 | `experiments_all.sh` list all the experiments.
63 |
64 |
65 | ## Code Examples
66 |
67 | See jupyter notebook examples at the tutorial presented at ***The First Learning on Graphs Conference***: **[Graph Rewiring: From Theory to Applications in Fairness](https://github.com/ellisalicante/GraphRewiring-Tutorial)**
68 |
69 | * CT-Layer [](https://colab.research.google.com/github/ellisalicante/GraphRewiring-Tutorial/blob/main/3-Inductive-Rewiring-CTLayer.ipynb)
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=conda_forge
5 | _openmp_mutex=4.5=2_kmp_llvm
6 | blas=2.114=mkl
7 | blas-devel=3.9.0=14_linux64_mkl
8 | ca-certificates=2021.10.8=ha878542_0
9 | certifi=2021.10.8=pypi_0
10 | charset-normalizer=2.0.12=pypi_0
11 | cudatoolkit=11.5.1=hcf5317a_10
12 | cycler=0.11.0=pypi_0
13 | fonttools=4.33.3=pypi_0
14 | freetype=2.10.4=h0708190_1
15 | giflib=5.2.1=h36c2ea0_2
16 | idna=3.3=pypi_0
17 | jbig=2.1=h7f98852_2003
18 | jinja2=3.1.2=pypi_0
19 | joblib=1.1.0=pypi_0
20 | jpeg=9e=h166bdaf_1
21 | kiwisolver=1.4.2=pypi_0
22 | lcms2=2.12=hddcbb42_0
23 | ld_impl_linux-64=2.36.1=hea4e1c9_2
24 | lerc=3.0=h9c3ff4c_0
25 | libblas=3.9.0=14_linux64_mkl
26 | libcblas=3.9.0=14_linux64_mkl
27 | libdeflate=1.10=h7f98852_0
28 | libffi=3.4.2=h7f98852_5
29 | libgcc-ng=11.2.0=h1d223b6_16
30 | libgfortran-ng=11.2.0=h69a702a_16
31 | libgfortran5=11.2.0=h5c6108e_16
32 | libgomp=11.2.0=h1d223b6_16
33 | liblapack=3.9.0=14_linux64_mkl
34 | liblapacke=3.9.0=14_linux64_mkl
35 | libnsl=2.0.0=h7f98852_0
36 | libpng=1.6.37=h21135ba_2
37 | libstdcxx-ng=11.2.0=he4da1e4_16
38 | libtiff=4.3.0=h542a066_3
39 | libuv=1.43.0=h7f98852_0
40 | libwebp=1.2.2=h3452ae3_0
41 | libwebp-base=1.2.2=h7f98852_1
42 | libxcb=1.13=h7f98852_1004
43 | libzlib=1.2.11=h166bdaf_1014
44 | llvm-openmp=14.0.3=he0ac6c6_0
45 | lz4-c=1.9.3=h9c3ff4c_1
46 | markupsafe=2.1.1=pypi_0
47 | matplotlib=3.5.2=pypi_0
48 | mkl=2022.0.1=h8d4b97c_803
49 | mkl-devel=2022.0.1=ha770c72_804
50 | mkl-include=2022.0.1=h8d4b97c_803
51 | ncurses=6.3=h27087fc_1
52 | numpy=1.22.3=py38h99721a1_2
53 | openjpeg=2.4.0=hb52868f_1
54 | openssl=3.0.3=h166bdaf_0
55 | packaging=21.3=pypi_0
56 | pandas=1.4.2=pypi_0
57 | pillow=9.1.0=py38h0ee0e06_2
58 | pip=22.0.4=pyhd8ed1ab_0
59 | pthread-stubs=0.4=h36c2ea0_1001
60 | pyparsing=3.0.8=pypi_0
61 | python=3.8.12=h0744224_3_cpython
62 | python-dateutil=2.8.2=pypi_0
63 | python_abi=3.8=2_cp38
64 | pytorch=1.11.0=py3.8_cuda11.5_cudnn8.3.2_0
65 | pytorch-mutex=1.0=cuda
66 | pytz=2022.1=pypi_0
67 | readline=8.1=h46c0cb4_0
68 | requests=2.27.1=pypi_0
69 | scikit-learn=1.0.2=pypi_0
70 | scipy=1.8.0=pypi_0
71 | setuptools=62.1.0=py38h578d9bd_0
72 | six=1.16.0=pyh6c4a22f_0
73 | sqlite=3.38.5=h4ff8645_0
74 | tbb=2021.5.0=h924138e_1
75 | threadpoolctl=3.1.0=pypi_0
76 | tk=8.6.12=h27826a3_0
77 | torch-cluster=1.6.0=pypi_0
78 | torch-geometric=2.0.4=pypi_0
79 | torch-scatter=2.0.9=pypi_0
80 | torch-sparse=0.6.13=pypi_0
81 | torch-spline-conv=1.2.1=pypi_0
82 | torchaudio=0.11.0=py38_cu115
83 | torchvision=0.2.2=py_3
84 | tqdm=4.64.0=pypi_0
85 | typing_extensions=4.2.0=pyha770c72_1
86 | urllib3=1.26.9=pypi_0
87 | wheel=0.37.1=pyhd8ed1ab_0
88 | xorg-libxau=1.0.9=h7f98852_0
89 | xorg-libxdmcp=1.1.3=h7f98852_0
90 | xz=5.2.5=h516909a_1
91 | zlib=1.2.11=h166bdaf_1014
92 | zstd=1.5.2=ha95c52a_0
93 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | import os
3 | import random
4 |
5 | from sklearn.model_selection import train_test_split
6 | from datasets.Erdos_Renyi_dataset import Erdos_Renyi_pyg
7 | from datasets.SBM_dataset import SBM_pyg
8 | from nets import CTNet, DiffWire, GAPNet, MinCutNet
9 | import torch
10 | import torch.nn.functional as F
11 | from torch_geometric.datasets import GNNBenchmarkDataset, TUDataset
12 | import torch_geometric.transforms as T
13 | from transforms import FeatureDegree, DIGLedges, SDRF, KNNGraph
14 | from torch_geometric.loader import DataLoader
15 | from torch_geometric.utils import to_dense_batch, to_dense_adj
16 | import time
17 | import argparse
18 | import numpy as np
19 | '''
20 | Dataset arguments:
21 | -Name:
22 | *TUDataset:
23 | + No features: [REDDEDIT BINARY, IMBD BINARY, COLLAB]
24 | + Featured: [MUTAG, ENZYMES, PROTEINS]
25 | *Benchmark: [MNIST,CIFAR10]
26 | Model arguments:
27 | -Name:
28 | MC(num_features,num_classes)
29 | GAPNet(new_num_features, dataset.new_num_classes,derivative)
30 | CTNet(num_features,num_classes)
31 | Other arguments:
32 | Lr
33 | weight decay
34 |
35 | '''
36 |
37 | ################### Arguments parameters ###################################
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument(
40 | "--dataset",
41 | default="MUTAG",
42 | choices=["MUTAG","ENZYMES","PROTEINS","CIFAR10","MNIST","COLLAB","IMDB-BINARY","REDDIT-BINARY","CSL", "SBM", "ERDOS"],
43 | help="nada",
44 | )
45 | parser.add_argument(
46 | "--model",
47 | default="CTNet",
48 | choices=["CTNet","GAPNet","MinCutNet", "DiffWire"],
49 | help="nada",
50 | )
51 | parser.add_argument(
52 | "--derivative",
53 | default="laplacian",
54 | choices=["laplacian","normalizedv2","normalized"], #,"normalized"
55 | help="Only used if model is GAP",
56 | )
57 | parser.add_argument(
58 | "--cuda",
59 | default="cuda:0",
60 | choices=["cuda:0","cuda:1"],
61 | help="cuda version",
62 | )
63 | parser.add_argument(
64 | "--prepro",
65 | default=None,
66 | choices=[None,"digl", "sdrf", "knn"],
67 | help="preprocessing",
68 | )
69 | parser.add_argument(
70 | "--store",
71 | action="store_true",
72 | help="nada",
73 | )
74 | parser.add_argument(
75 | "--iter",
76 | type=int,
77 | default=10,
78 | help="The number of games to simulate"
79 | )
80 | parser.add_argument(
81 | "--logs",
82 | default="logs",
83 | help="log folders",
84 | )
85 | parser.add_argument(
86 | "--lr", type=float, default=5e-4, help="Outer learning rate of model"
87 | )
88 | parser.add_argument(
89 | "--wd", type=float, default=1e-4, help="Outer weight decay rate of model"
90 | )
91 | args = parser.parse_args()
92 |
93 | #Procesing dataset
94 | No_Features = ["COLLAB","IMDB-BINARY","REDDIT-BINARY", "CSL"]
95 |
96 | if args.prepro == 'digl':
97 | preprocessing = DIGLedges(alpha=0.001)
98 | aux_prepro_folder = "/DIGL" if args.prepro == "digl" else ""
99 |
100 | elif args.prepro == 'sdrf':
101 | preprocessing = SDRF(undirected = True, max_steps="dynamic", tau = 20,
102 | remove_edges = True, removal_bound = 0)
103 | aux_prepro_folder = "/SDRF" if args.prepro == "sdrf" else ""
104 | elif args.prepro == 'knn':
105 | preprocessing = KNNGraph(
106 | k=None,
107 | force_undirected=True
108 | )
109 | aux_prepro_folder = "/KNN" if args.prepro == "knn" else ""
110 |
111 | elif args.prepro is None:
112 | aux_prepro_folder = ""
113 | preprocessing = None
114 |
115 | else:
116 | raise NotImplementedError("Not implemented preprocessing")
117 |
118 | if args.dataset == "SBM":
119 | dataset = SBM_pyg('./data/SBM_final'+aux_prepro_folder, nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500,
120 | p1=0.8, p2=0.5, qmin1=0.1, qmax1=0.15, qmin2=0.01, qmax2=0.1,
121 | directed=False, pre_transform=preprocessing)
122 | TRAIN_SPLIT = 800
123 | BATCH_SIZE = 32
124 | num_of_centers = 200
125 |
126 | elif args.dataset == "ERDOS":
127 | dataset = Erdos_Renyi_pyg('./data/SBM_final', nb_nodes1=200, nb_graphs1=500, nb_nodes2=200, nb_graphs2=500,
128 | p1_min=0.4, p1_max=0.6, p2_min=0.5, p2_max=0.8)
129 | TRAIN_SPLIT = 800
130 | BATCH_SIZE = 32
131 | num_of_centers = 200
132 |
133 | elif args.dataset not in GNNBenchmarkDataset.names:
134 | if args.dataset not in No_Features: #Features
135 | dataset = TUDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'TUDataset', name=args.dataset, pre_transform=preprocessing)
136 | if args.dataset =="MUTAG": # 188 graphs
137 | TRAIN_SPLIT = 150
138 | BATCH_SIZE = 32
139 | num_of_centers = 17 #mean number of nodes according to PyGeom
140 | if args.dataset =="ENZYMES": # 600 graphs
141 | TRAIN_SPLIT = 500
142 | BATCH_SIZE = 32
143 | num_of_centers = 16 #mean number of nodes according to PyGeom
144 | if args.dataset =="PROTEINS": # 1113 graphs
145 | TRAIN_SPLIT = 1000
146 | BATCH_SIZE = 64
147 | num_of_centers = 39 #mean number of nodes according to PyGeom
148 | else: #No Features
149 | if args.prepro is not None:
150 | preprocessing = preprocessing
151 | processing = FeatureDegree()
152 | else:
153 | preprocessing = FeatureDegree()
154 | processing = None
155 | dataset = TUDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'TUDataset',name=args.dataset,
156 | pre_transform=preprocessing, transform = processing, use_node_attr=True)
157 | #dataset = TUDatasetFeatures(root='data/TUDataset', name=args.dataset,dataset=datasetGNN)
158 | if args.dataset =="IMDB-BINARY": # 1000 graphs
159 | TRAIN_SPLIT = 800
160 | BATCH_SIZE = 64
161 | num_of_centers = 20 #mean number of nodes according to PyGeom
162 | elif args.dataset == "REDDIT-BINARY": # 2000 graphs
163 | TRAIN_SPLIT = 1500
164 | BATCH_SIZE = 64
165 | num_of_centers = 420 #mean number of nodes according to PyGeom
166 | elif args.dataset == "COLLAB": # 2000 graphs
167 | TRAIN_SPLIT = 4500
168 | BATCH_SIZE = 64
169 | num_of_centers = 75 #mean number of nodes according to PyGeom
170 | else:
171 | raise Exception("Not dataset in list of datasets")
172 | else: #GNNBenchmarkDataset
173 | if args.dataset in No_Features:
174 | if args.prepro is not None:
175 | preprocessing = preprocessing
176 | processing = FeatureDegree()
177 | else:
178 | preprocessing = FeatureDegree()
179 | processing = None
180 |
181 | dataset = GNNBenchmarkDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'GNNBenchmarkDataset', name=args.dataset,
182 | pre_transform=preprocessing, transform = processing)
183 | if args.dataset =="CSL":
184 | TRAIN_SPLIT = 120
185 | BATCH_SIZE = 10
186 | num_of_centers = 42
187 | else:
188 | dataset = GNNBenchmarkDataset(root='data'+os.sep+aux_prepro_folder+os.sep+'GNNBenchmarkDataset', name=args.dataset, pre_transform=preprocessing) #MNISTo CIFAR10
189 | if args.dataset =="MNIST":
190 | TRAIN_SPLIT = 50000
191 | BATCH_SIZE = 100
192 | num_of_centers = 100
193 | elif args.dataset == "CIFAR10":
194 | TRAIN_SPLIT = 40000
195 | BATCH_SIZE = 100
196 | num_of_centers = 100
197 |
198 | ##################### STATIC Variables #################################
199 |
200 | device = args.cuda
201 |
202 | N_EPOCH = 60
203 |
204 | exp_name = f"{args.dataset}"
205 | exp_name = exp_name + f"{dataset.details}" if args.dataset == "SBM" else exp_name #add sbm details
206 | exp_name = exp_name + f"_{args.model}"
207 | exp_name = exp_name+"DIGL" if args.prepro=="digl" else exp_name
208 | exp_name = exp_name+"SDRF" if args.prepro=="sdrf" else exp_name
209 | exp_name = exp_name + f"_{args.derivative}" if args.model=="GAPNet" else exp_name # add derivative details
210 | exp_time = time.strftime('%d_%m_%y__%H_%M')
211 | train_log_file = exp_name + f"_{exp_time}.txt"
212 |
213 | RandList = [12345, 42345, 64345, 54345, 74345, 47345, 54321, 14321, 94321, 84328]
214 | RandList = RandList[:args.iter]
215 |
216 | if not os.path.exists(args.logs):
217 | os.makedirs(args.logs)
218 | if not os.path.exists("models/") and args.store:
219 | os.makedirs("models")
220 |
221 | ######################################################
222 | ######################################################
223 | def train(epoch, loader):
224 | model.train()
225 | loss_all = 0
226 | correct = 0
227 | #i = 0
228 | for data in loader:
229 | data = data.to(device)
230 | optimizer.zero_grad()
231 | out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) # data.batch torch.Size([783])
232 | loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss
233 | loss.backward()
234 | loss_all += data.y.size(0) * loss.item()
235 | optimizer.step()
236 | correct += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item() #accuracy in train AFTER EACH BACH
237 | return loss_all / len(loader.dataset), correct / len(loader.dataset)
238 |
239 | @torch.no_grad()
240 | def test(loader):
241 | model.eval()
242 | correct = 0
243 | for data in loader:
244 | data = data.to(device)
245 | pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
246 | loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss
247 | correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
248 |
249 | return loss, correct / len(loader.dataset)
250 |
251 | ######################################################
252 | print(device)
253 |
254 | #torch.autograd.set_detect_anomaly(True)
255 | ExperimentResult = []
256 |
257 | f = open(args.logs+os.sep+train_log_file, 'w') #clear file
258 | print("- M:", args.model, "- D:",dataset,
259 | "- Train_split:", TRAIN_SPLIT, "- B:",BATCH_SIZE,
260 | "- Centers (if CTNet):", num_of_centers, "- LAP (if GAPNet):", args.derivative,
261 | "- Classes" ,dataset.num_classes,"- Feats",dataset.num_features, file=f)
262 | f.close()
263 |
264 | print("- M:", args.model, "- D:",dataset, "- Train_split:", TRAIN_SPLIT, "- B:",BATCH_SIZE)
265 |
266 | EPS=1e-10
267 | for e in range(len(RandList)):
268 | if args.model == 'CTNet':
269 | model = CTNet(dataset.num_features, dataset.num_classes, k_centers=num_of_centers, EPS=EPS).to(device)
270 | elif args.model == 'GAPNet':
271 | model = GAPNet(dataset.num_features, dataset.num_classes, derivative=args.derivative, device=device).to(device)
272 | elif args.model == 'MinCutNet':
273 | model = MinCutNet(dataset.num_features, dataset.num_classes).to(device)
274 | elif args.model == 'DiffWire':
275 | model = DiffWire(dataset.num_features, dataset.num_classes, k_centers=num_of_centers,
276 | derivative="normalized", device=device, EPS=EPS).to(device)
277 | else:
278 | raise Exception(f"Not implemented model: {args.model}")
279 |
280 | train_indices, test_indices = train_test_split(list(range(len(dataset.data.y))), test_size=0.15, stratify=dataset.data.y,
281 | random_state=RandList[e], shuffle=True)
282 | train_dataset = torch.utils.data.Subset(dataset, train_indices)
283 | test_dataset = torch.utils.data.Subset(dataset, test_indices)
284 | #model = GAPNet(dataset.num_features, dataset.num_classes, derivative=DERIVATIVE, device=device).to(device)
285 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4) #
286 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # Original 64
287 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # Original 64
288 | #train_dataset = train_dataset.shuffle()
289 | #test_dataset = test_dataset.shuffle()
290 | optimizer.zero_grad()
291 | torch.manual_seed(RandList[e])
292 | print("Experimen run", RandList[e])
293 |
294 | for epoch in range(1, 60):
295 | start_time_epoch = time.time()
296 | train_loss, train_acc = train(epoch, train_loader) # return also train_acc_t if want Accuracy after BATCH
297 | #_, train_acc = test(train_loader)
298 | test_loss, test_acc = test(test_loader)
299 | time_lapse = time.time() - start_time_epoch
300 |
301 | f = open(args.logs+os.sep+train_log_file, 'a')
302 | print('Epoch: {:03d}, '
303 | 'Train Loss: {:.3f}, Train Acc: {:.3f}, '
304 | 'Test Loss: {:.3f}, Test Acc: {:.3f}'.format(epoch, train_loss,
305 | train_acc, test_loss,
306 | test_acc), file=f)
307 | print('{} - Epoch: {:03d}, '
308 | 'Train Loss: {:.3f}, Train Acc: {:.3f}, '
309 | 'Test Loss: {:.3f}, Test Acc: {:.3f}, Time: {:.2f}'.format(exp_name, epoch, train_loss,
310 | train_acc, test_loss,
311 | test_acc, time_lapse))
312 | f.close()
313 |
314 | if args.store:
315 | torch.save(model.state_dict(), f"models{os.sep}{exp_name}_{exp_time}_iter{e}.pth")
316 | print(f"Model saved in models{os.sep}{exp_name}_{exp_time}_iter{e}.pth")
317 |
318 | ExperimentResult.append(test_acc)
319 | f = open(args.logs+os.sep+train_log_file, 'a')
320 | print('Result of run {:.3f} is {:.3f}'.format(e,test_acc), file=f)
321 | f.close()
322 |
323 | f = open(args.logs+os.sep+train_log_file, 'a')
324 | print('Test Acc of 10 execs {}'.format(ExperimentResult), file=f)
325 | print('{} +- {}'.format(np.mean(ExperimentResult), np.std(ExperimentResult)), file=f)
326 | f.close()
327 |
--------------------------------------------------------------------------------
/trained_models/CTNet/COLLAB_CTNet_17_05_22__08_56_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/COLLAB_CTNet_17_05_22__08_56_iter0.pth
--------------------------------------------------------------------------------
/trained_models/CTNet/ERDOS_CTNet_19_05_22__14_46_iter1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/ERDOS_CTNet_19_05_22__14_46_iter1.pth
--------------------------------------------------------------------------------
/trained_models/CTNet/IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/IMDB-BINARY_CTNet_17_05_22__08_56_iter0.pth
--------------------------------------------------------------------------------
/trained_models/CTNet/MUTAG_CTNet_19_05_22__11_35_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/MUTAG_CTNet_19_05_22__11_35_iter0.pth
--------------------------------------------------------------------------------
/trained_models/CTNet/REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/REDDIT-BINARY_CTNet_17_05_22__08_50_iter0.pth
--------------------------------------------------------------------------------
/trained_models/CTNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/CTNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_CTNet_18_05_22__22_18_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/COLLAB_GAPNet_normalized_18_05_22__21_30_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/ERDOS_GAPNet_normalized_19_05_22__14_46_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/IMDB-BINARY_GAPNet_normalized_19_05_22__11_35_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/MUTAG_GAPNet_normalized_19_05_22__11_35_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/REDDIT-BINARY_GAPNet_normalized_19_05_22__10_09_iter0.pth
--------------------------------------------------------------------------------
/trained_models/GAPNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/GAPNet/SBM200nodesT1000graphsT80p1T50p2T10to15q1T1to10q2_GAPNet_normalized_18_05_22__21_44_iter0.pth
--------------------------------------------------------------------------------
/trained_models/Laplacian/COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/Laplacian/COLLAB_GAPNet_laplacian_16_05_22__16_46_iter0.pth
--------------------------------------------------------------------------------
/trained_models/Laplacian/REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdrianArnaiz/DiffWire/deced4bbe088827e39a9359fa368a8efa2b00cfd/trained_models/Laplacian/REDDIT-BINARY_GAPNet_laplacian_16_05_22__11_04_iter0.pth
--------------------------------------------------------------------------------
/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | #from transforms.sdrf.sdrf_transform import SDRF
2 | from transforms.transform_features import FeatureDegree
3 | from transforms.transform_features import DIGLedges
4 | from transforms.transform_features import KNNGraph
--------------------------------------------------------------------------------
/transforms/sdrf/curvature.py:
--------------------------------------------------------------------------------
1 | import math
2 | from numba import cuda
3 | import numpy as np
4 | import torch
5 | from torch_geometric.utils import (
6 | to_networkx,
7 | from_networkx,
8 | to_dense_adj,
9 | remove_self_loops,
10 | to_undirected,
11 | )
12 |
13 | from transforms.sdrf.utils import softmax
14 |
15 |
16 | @cuda.jit(
17 | "void(float32[:,:], float32[:,:], float32[:], float32[:], int32, float32[:,:])"
18 | )
19 | def _balanced_forman_curvature(A, A2, d_in, d_out, N, C):
20 | i, j = cuda.grid(2)
21 |
22 | if (i < N) and (j < N):
23 | if A[i, j] == 0:
24 | C[i, j] = 0
25 | return
26 |
27 | if d_in[i] > d_out[j]:
28 | d_max = d_in[i]
29 | d_min = d_out[j]
30 | else:
31 | d_max = d_out[j]
32 | d_min = d_in[i]
33 |
34 | if d_max * d_min == 0:
35 | C[i, j] = 0
36 | return
37 |
38 | sharp_ij = 0
39 | lambda_ij = 0
40 | for k in range(N):
41 | TMP = A[k, j] * (A2[i, k] - A[i, k]) * A[i, j]
42 | if TMP > 0:
43 | sharp_ij += 1
44 | if TMP > lambda_ij:
45 | lambda_ij = TMP
46 |
47 | TMP = A[i, k] * (A2[k, j] - A[k, j]) * A[i, j]
48 | if TMP > 0:
49 | sharp_ij += 1
50 | if TMP > lambda_ij:
51 | lambda_ij = TMP
52 |
53 | C[i, j] = (
54 | (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2[i, j] * A[i, j]
55 | )
56 | if lambda_ij > 0:
57 | C[i, j] += sharp_ij / (d_max * lambda_ij)
58 |
59 |
60 | def balanced_forman_curvature(A, C=None):
61 | N = A.shape[0]
62 | A2 = torch.matmul(A, A)
63 | d_in = A.sum(axis=0)
64 | d_out = A.sum(axis=1)
65 | if C is None:
66 | C = torch.zeros(N, N).cuda()
67 |
68 | threadsperblock = (16, 16)
69 | blockspergrid_x = math.ceil(N / threadsperblock[0])
70 | blockspergrid_y = math.ceil(N / threadsperblock[1])
71 | blockspergrid = (blockspergrid_x, blockspergrid_y)
72 |
73 | _balanced_forman_curvature[blockspergrid, threadsperblock](A, A2, d_in, d_out, N, C)
74 | return C
75 |
76 |
77 | @cuda.jit(
78 | "void(float32[:,:], float32[:,:], float32, float32, int32, float32[:,:], int32, int32, int32[:], int32[:], int32, int32)"
79 | )
80 | def _balanced_forman_post_delta(
81 | A, A2, d_in_x, d_out_y, N, D, x, y, i_neighbors, j_neighbors, dim_i, dim_j
82 | ):
83 | I, J = cuda.grid(2)
84 |
85 | if (I < dim_i) and (J < dim_j):
86 | i = i_neighbors[I]
87 | j = j_neighbors[J]
88 |
89 | if (i == j) or (A[i, j] != 0):
90 | D[I, J] = -1000
91 | return
92 |
93 | # Difference in degree terms
94 | if j == x:
95 | d_in_x += 1
96 | elif i == y:
97 | d_out_y += 1
98 |
99 | if d_in_x * d_out_y == 0:
100 | D[I, J] = 0
101 | return
102 |
103 | if d_in_x > d_out_y:
104 | d_max = d_in_x
105 | d_min = d_out_y
106 | else:
107 | d_max = d_out_y
108 | d_min = d_in_x
109 |
110 | # Difference in triangles term
111 | A2_x_y = A2[x, y]
112 | if (x == i) and (A[j, y] != 0):
113 | A2_x_y += A[j, y]
114 | elif (y == j) and (A[x, i] != 0):
115 | A2_x_y += A[x, i]
116 |
117 | # Difference in four-cycles term
118 | sharp_ij = 0
119 | lambda_ij = 0
120 | for z in range(N):
121 | A_z_y = A[z, y] + 0
122 | A_x_z = A[x, z] + 0
123 | A2_z_y = A2[z, y] + 0
124 | A2_x_z = A2[x, z] + 0
125 |
126 | if (z == i) and (y == j):
127 | A_z_y += 1
128 | if (x == i) and (z == j):
129 | A_x_z += 1
130 | if (z == i) and (A[j, y] != 0):
131 | A2_z_y += A[j, y]
132 | if (x == i) and (A[j, z] != 0):
133 | A2_x_z += A[j, z]
134 | if (y == j) and (A[z, i] != 0):
135 | A2_z_y += A[z, i]
136 | if (z == j) and (A[x, i] != 0):
137 | A2_x_z += A[x, i]
138 |
139 | TMP = A_z_y * (A2_x_z - A_x_z) * A[x, y]
140 | if TMP > 0:
141 | sharp_ij += 1
142 | if TMP > lambda_ij:
143 | lambda_ij = TMP
144 |
145 | TMP = A_x_z * (A2_z_y - A_z_y) * A[x, y]
146 | if TMP > 0:
147 | sharp_ij += 1
148 | if TMP > lambda_ij:
149 | lambda_ij = TMP
150 |
151 | D[I, J] = (
152 | (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2_x_y * A[x, y]
153 | )
154 | if lambda_ij > 0:
155 | D[I, J] += sharp_ij / (d_max * lambda_ij)
156 |
157 |
158 | def balanced_forman_post_delta(A, x, y, i_neighbors, j_neighbors, D=None):
159 | N = A.shape[0]
160 | A2 = torch.matmul(A, A)
161 | d_in = A[:, x].sum()
162 | d_out = A[y].sum()
163 | if D is None:
164 | D = torch.zeros(len(i_neighbors), len(j_neighbors)).cuda()
165 |
166 | threadsperblock = (16, 16)
167 | blockspergrid_x = math.ceil(D.shape[0] / threadsperblock[0])
168 | blockspergrid_y = math.ceil(D.shape[1] / threadsperblock[1])
169 | blockspergrid = (blockspergrid_x, blockspergrid_y)
170 |
171 | _balanced_forman_post_delta[blockspergrid, threadsperblock](
172 | A,
173 | A2,
174 | d_in,
175 | d_out,
176 | N,
177 | D,
178 | x,
179 | y,
180 | np.array(i_neighbors),
181 | np.array(j_neighbors),
182 | D.shape[0],
183 | D.shape[1],
184 | )
185 | return D
186 |
187 |
188 | def sdrf(
189 | data,
190 | loops=10,
191 | remove_edges=True,
192 | removal_bound=0.5,
193 | tau=1,
194 | is_undirected=False,
195 | ):
196 | edge_index = data.edge_index
197 | if is_undirected:
198 | edge_index = to_undirected(edge_index)
199 | A = to_dense_adj(remove_self_loops(edge_index)[0])[0]
200 | N = A.shape[0]
201 | G = to_networkx(data)
202 | if is_undirected:
203 | G = G.to_undirected()
204 | A = A.cuda()
205 | C = torch.zeros(N, N).cuda()
206 |
207 | for x in range(loops):
208 | can_add = True
209 | balanced_forman_curvature(A, C=C)
210 | ix_min = C.argmin().item()
211 | x = ix_min // N
212 | y = ix_min % N
213 |
214 | if is_undirected:
215 | x_neighbors = list(G.neighbors(x)) + [x]
216 | y_neighbors = list(G.neighbors(y)) + [y]
217 | else:
218 | x_neighbors = list(G.successors(x)) + [x]
219 | y_neighbors = list(G.predecessors(y)) + [y]
220 | candidates = []
221 | for i in x_neighbors:
222 | for j in y_neighbors:
223 | if (i != j) and (not G.has_edge(i, j)):
224 | candidates.append((i, j))
225 |
226 | if len(candidates):
227 | D = balanced_forman_post_delta(A, x, y, x_neighbors, y_neighbors)
228 | improvements = []
229 | for (i, j) in candidates:
230 | improvements.append(
231 | (D - C[x, y])[x_neighbors.index(i), y_neighbors.index(j)].item()
232 | )
233 |
234 | k, l = candidates[
235 | np.random.choice(
236 | range(len(candidates)), p=softmax(np.array(improvements), tau=tau)
237 | )
238 | ]
239 | G.add_edge(k, l)
240 | if is_undirected:
241 | A[k, l] = A[l, k] = 1
242 | else:
243 | A[k, l] = 1
244 | else:
245 | can_add = False
246 | if not remove_edges:
247 | break
248 |
249 | if remove_edges:
250 | ix_max = C.argmax().item()
251 | x = ix_max // N
252 | y = ix_max % N
253 | if C[x, y] > removal_bound:
254 | G.remove_edge(x, y)
255 | if is_undirected:
256 | A[x, y] = A[y, x] = 0
257 | else:
258 | A[x, y] = 0
259 | else:
260 | if can_add is False:
261 | break
262 |
263 | return from_networkx(G)
264 |
--------------------------------------------------------------------------------
/transforms/sdrf/sdrf_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.transforms import BaseTransform
3 | from torch_geometric.data import Data
4 | from transforms.sdrf.curvature import sdrf
5 | from transforms.sdrf.utils import get_dataset
6 |
7 | class SDRF(BaseTransform):
8 |
9 | def __init__(self,
10 | max_steps: int = None,
11 | remove_edges: bool = True,
12 | removal_bound: float = 0.5,
13 | tau: float = 1,
14 | undirected: bool = False,
15 | use_edge_weigths = True
16 | ):
17 |
18 | self.max_steps = max_steps
19 | self.remove_edges = remove_edges
20 | self.removal_bound = removal_bound
21 | self.tau = tau
22 | self.undirected = undirected
23 | self.use_edge_weigths = use_edge_weigths
24 |
25 | def __call__(self, graph_data):
26 |
27 | graph_data = get_dataset(graph_data, use_lcc=False)
28 |
29 | if self.max_steps == 'dynamic':
30 | self.max_steps = int(0.7 * graph_data.num_nodes)
31 | else:
32 | self.max_steps = int(self.max_steps)
33 |
34 | altered_data = sdrf(
35 | graph_data,
36 | loops=self.max_steps,
37 | remove_edges=self.remove_edges,
38 | removal_bound=self.removal_bound,
39 | tau=self.tau,
40 | is_undirected=self.undirected,
41 | )
42 |
43 | new_data = Data(
44 | edge_index=torch.LongTensor(altered_data.edge_index),
45 | edge_attr=torch.FloatTensor(altered_data.edge_attr) if altered_data.edge_attr is not None else None,
46 | y=graph_data.y,
47 | x=graph_data.x,
48 | num_nodes = graph_data.num_nodes
49 | )
50 | return new_data
51 |
52 |
53 | def __repr__(self) -> str:
54 | return f'{self.__class__.__name__}({self.max_steps})'
55 |
56 |
--------------------------------------------------------------------------------
/transforms/sdrf/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch_geometric.data import Data
3 | import torch
4 |
5 | def softmax(a, tau=1):
6 | exp_a = np.exp(a * tau)
7 | return exp_a / exp_a.sum()
8 |
9 |
10 | def get_node_mapper(lcc: np.ndarray) -> dict:
11 | mapper = {}
12 | counter = 0
13 | for node in lcc:
14 | mapper[node] = counter
15 | counter += 1
16 | return mapper
17 |
18 |
19 | def remap_edges(edges: list, mapper: dict) -> list:
20 | row = [e[0] for e in edges]
21 | col = [e[1] for e in edges]
22 | row = list(map(lambda x: mapper[x], row))
23 | col = list(map(lambda x: mapper[x], col))
24 | return [row, col]
25 |
26 |
27 | def get_component(dataset, start: int = 0) -> set:
28 | visited_nodes = set()
29 | queued_nodes = set([start])
30 | row, col = dataset.edge_index.numpy()
31 | while queued_nodes:
32 | current_node = queued_nodes.pop()
33 | visited_nodes.update([current_node])
34 | neighbors = col[np.where(row == current_node)[0]]
35 | neighbors = [
36 | n for n in neighbors if n not in visited_nodes and n not in queued_nodes
37 | ]
38 | queued_nodes.update(neighbors)
39 | return visited_nodes
40 |
41 |
42 | def get_largest_connected_component(dataset) -> np.ndarray:
43 | remaining_nodes = set(range(dataset.num_nodes))
44 | comps = []
45 | while remaining_nodes:
46 | start = min(remaining_nodes)
47 | comp = get_component(dataset, start)
48 | comps.append(comp)
49 | remaining_nodes = remaining_nodes.difference(comp)
50 | return np.array(list(comps[np.argmax(list(map(len, comps)))]))
51 |
52 |
53 | def get_dataset(dataset: Data, use_lcc: bool = True):
54 | if use_lcc:
55 | lcc = get_largest_connected_component(dataset)
56 |
57 | if dataset.x is not None:
58 | x_new = dataset.x[lcc]
59 | else:
60 | x_new = None
61 | y_new = dataset.y # for graph clf, same y
62 |
63 | row, col = dataset.edge_index.numpy()
64 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
65 | edges = remap_edges(edges, get_node_mapper(lcc))
66 |
67 | data = Data(
68 | x=x_new,
69 | edge_index=torch.LongTensor(edges),
70 | y=y_new,
71 | num_nodes = dataset.num_nodes
72 | )
73 | dataset = data
74 |
75 | return dataset
76 |
--------------------------------------------------------------------------------
/transforms/transform_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import torch_geometric
5 | from torch_geometric.transforms import BaseTransform
6 | from torch_geometric.utils import degree, to_undirected
7 | from torch_geometric.utils.convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix
8 |
9 | import scipy.sparse as sp
10 | import numpy as np
11 |
12 |
13 | class FeatureDegree(BaseTransform):
14 | r"""Adds the node degree as one hot encodings to the node features.
15 |
16 | Args:
17 | max_degree (int): Maximum degree.
18 | in_degree (bool, optional): If set to :obj:`True`, will compute the
19 | in-degree of nodes instead of the out-degree.
20 | (default: :obj:`False`)
21 | cat (bool, optional): Concat node degrees to node features instead
22 | of replacing them. (default: :obj:`True`)
23 | """
24 | def __init__(self, in_degree=False, cat=True):
25 | self.in_degree = in_degree
26 | self.cat = cat
27 |
28 | def __call__(self, data):
29 | idx, x = data.edge_index[1 if self.in_degree else 0], data.x
30 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1)
31 | #deg = F.one_hot(deg, num_classes=self.max_degree + 1).to(torch.float)
32 |
33 | if x is not None and self.cat:
34 | x = x.view(-1, 1) if x.dim() == 1 else x
35 | data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)
36 | else:
37 | data.x = deg
38 |
39 | return data
40 |
41 | def __repr__(self) -> str:
42 | return f'{self.__class__.__name__}({self.in_degree})'
43 |
44 |
45 | class DIGLedges(BaseTransform):
46 | def __init__(self, alpha:float, use_edge_weigths = False):
47 | self.alpha = alpha
48 | self.eps = 0.005
49 | self.use_edge_weigths = use_edge_weigths
50 |
51 | def __call__(self, data):
52 | new_edges, new_weights = self.digl_edges(data.edge_index, data.num_edges)
53 | data.edge_index = new_edges
54 |
55 | if self.use_edge_weigths:
56 | data.edge_weight = new_weights
57 |
58 | return data
59 |
60 |
61 | def gdc(self, A: sp.csr_matrix, alpha: float, num_previous_edges):
62 | N = A.shape[0]
63 |
64 | # Self-loops
65 | A_loop = sp.eye(N) + A
66 |
67 | # Symmetric transition matrix
68 | D_loop_vec = A_loop.sum(0).A1
69 | D_loop_vec_invsqrt = 1 / np.sqrt(D_loop_vec)
70 | D_loop_invsqrt = sp.diags(D_loop_vec_invsqrt)
71 | T_sym = D_loop_invsqrt @ A_loop @ D_loop_invsqrt
72 |
73 | # PPR-based diffusion
74 | S = alpha * sp.linalg.inv(sp.eye(N) - (1 - alpha) * T_sym)
75 |
76 | # Same as e-threshold based on average degree
77 | # but dynamic for all graphs
78 | A = np.array(S.todense())
79 | top_k_idx = np.unravel_index(np.argsort(A.ravel())[-num_previous_edges:], A.shape)
80 | mask = np.ones(A.shape, bool)
81 | mask[top_k_idx] = False
82 | A[mask] = 0
83 | S_tilde = sp.csr_matrix(A)
84 |
85 |
86 | # Column-normalized transition matrix on graph S_tilde
87 | D_tilde_vec = S_tilde.sum(0).A1
88 | T_S = S_tilde / D_tilde_vec
89 |
90 | return T_S
91 |
92 | def get_top_k_matrix(self, A: np.ndarray, k: int = 128) -> np.ndarray:
93 | """
94 | Get k best edges for EACH NODE
95 | """
96 | num_nodes = A.shape[0]
97 | print('AA', num_nodes)
98 | row_idx = np.arange(num_nodes)
99 | A[A.argsort(axis=0)[:num_nodes - k], row_idx] = 0.
100 | norm = A.sum(axis=0)
101 | norm[norm <= 0] = 1 # avoid dividing by zero
102 | return A/norm
103 |
104 | def digl_edges(self, edges, num_previous_edges):
105 | A0 = sp.csr_matrix(to_scipy_sparse_matrix(edges))
106 | new_sp_matrix = sp.csr_matrix(self.gdc(A0, self.alpha, num_previous_edges))
107 | new_edge_index, weights = from_scipy_sparse_matrix(new_sp_matrix)
108 | return new_edge_index, weights
109 |
110 |
111 | def __repr__(self) -> str:
112 | return f'{self.__class__.__name__}(alpha={self.alpha}, eps={self.alpha})'
113 |
114 |
115 | class KNNGraph(BaseTransform):
116 | r"""Creates a k-NN graph based on node positions :obj:`pos`
117 | (functional name: :obj:`knn_graph`).
118 |
119 | Args:
120 | k (int, optional): The number of neighbors. (default: :obj:`6`)
121 | loop (bool, optional): If :obj:`True`, the graph will contain
122 | self-loops. (default: :obj:`False`)
123 | force_undirected (bool, optional): If set to :obj:`True`, new edges
124 | will be undirected. (default: :obj:`False`)
125 | flow (string, optional): The flow direction when used in combination
126 | with message passing (:obj:`"source_to_target"` or
127 | :obj:`"target_to_source"`).
128 | If set to :obj:`"source_to_target"`, every target node will have
129 | exactly :math:`k` source nodes pointing to it.
130 | (default: :obj:`"source_to_target"`)
131 | cosine (boolean, optional): If :obj:`True`, will use the cosine
132 | distance instead of euclidean distance to find nearest neighbors.
133 | (default: :obj:`False`)
134 | num_workers (int): Number of workers to use for computation. Has no
135 | effect in case :obj:`batch` is not :obj:`None`, or the input lies
136 | on the GPU. (default: :obj:`1`)
137 | """
138 | def __init__(
139 | self,
140 | k=None,
141 | loop=False,
142 | force_undirected=True,
143 | flow='source_to_target',
144 | cosine: bool = False,
145 | num_workers: int = 1,
146 | ):
147 | self.k = k
148 | self.loop = loop
149 | self.force_undirected = force_undirected
150 | self.flow = flow
151 | self.cosine = cosine
152 | self.num_workers = num_workers
153 |
154 | def __call__(self, data):
155 | had_features = True
156 | data.edge_attr = None
157 | batch = data.batch if 'batch' in data else None
158 |
159 | if data.x is None:
160 | idx = data.edge_index[1]
161 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1)
162 | data.x = deg
163 | had_features = False
164 |
165 | if self.k is None:
166 | self.k = int(data.num_edges / (data.num_nodes*4)) #mean degree - Note: num_edges is already doubled by default in PyG
167 |
168 |
169 | edge_index = torch_geometric.nn.knn_graph(
170 | data.x,
171 | self.k,
172 | batch,
173 | loop=self.loop,
174 | flow=self.flow,
175 | cosine=self.cosine,
176 | num_workers=self.num_workers,
177 | )
178 | if self.force_undirected:
179 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
180 |
181 | data.edge_index = edge_index
182 |
183 | #Update degree
184 | if not had_features:
185 | idx = data.edge_index[1]
186 | deg = degree(idx, data.num_nodes, dtype=torch.float).unsqueeze(-1)
187 | data.x = deg
188 |
189 | return data
190 |
191 | def __repr__(self) -> str:
192 | return f'{self.__class__.__name__}(k={self.k})'
--------------------------------------------------------------------------------