├── LICENCE.md ├── README.md ├── config.py ├── datasets ├── __init__.py ├── base_dataset.py ├── config.py ├── gnn_benchmark.py ├── grid.pt ├── lrgb.py ├── manifold.py ├── mnist.py ├── modelnet.py ├── moleculenet.py ├── transforms.py ├── tu.py └── weighted_mnist.py ├── dect ├── LICENSE.txt ├── README.rst ├── __init__.py ├── dect │ ├── __init__.py │ └── ect.py └── pyproject.toml ├── example.ipynb ├── experiment ├── DD │ └── ect_cnn_edges.yaml ├── Letter-high │ └── ect_cnn_edges.yaml ├── Letter-low │ └── ect_cnn_edges.yaml ├── Letter-med │ └── ect_cnn_edges.yaml ├── PROTEINS_full │ └── ect_cnn_edges.yaml ├── ablations │ ├── ablations0.yaml │ ├── ablations1.yaml │ ├── ablations2.yaml │ ├── ablations3.yaml │ ├── ablations4.yaml │ ├── ablations5.yaml │ ├── ablations6.yaml │ ├── ablations7.yaml │ └── ablations8.yaml ├── bzr │ └── ect_cnn_edges.yaml ├── cox2 │ └── ect_cnn_edges.yaml ├── dhfr │ └── ect_cnn_edges.yaml ├── gnn_mnist_classification │ └── ect_cnn_edges.yaml ├── lrgb │ ├── lrgb_cnn_edges.yaml │ ├── lrgb_cnn_points.yaml │ ├── lrgb_linear_edges.yaml │ └── lrgb_linear_points.yaml ├── manifold_classification │ ├── ect_cnn_faces.yaml │ └── ect_linear_faces.yaml └── weighted_mnist │ └── wect.yaml ├── figures └── ect_animation.gif ├── generate_experiments.py ├── loaders └── factory.py ├── logger.py ├── main.py ├── metrics └── metrics.py ├── models ├── __init__.py ├── base_model.py ├── config.py ├── ect_cnn.py ├── ect_linear.py ├── layers │ ├── __init__.py │ ├── layers.py │ └── layers_wect.py ├── wect_cnn.py └── wect_linear.py ├── notebooks └── dect.ipynb ├── requirements.txt ├── single_main.py └── utils └── __init__.py /LICENCE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Ernst Roell 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 19 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 20 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 25 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 26 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 27 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DECT - Differentiable Euler Characteristic Transform 2 | [![arXiv](https://img.shields.io/badge/arXiv-2310.07630-b31b1b.svg)](https://arxiv.org/abs/2310.07630) ![GitHub contributors](https://img.shields.io/github/contributors/aidos-lab/DECT) ![GitHub](https://img.shields.io/github/license/aidos-lab/DECT) 3 | 4 | This is the official repository for the ICLR 2024 paper: [Differentiable Euler Characteristic Transforms for Shape Classification](https://arxiv.org/abs/2310.07630). 5 | 6 | **Abstract** The Euler Characteristic Transform (ECT) has proven to be a powerful representation, combining geometrical and topological characteristics of shapes and graphs. However, the ECT was hitherto unable to learn task-specific representations. We overcome this issue and develop a novel computational layer that enables learning the ECT in an end-to-end fashion. Our method, the Differentiable Euler Characteristic Transform (DECT), is fast and computationally efficient, while exhibiting performance on a par with more complex models in both graph and point cloud classification tasks. Moreover, we show that this seemingly simple statistic provides the same topological expressivity as more complex topological deep learning layers. 7 | 8 | ![Animated-ECT](figures/ect_animation.gif) 9 | 10 | 11 | Please use the following citation for our work: 12 | 13 | ```bibtex 14 | @inproceedings{Roell24a, 15 | title = {Differentiable Euler Characteristic Transforms for Shape Classification}, 16 | author = {Ernst R{\"o}ell and Bastian Rieck}, 17 | year = 2024, 18 | booktitle = {International Conference on Learning Representations}, 19 | eprint = {2310.07630}, 20 | archiveprefix = {arXiv}, 21 | primaryclass = {cs.LG}, 22 | repository = {https://github.com/aidos-lab/DECT}, 23 | url = {https://openreview.net/forum?id=MO632iPq3I}, 24 | } 25 | ``` 26 | 27 | ## Installation 28 | Our code has been developed using python 3.10 and using pytorch 2.0.1 installed 29 | with CUDA 11.7. 30 | After installing the above, install the requirements in the requirements.txt. 31 | 32 | ```{python} 33 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Usage 38 | 39 | To run a single experiment, run `single_main.py` and modify the path to the 40 | right experiment. 41 | The configuration files for each experiment can be found under the 42 | `experiment` folder and the parameters are in the `.yaml` files. 43 | 44 | To run all experiments in a folder, update the path in `main.py` to that 45 | folder and run `main.py`. 46 | 47 | All datasets will be downloaded and preprocessed when first ran via the 48 | `torch_geometric` package. 49 | The TU Datasets are small and run fast, so for testing purposes it is recommended 50 | to run these first. 51 | 52 | Alternatively, for research purposes the ECT is installable as a python package. 53 | To install DECT, run the following in the terminal 54 | 55 | ```{bash} 56 | pip install "git+https://github.com/aidos-lab/DECT/#subdirectory=dect" 57 | ``` 58 | 59 | For example usage, we provide the `example.ipynb` file and the code therein reproduces the 60 | ECT of the gif of this readme. 61 | The code is provided on an as is basis. You are cordially invited to both contribute and 62 | provide feedback. Do not hesitate to contact us. 63 | 64 | ## Examples 65 | 66 | The core of our method, the differentiable computation of the Euler Characteristic 67 | transform, can be found in the `./models/layers/layers.py` folder. 68 | Since the code is somewhat terse, highly vectorised and optimized for batch 69 | processing, we provide an example computation that illustrates the core 70 | principle of our method. 71 | 72 | 73 | ## License 74 | 75 | Our code is released under a BSD-3-Clause license. This license 76 | essentially permits you to freely use our code as desired, integrate it 77 | into your projects, and much more --- provided you acknowledge the 78 | original authors. Please refer to [LICENSE.md](./LICENSE.md) for more 79 | information. 80 | 81 | ## Issues 82 | 83 | This project is maintained by members of the [AIDOS Lab](https://github.com/aidos-lab). 84 | Please open an [issue](https://github.com/aidos-lab/TARDIS/issues) in 85 | case you encounter any problems. 86 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass(frozen=True) 6 | class Config: 7 | meta: Any 8 | data: Any 9 | model: Any 10 | trainer: Any 11 | 12 | 13 | @dataclass 14 | class Meta: 15 | name: str 16 | project: str = "desct" 17 | tags: list[str] = field(default_factory=list) 18 | experiment_folder: str = "experiment" 19 | 20 | 21 | # ╭──────────────────────────────────────────────────────────╮ 22 | # │ Data Configurations │ 23 | # ╰──────────────────────────────────────────────────────────╯ 24 | 25 | 26 | # ╭──────────────────────────────────────────────────────────╮ 27 | # │ Model Configurations │ 28 | # ╰──────────────────────────────────────────────────────────╯ 29 | 30 | 31 | @dataclass 32 | class ModelConfig: 33 | name: str 34 | config: Any 35 | 36 | 37 | # ╭──────────────────────────────────────────────────────────╮ 38 | # │ Trainer configurations │ 39 | # ╰──────────────────────────────────────────────────────────╯ 40 | 41 | 42 | @dataclass 43 | class TrainerConfig: 44 | lr: float = 0.001 45 | num_epochs: int = 200 46 | num_reruns: int = 1 47 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # import os, sys 2 | # from pydoc import locate 3 | 4 | # """ 5 | # This package includes all the modules related to data loading and preprocessing. 6 | # """ 7 | 8 | # """ 9 | # The below code fetches all files in the datasets folder and imports them. Each 10 | # dataset can have its own file, or grouped in sets. 11 | # """ 12 | # """ path = os.path.dirname(os.path.abspath(__file__)) """ 13 | # """ for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']: """ 14 | # """ mod = __import__('.'.join([__name__, py]), fromlist=[py]) """ 15 | # """ classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] """ 16 | # """ for cls in classes: """ 17 | # """ setattr(sys.modules[__name__], cls.__name__, cls) """ 18 | # """""" 19 | # """""" 20 | # def load_datamodule(name=None,config=None): 21 | # dataset = locate(f'name') 22 | # print("loading", dataset) 23 | # if not dataset: 24 | # print(name) 25 | # raise AttributeError() 26 | # return dataset(config) 27 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from torch_geometric.loader import DataLoader 3 | from torch_geometric.data import Dataset 4 | from torch_geometric.loader import ImbalancedSampler 5 | 6 | class DataModule(ABC): 7 | train_ds: Dataset 8 | test_ds: Dataset 9 | val_ds: Dataset 10 | entire_ds: Dataset 11 | 12 | def __init__(self, root, batch_size, num_workers, pin_memory=True, drop_last=True): 13 | super().__init__() 14 | self.data_dir = root 15 | self.batch_size = batch_size 16 | self.num_workers = num_workers 17 | self.pin_memory = pin_memory 18 | self.drop_last = drop_last 19 | self.setup() 20 | 21 | @abstractmethod 22 | def setup(self): 23 | raise NotImplementedError() 24 | 25 | def train_dataloader(self): 26 | return DataLoader( 27 | self.train_ds, 28 | batch_size=self.batch_size, 29 | num_workers=self.num_workers, 30 | sampler=ImbalancedSampler(self.train_ds), 31 | # shuffle=True, 32 | pin_memory=self.pin_memory, 33 | # drop_last=self.drop_last, 34 | ) 35 | 36 | def val_dataloader(self): 37 | return DataLoader( 38 | self.val_ds, 39 | batch_size=self.batch_size, 40 | num_workers=self.num_workers, 41 | # sampler=ImbalancedSampler(self.val_ds), 42 | # shuffle=False, 43 | pin_memory=self.pin_memory, 44 | # drop_last=self.drop_last, 45 | ) 46 | 47 | def test_dataloader(self): 48 | return DataLoader( 49 | self.test_ds, 50 | batch_size=self.batch_size, 51 | num_workers=self.num_workers, 52 | # sampler=ImbalancedSampler(self.test_ds), 53 | shuffle=False, 54 | pin_memory=self.pin_memory, 55 | # drop_last=self.drop_last, 56 | ) 57 | -------------------------------------------------------------------------------- /datasets/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DataModuleConfig: 6 | module: str 7 | root: str = "./data" 8 | num_workers: int = 0 9 | batch_size: int = 64 10 | pin_memory: bool = True 11 | drop_last: bool = False 12 | 13 | 14 | @dataclass 15 | class GNNBenchmarkDataModuleConfig(DataModuleConfig): 16 | name: str = "MNIST" 17 | module: str = "datasets.gnn_benchmark" 18 | 19 | 20 | @dataclass 21 | class MnistDataModuleConfig(DataModuleConfig): 22 | root: str = "./data/MNIST" 23 | module: str = "datasets.mnist" 24 | 25 | 26 | @dataclass 27 | class WeightedMnistDataModuleConfig(DataModuleConfig): 28 | root: str = "./data/WMNIST" 29 | module: str = "datasets.weighted_mnist" 30 | 31 | 32 | @dataclass 33 | class ModelNetDataModuleConfig(DataModuleConfig): 34 | root: str = "./data/modelnet" 35 | name: str = "10" 36 | module: str = "datasets.modelnet" 37 | samplepoints: int = 100 38 | -------------------------------------------------------------------------------- /datasets/gnn_benchmark.py: -------------------------------------------------------------------------------- 1 | from datasets.base_dataset import DataModule 2 | from torch_geometric.datasets import GNNBenchmarkDataset 3 | 4 | import torchvision.transforms as transforms 5 | from datasets.transforms import CenterTransform, ThresholdTransform 6 | 7 | from loaders.factory import register 8 | 9 | 10 | transforms_dict = { 11 | "MNIST": [ 12 | ThresholdTransform(), 13 | CenterTransform(), 14 | ], 15 | "CIFAR10": [ 16 | ThresholdTransform(), 17 | CenterTransform(), 18 | ], 19 | "PATTERN": [ 20 | CenterTransform(), 21 | ], 22 | } 23 | 24 | 25 | class GNNBenchmarkDataModule(DataModule): 26 | def __init__(self, config): 27 | self.config = config 28 | self.transform = transforms.Compose(transforms_dict[self.config.name]) 29 | super().__init__(config.root, config.batch_size, config.num_workers) 30 | 31 | def setup(self): 32 | self.train_ds = GNNBenchmarkDataset( 33 | name=self.config.name, 34 | root=self.config.root, 35 | pre_transform=self.transform, 36 | split="train", 37 | ) 38 | self.test_ds = GNNBenchmarkDataset( 39 | name=self.config.name, 40 | root=self.config.root, 41 | pre_transform=self.transform, 42 | split="test", 43 | ) 44 | self.val_ds = GNNBenchmarkDataset( 45 | name=self.config.name, 46 | root=self.config.root, 47 | pre_transform=self.transform, 48 | split="val", 49 | ) 50 | 51 | 52 | def initialize(): 53 | register("dataset", GNNBenchmarkDataModule) 54 | -------------------------------------------------------------------------------- /datasets/grid.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/dect-evaluation/87640a57eedb6f528569b8f66647a0e31828f156/datasets/grid.pt -------------------------------------------------------------------------------- /datasets/lrgb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | from dataclasses import dataclass 4 | 5 | import torchvision.transforms as transforms 6 | from torch_geometric.transforms import BaseTransform 7 | from torch_geometric.datasets import LRGBDataset 8 | 9 | from datasets.transforms import CenterTransform 10 | from datasets.base_dataset import DataModule 11 | from loaders.factory import register 12 | from datasets.config import DataModuleConfig 13 | 14 | 15 | @dataclass 16 | class LRGBDataModuleConfig(DataModuleConfig): 17 | root: str = "./data/lrgb" 18 | module: str = "datasets.lrgb" 19 | name: str = "Peptides-func" 20 | 21 | 22 | class ConvertToFloat(BaseTransform): 23 | def __call__(self, data): 24 | """Convert node feature data type to float.""" 25 | data["x"] = data["x"].to(dtype=torch.float) 26 | return data 27 | 28 | 29 | class OneHotDecoding(BaseTransform): 30 | def __call__(self, data): 31 | """Adjust multi-class labels (reverse one-hot encoding). 32 | 33 | This is necessary because some data sets use one-hot encoding 34 | for their labels, wreaks havoc with some multi-class tasks. 35 | """ 36 | label = data["y"] 37 | 38 | if len(label.shape) > 1: 39 | label = label.squeeze().tolist() 40 | 41 | if isinstance(label, list): 42 | label = label.index(1.0) 43 | 44 | data["y"] = torch.as_tensor([label], dtype=torch.long) 45 | 46 | return data 47 | 48 | 49 | class LrgbDataModule(DataModule): 50 | def __init__(self, config): 51 | self.config = config 52 | self.transform = transforms.Compose( 53 | [OneHotDecoding(), ConvertToFloat(), CenterTransform()] 54 | ) 55 | super().__init__( 56 | config.root, config.batch_size, config.num_workers, config.pin_memory 57 | ) 58 | 59 | def setup(self): 60 | self.entire_ds = torch.utils.data.ConcatDataset( 61 | [ 62 | LRGBDataset( 63 | root=self.config.root, 64 | pre_transform=self.transform, 65 | name=self.config.name, 66 | split="train" 67 | ), 68 | LRGBDataset( 69 | root=self.config.root, 70 | pre_transform=self.transform, 71 | name=self.config.name, 72 | split="test" 73 | ), 74 | LRGBDataset( 75 | root=self.config.root, 76 | pre_transform=self.transform, 77 | name=self.config.name, 78 | split="val" 79 | ) 80 | ] 81 | ) 82 | 83 | self.train_ds, self.test_ds = random_split( 84 | self.entire_ds, 85 | [ 86 | int(0.9 * len(self.entire_ds)), 87 | len(self.entire_ds) - int(0.9 * len(self.entire_ds)), 88 | ], 89 | ) # type: ignore 90 | 91 | self.train_ds, self.val_ds = random_split( 92 | self.train_ds, 93 | [ 94 | int(0.9 * len(self.train_ds)), 95 | len(self.train_ds) - int(0.9 * len(self.train_ds)), 96 | ], 97 | ) # type: ignore 98 | 99 | 100 | def initialize(): 101 | register("dataset", LrgbDataModule) 102 | -------------------------------------------------------------------------------- /datasets/manifold.py: -------------------------------------------------------------------------------- 1 | from datasets.base_dataset import DataModule, DataModuleConfig 2 | import torch 3 | import numpy as np 4 | import open3d as o3d 5 | import torch 6 | import pandas as pd 7 | import torch_geometric 8 | import os 9 | import torch 10 | from torch_geometric.data import Dataset, Data 11 | from torch_geometric.transforms import FaceToEdge, RandomRotate 12 | import shutil 13 | import torchvision.transforms as transforms 14 | from loaders.factory import register 15 | from dataclasses import dataclass 16 | import trimesh 17 | 18 | 19 | class CenterTransform(object): 20 | def __call__(self, data): 21 | data.x = data.pos 22 | data.pos = None 23 | data.x -= data.x.mean() 24 | data.x /= data.x.pow(2).sum(axis=1).sqrt().max() 25 | return data 26 | 27 | 28 | @dataclass 29 | class ManifoldDataModuleConfig(DataModuleConfig): 30 | module: str = "datasets.manifold" 31 | num_samples: int = 100 32 | 33 | 34 | def read_ply(path): 35 | mesh = trimesh.load_mesh(path) 36 | pos = torch.from_numpy(mesh.vertices).to(torch.float) 37 | edge = torch.from_numpy(mesh.edges).to(torch.long).t() 38 | face = torch.from_numpy(mesh.faces).to(torch.long).t() 39 | return Data(pos=pos, edge_index=edge, face=face) 40 | 41 | 42 | class ManifoldDataModule(DataModule): 43 | def __init__(self, config): 44 | self.config = config 45 | self.transform = transforms.Compose( 46 | [RandomRotate(0.7), FaceToEdge(remove_faces=False), CenterTransform()] 47 | ) 48 | super().__init__( 49 | config.root, config.batch_size, config.num_workers, config.pin_memory 50 | ) 51 | 52 | def prepare_data(self): 53 | self.train_ds = ManifoldDataset( 54 | self.config, split="train", pre_transform=self.transform 55 | ) 56 | self.test_ds = ManifoldDataset( 57 | self.config, split="test", pre_transform=self.transform 58 | ) 59 | self.val_ds = ManifoldDataset( 60 | self.config, split="val", pre_transform=self.transform 61 | ) 62 | self.entire_ds = ManifoldDataset( 63 | self.config, split="train", pre_transform=self.transform 64 | ) 65 | 66 | def setup(self): 67 | pass 68 | 69 | 70 | class ManifoldDataset(Dataset): 71 | """Represents a 2D segmentation dataset. 72 | 73 | Input params: 74 | configuration: Configuration dictionary. 75 | """ 76 | 77 | def __init__(self, config, split, pre_transform): 78 | super().__init__( 79 | root=config.root, 80 | transform=pre_transform, 81 | pre_transform=pre_transform, 82 | pre_filter=None, 83 | ) 84 | self.config = config 85 | self.split = split 86 | self.clean() 87 | self.files = [] 88 | self.create_spheres() 89 | self.create_mobius() 90 | self.create_torus() 91 | self.file_frame = pd.DataFrame(self.files, columns=["filename", "y"]) 92 | 93 | def clean(self): 94 | path = f"{self.config.root}/manifold/{self.split}" 95 | shutil.rmtree(path, ignore_errors=True) 96 | os.makedirs(path, exist_ok=True) 97 | 98 | def get(self, index): 99 | if torch.is_tensor(index): 100 | index = index.tolist() 101 | 102 | file_name = self.file_frame.iloc[index, 0] 103 | y = self.file_frame.iloc[index, 1] 104 | data = read_ply(file_name) 105 | data.y = torch.tensor([y]) 106 | return data 107 | 108 | def len(self): 109 | # return the size of the dataset 110 | return len(self.file_frame) 111 | 112 | def create_spheres(self, noise=None): 113 | if not noise: 114 | noise = 0.3 115 | 116 | for i in range(self.config.num_samples): 117 | base_mesh = o3d.geometry.TriangleMesh.create_sphere() 118 | vertices = np.asarray(base_mesh.vertices) 119 | vertices += np.random.uniform(0, noise, size=vertices.shape) 120 | base_mesh.vertices = o3d.utility.Vector3dVector(vertices) 121 | base_mesh.compute_vertex_normals() 122 | f_name = ( 123 | f"{self.config.root}/manifold/{self.split}/sphere_{self.split}_{i}.ply" 124 | ) 125 | o3d.io.write_triangle_mesh(f_name, base_mesh) 126 | self.files.append([f_name, int(0)]) 127 | 128 | def create_mobius(self, noise=None): 129 | if not noise: 130 | noise = 0.3 131 | for i in range(self.config.num_samples): 132 | base_mesh = o3d.geometry.TriangleMesh.create_mobius() 133 | vertices = np.asarray(base_mesh.vertices) 134 | vertices += np.random.uniform(0, noise, size=vertices.shape) 135 | base_mesh.vertices = o3d.utility.Vector3dVector(vertices) 136 | base_mesh.compute_vertex_normals() 137 | f_name = ( 138 | f"{self.config.root}/manifold/{self.split}/mobius_{self.split}_{i}.ply" 139 | ) 140 | o3d.io.write_triangle_mesh(f_name, base_mesh) 141 | self.files.append([f_name, int(1)]) 142 | 143 | def create_torus(self, noise=None): 144 | if not noise: 145 | noise = 0.3 146 | 147 | for i in range(self.config.num_samples): 148 | base_mesh = o3d.geometry.TriangleMesh.create_torus() 149 | vertices = np.asarray(base_mesh.vertices) 150 | vertices += np.random.uniform(0, noise, size=vertices.shape) 151 | base_mesh.vertices = o3d.utility.Vector3dVector(vertices) 152 | base_mesh.compute_vertex_normals() 153 | f_name = ( 154 | f"{self.config.root}/manifold/{self.split}/torus_{self.split}_{i}.ply" 155 | ) 156 | o3d.io.write_triangle_mesh(f_name, base_mesh) 157 | self.files.append([f_name, int(2)]) 158 | 159 | 160 | def initialize(): 161 | register("dataset", ManifoldDataModule) 162 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | 4 | from torchvision.datasets import MNIST 5 | import torchvision.transforms as transforms 6 | 7 | from torch_geometric.data import InMemoryDataset 8 | from torch_geometric.transforms import FaceToEdge 9 | 10 | from datasets.transforms import CenterTransform 11 | from datasets.base_dataset import DataModule 12 | from datasets.transforms import MnistTransform 13 | 14 | from loaders.factory import register 15 | 16 | 17 | class MnistDataModule(DataModule): 18 | def __init__(self, config): 19 | self.config = config 20 | self.transform = transforms.Compose( 21 | [MnistTransform(), FaceToEdge(remove_faces=False), CenterTransform()] 22 | ) 23 | super().__init__( 24 | config.root, config.batch_size, config.num_workers, config.pin_memory 25 | ) 26 | 27 | def setup(self): 28 | self.entire_ds = MnistDataset( 29 | root=self.config.root, pre_transform=self.transform, train=True 30 | ) 31 | self.train_ds, self.val_ds = random_split( 32 | self.entire_ds, 33 | [ 34 | int(0.9 * len(self.entire_ds)), 35 | len(self.entire_ds) - int(0.9 * len(self.entire_ds)), 36 | ], 37 | ) # type: ignore 38 | 39 | self.test_ds = MnistDataset( 40 | root=self.config.root, pre_transform=self.transform, train=False 41 | ) 42 | 43 | 44 | class MnistDataset(InMemoryDataset): 45 | def __init__( 46 | self, root, transform=None, pre_transform=None, train=True, pre_filter=None 47 | ): 48 | self.train = train 49 | self.root = root 50 | super().__init__(root, transform, pre_transform, pre_filter) 51 | if train: 52 | self.data, self.slices = torch.load(self.processed_paths[0]) 53 | else: 54 | self.data, self.slices = torch.load(self.processed_paths[1]) 55 | 56 | @property 57 | def raw_file_names(self): 58 | return ["MNIST"] 59 | 60 | @property 61 | def processed_file_names(self): 62 | return ["train.pt", "test.pt"] 63 | 64 | def download(self): 65 | if self.train: 66 | MNIST(f"{self.root}/raw/", train=True, download=True) 67 | else: 68 | MNIST(f"{self.root}/raw/", train=False, download=True) 69 | 70 | def process(self): 71 | train_ds = MNIST(f"{self.root}/raw/", train=True, download=True) 72 | test_ds = MNIST(f"{self.root}/raw/", train=False, download=True) 73 | 74 | if self.pre_transform is not None: 75 | train_data_list = [self.pre_transform(data) for data in train_ds] 76 | test_data_list = [self.pre_transform(data) for data in test_ds] 77 | 78 | train_data, train_slices = self.collate(train_data_list) 79 | torch.save((train_data, train_slices), self.processed_paths[0]) 80 | test_data, test_slices = self.collate(test_data_list) 81 | torch.save((test_data, test_slices), self.processed_paths[1]) 82 | 83 | 84 | def initialize(): 85 | register("dataset", MnistDataModule) 86 | -------------------------------------------------------------------------------- /datasets/modelnet.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import ModelNet 2 | from torch_geometric import transforms 3 | from datasets.base_dataset import DataModule 4 | from torch.utils.data import random_split 5 | import torch 6 | 7 | from torch_geometric.transforms import FaceToEdge 8 | from datasets.transforms import CenterTransform, ModelNetTransform 9 | 10 | from loaders.factory import register 11 | 12 | class CenterTransformNew(object): 13 | def __call__(self, data): 14 | data.x -= data.x.mean(axis=0) 15 | data.x /= data.x.pow(2).sum(axis=1).sqrt().max() 16 | return data 17 | 18 | 19 | 20 | 21 | 22 | class Normalize(object): 23 | def __call__(self, data): 24 | mean = data.x.mean(axis=0) 25 | std = data.x.std(axis=0) 26 | data.x = (data.x - mean) / std 27 | return data 28 | 29 | class ModelNetDataModule(DataModule): 30 | def __init__(self, config): 31 | self.config = config 32 | self.pre_transform = transforms.Compose( 33 | [ 34 | transforms.SamplePoints(self.config.samplepoints), 35 | ModelNetTransform(), 36 | CenterTransform(), 37 | # Normalize(), 38 | ] 39 | ) 40 | super().__init__( 41 | config.root, config.batch_size, config.num_workers, config.pin_memory 42 | ) 43 | 44 | self.setup() 45 | 46 | def setup(self): 47 | self.entire_ds = ModelNet( 48 | root=self.config.root, 49 | pre_transform=self.pre_transform, 50 | train=True, 51 | name=self.config.name, 52 | ) 53 | self.train_ds, self.val_ds = random_split( 54 | self.entire_ds, 55 | [ 56 | int(0.8 * len(self.entire_ds)), 57 | len(self.entire_ds) - int(0.8 * len(self.entire_ds)), 58 | ], 59 | ) # type: ignore 60 | self.test_ds = ModelNet( 61 | root=self.config.root, 62 | pre_transform=self.pre_transform, 63 | train=False, 64 | name=self.config.name, 65 | ) 66 | 67 | 68 | def initialize(): 69 | register("dataset", ModelNetDataModule) 70 | -------------------------------------------------------------------------------- /datasets/moleculenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | from dataclasses import dataclass 4 | 5 | import torchvision.transforms as transforms 6 | from torch_geometric.transforms import BaseTransform 7 | from torch_geometric.datasets import MoleculeNet 8 | 9 | from datasets.transforms import CenterTransform 10 | from datasets.base_dataset import DataModule 11 | from loaders.factory import register 12 | from datasets.config import DataModuleConfig 13 | 14 | 15 | 16 | @dataclass 17 | class MoleculeNetDataModuleConfig(DataModuleConfig): 18 | root: str = "./data/moleculenet" 19 | module: str = "datasets.moleculenet" 20 | name: str = "HIV" 21 | 22 | 23 | 24 | class ConvertToFloat(BaseTransform): 25 | def __call__(self, data): 26 | """Convert node feature data type to float.""" 27 | data["x"] = data["x"].to(dtype=torch.float) 28 | return data 29 | 30 | 31 | class OneHotDecoding(BaseTransform): 32 | def __call__(self, data): 33 | """Adjust multi-class labels (reverse one-hot encoding). 34 | 35 | This is necessary because some data sets use one-hot encoding 36 | for their labels, wreaks havoc with some multi-class tasks. 37 | """ 38 | label = data["y"] 39 | 40 | if len(label.shape) > 1: 41 | label = label.squeeze().tolist() 42 | 43 | if isinstance(label, list): 44 | label = label.index(1.0) 45 | 46 | data["y"] = torch.as_tensor([label], dtype=torch.long) 47 | 48 | return data 49 | 50 | 51 | class MoleculeNetDataModule(DataModule): 52 | def __init__(self, config): 53 | self.config = config 54 | self.transform = transforms.Compose( 55 | [OneHotDecoding(), ConvertToFloat(), CenterTransform()] 56 | ) 57 | super().__init__( 58 | config.root, config.batch_size, config.num_workers, config.pin_memory 59 | ) 60 | 61 | def setup(self): 62 | self.entire_ds = MoleculeNet( 63 | root=self.config.root, pre_transform=self.transform, name=self.config.name 64 | ) 65 | self.train_ds, self.test_ds = random_split( 66 | self.entire_ds, 67 | [ 68 | int(0.9 * len(self.entire_ds)), 69 | len(self.entire_ds) - int(0.9 * len(self.entire_ds)), 70 | ], 71 | ) # type: ignore 72 | 73 | self.train_ds, self.val_ds = random_split( 74 | self.train_ds, 75 | [ 76 | int(0.9 * len(self.train_ds)), 77 | len(self.train_ds) - int(0.9 * len(self.train_ds)), 78 | ], 79 | ) # type: ignore 80 | 81 | 82 | def initialize(): 83 | register("dataset", MoleculeNetDataModule) 84 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree 3 | from torch_geometric.data import Data 4 | import matplotlib.pyplot as plt 5 | import torchvision 6 | import vedo 7 | 8 | 9 | def plot_batch(data): 10 | coords = data.x.cpu().numpy() 11 | 12 | fig = plt.figure(figsize=(12, 12)) 13 | ax = fig.add_subplot(projection="3d") 14 | 15 | sequence_containing_x_vals = coords[:, 0] 16 | sequence_containing_y_vals = coords[:, 1] 17 | sequence_containing_z_vals = coords[:, 2] 18 | 19 | ax.scatter( 20 | sequence_containing_x_vals, 21 | sequence_containing_y_vals, 22 | sequence_containing_z_vals, 23 | ) 24 | ax.set_xlim([-1, 1]) 25 | ax.set_ylim([-1, 1]) 26 | ax.set_zlim([-1, 1]) 27 | plt.show() 28 | 29 | 30 | class ThresholdTransform(object): 31 | def __call__(self, data): 32 | data.x = torch.hstack([data.pos, data.x]) 33 | return data 34 | 35 | 36 | class CenterTransform(object): 37 | def __call__(self, data): 38 | data.x -= data.x.mean() 39 | data.x /= data.x.pow(2).sum(axis=1).sqrt().max() 40 | return data 41 | 42 | 43 | class Normalize(object): 44 | def __call__(self, data): 45 | mean = data.x.mean() 46 | std = data.x.std() 47 | data.x = (data.x - mean) / std 48 | return data 49 | 50 | 51 | class NormalizedDegree(object): 52 | def __init__(self, mean, std): 53 | self.mean = mean 54 | self.std = std 55 | 56 | def __call__(self, data): 57 | deg = degree(data.edge_index[0], dtype=torch.float) 58 | deg = (deg - self.mean) / self.std 59 | data.x = deg.view(-1, 1) 60 | return data 61 | 62 | 63 | class NCI109Transform(object): 64 | def __call__(self, data): 65 | deg = degree(data.edge_index[0], dtype=torch.float).unsqueeze(0).T 66 | atom_number = torch.argmax(data.x, dim=-1, keepdim=True) 67 | data.x = torch.hstack([deg, atom_number]) 68 | return data 69 | 70 | 71 | class ModelNetTransform(object): 72 | def __call__(self, data): 73 | data.x = data.pos 74 | data.pos = None 75 | return data 76 | 77 | 78 | class Project(object): 79 | def __call__(self, batch): 80 | batch.x = batch.x[:, :2] 81 | # scaling 82 | return batch 83 | 84 | 85 | class MnistTransform: 86 | def __init__(self): 87 | xcoords = torch.linspace(-0.5, 0.5, 28) 88 | ycoords = torch.linspace(-0.5, 0.5, 28) 89 | self.X, self.Y = torch.meshgrid(xcoords, ycoords) 90 | self.tr = torchvision.transforms.ToTensor() 91 | 92 | def __call__(self, data: tuple) -> Data: 93 | img, y = data 94 | img = self.tr(img) 95 | idx = torch.nonzero(img.squeeze(), as_tuple=True) 96 | gp = torch.vstack([self.X[idx], self.Y[idx]]).T 97 | dly = vedo.delaunay2d(gp, mode="xy", alpha=0.03).c("w").lc("o").lw(1) 98 | 99 | return Data( 100 | x=torch.tensor(dly.points()), 101 | face=torch.tensor(dly.faces(), dtype=torch.long).T, 102 | y=torch.tensor(y, dtype=torch.long), 103 | ) 104 | 105 | 106 | class WeightedMnistTransform: 107 | def __init__(self): 108 | self.grid = torch.load("./datasets/grid.pt") 109 | self.tr = torchvision.transforms.ToTensor() 110 | 111 | def __call__(self, data: tuple) -> Data: 112 | img, y = data 113 | img = self.tr(img) 114 | 115 | return Data( 116 | x=self.grid.x, 117 | node_weights = img.view(-1,1), 118 | edge_index=self.grid.edge_index, 119 | face=self.grid.face, 120 | y=torch.tensor(y, dtype=torch.long), 121 | ) -------------------------------------------------------------------------------- /datasets/tu.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import TUDataset 2 | 3 | from torch.utils.data import random_split 4 | 5 | from torch_geometric import transforms 6 | from torch_geometric.transforms import OneHotDegree 7 | 8 | from datasets.config import DataModuleConfig 9 | from datasets.base_dataset import DataModule 10 | from datasets.transforms import ( 11 | CenterTransform, 12 | NormalizedDegree, 13 | NCI109Transform, 14 | ) 15 | 16 | from dataclasses import dataclass 17 | from loaders.factory import register 18 | 19 | 20 | transforms_dict = { 21 | "DD": [CenterTransform()], 22 | "ENZYMES": [CenterTransform()], 23 | "IMDB-BINARY": [OneHotDegree(540), CenterTransform()], 24 | "Letter-high": [CenterTransform()], 25 | "Letter-med": [CenterTransform()], 26 | "Letter-low": [CenterTransform()], 27 | "PROTEINS_full": [CenterTransform()], 28 | "REDDIT-BINARY": [ 29 | NormalizedDegree(2.31, 20.66), 30 | CenterTransform(), 31 | ], 32 | "NCI1": [NCI109Transform(), CenterTransform()], 33 | "NCI109": [CenterTransform()], 34 | "BZR": [CenterTransform()], 35 | "COX2": [CenterTransform()], 36 | "FRANKENSTEIN": [CenterTransform()], 37 | "Fingerprint": [CenterTransform()], 38 | "Cuneiform": [CenterTransform()], 39 | "COLLAB": [CenterTransform()], 40 | "DHFR": [CenterTransform()], 41 | } 42 | 43 | # ╭──────────────────────────────────────────────────────────╮ 44 | # │ Datasets │ 45 | # ╰──────────────────────────────────────────────────────────╯ 46 | 47 | 48 | """ 49 | Define the dataset classes, provide dataset/dataloader parameters 50 | in the config file or overwrite them in the class definition. 51 | """ 52 | 53 | 54 | @dataclass 55 | class TUBZRConfig(DataModuleConfig): 56 | module: str = "datasets.tu" 57 | name: str = "BZR" 58 | cleaned: bool = True 59 | use_node_attr: bool = True 60 | 61 | 62 | @dataclass 63 | class TUCOX2Config(DataModuleConfig): 64 | module: str = "datasets.tu" 65 | name: str = "COX2" 66 | cleaned: bool = True 67 | use_node_attr: bool = True 68 | 69 | 70 | @dataclass 71 | class TUFrankensteinConfig(DataModuleConfig): 72 | module: str = "datasets.tu" 73 | name: str = "FRANKENSTEIN" 74 | cleaned: bool = True 75 | use_node_attr: bool = True 76 | 77 | 78 | @dataclass 79 | class TUFingerprintConfig(DataModuleConfig): 80 | module: str = "datasets.tu" 81 | name: str = "Fingerprint" 82 | cleaned: bool = True 83 | use_node_attr: bool = True 84 | 85 | 86 | @dataclass 87 | class TUCuneiformConfig(DataModuleConfig): 88 | module: str = "datasets.tu" 89 | name: str = "Cuneiform" 90 | cleaned: bool = True 91 | use_node_attr: bool = True 92 | 93 | 94 | @dataclass 95 | class TUCollabConfig(DataModuleConfig): 96 | module: str = "datasets.tu" 97 | name: str = "COLLAB" 98 | cleaned: bool = True 99 | use_node_attr: bool = True 100 | 101 | 102 | @dataclass 103 | class TUDHFRConfig(DataModuleConfig): 104 | module: str = "datasets.tu" 105 | name: str = "DHFR" 106 | cleaned: bool = True 107 | use_node_attr: bool = True 108 | 109 | 110 | @dataclass 111 | class TUBBBPConfig(DataModuleConfig): 112 | module: str = "datasets.tu" 113 | name: str = "BBBP" 114 | cleaned: bool = True 115 | use_node_attr: bool = True 116 | 117 | 118 | @dataclass 119 | class TUNCI109Config(DataModuleConfig): 120 | module: str = "datasets.tu" 121 | name: str = "NCI109" 122 | cleaned: bool = True 123 | use_node_attr: bool = True 124 | 125 | 126 | @dataclass 127 | class TUNCI1Config(DataModuleConfig): 128 | module: str = "datasets.tu" 129 | name: str = "NCI1" 130 | cleaned: bool = True 131 | use_node_attr: bool = True 132 | 133 | 134 | @dataclass 135 | class TUDDConfig(DataModuleConfig): 136 | module: str = "datasets.tu" 137 | name: str = "DD" 138 | cleaned: bool = True 139 | use_node_attr: bool = True 140 | 141 | 142 | @dataclass 143 | class TUEnzymesConfig(DataModuleConfig): 144 | module: str = "datasets.tu" 145 | name: str = "ENZYMES" 146 | cleaned: bool = False 147 | use_node_attr: bool = True 148 | 149 | 150 | @dataclass 151 | class TUIMDBBConfig(DataModuleConfig): 152 | module: str = "datasets.tu" 153 | name: str = "IMDB-BINARY" 154 | cleaned: bool = True 155 | use_node_attr: bool = True 156 | 157 | 158 | @dataclass 159 | class TUProteinsFullConfig(DataModuleConfig): 160 | module: str = "datasets.tu" 161 | name: str = "PROTEINS_full" 162 | cleaned: bool = False 163 | use_node_attr: bool = True 164 | 165 | 166 | @dataclass 167 | class TURedditBConfig(DataModuleConfig): 168 | module: str = "datasets.tu" 169 | name: str = "REDDIT-BINARY" 170 | cleaned: bool = True 171 | use_node_attr: bool = True 172 | 173 | 174 | @dataclass 175 | class TULetterHighConfig(DataModuleConfig): 176 | name: str = "Letter-high" 177 | module: str = "datasets.tu" 178 | cleaned: bool = False 179 | drop_last: bool = False 180 | 181 | 182 | @dataclass 183 | class TULetterMedConfig(DataModuleConfig): 184 | name: str = "Letter-med" 185 | module: str = "datasets.tu" 186 | cleaned: bool = False 187 | drop_last: bool = False 188 | 189 | 190 | @dataclass 191 | class TULetterLowConfig(DataModuleConfig): 192 | name: str = "Letter-low" 193 | module: str = "datasets.tu" 194 | cleaned: bool = False 195 | drop_last: bool = False 196 | 197 | 198 | class TUDataModule(DataModule): 199 | """ 200 | This datamodule loads the base TUDatasets without transforming. 201 | See below how to add a transform the easiest way. 202 | """ 203 | 204 | def __init__(self, config): 205 | self.config = config 206 | super().__init__( 207 | config.root, 208 | config.batch_size, 209 | config.num_workers, 210 | drop_last=self.config.drop_last, 211 | ) 212 | 213 | def setup(self): 214 | self.entire_ds = TUDataset( 215 | pre_transform=transforms.Compose(transforms_dict[self.config.name]), 216 | name=self.config.name, 217 | root=self.config.root, 218 | cleaned=self.config.cleaned, 219 | use_node_attr=True, 220 | ) 221 | inter_ds, self.test_ds = random_split( 222 | self.entire_ds, 223 | [ 224 | int(0.8 * len(self.entire_ds)), 225 | len(self.entire_ds) - int(0.8 * len(self.entire_ds)), 226 | ], 227 | ) # type: ignore 228 | self.train_ds, self.val_ds = random_split(inter_ds, [int(0.8 * len(inter_ds)), len(inter_ds) - int(0.8 * len(inter_ds))]) # type: ignore 229 | 230 | 231 | def initialize(): 232 | register("dataset", TUDataModule) 233 | -------------------------------------------------------------------------------- /datasets/weighted_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | 4 | from torchvision.datasets import MNIST 5 | import torchvision.transforms as transforms 6 | 7 | from torch_geometric.data import InMemoryDataset 8 | from torch_geometric.transforms import FaceToEdge 9 | 10 | from datasets.transforms import CenterTransform 11 | from datasets.base_dataset import DataModule 12 | from datasets.transforms import WeightedMnistTransform 13 | 14 | from loaders.factory import register 15 | 16 | 17 | class WeightedMnistDataModule(DataModule): 18 | def __init__(self, config): 19 | self.config = config 20 | self.transform = transforms.Compose( 21 | [WeightedMnistTransform(), FaceToEdge(remove_faces=False), CenterTransform()] 22 | ) 23 | super().__init__( 24 | config.root, config.batch_size, config.num_workers, config.pin_memory 25 | ) 26 | 27 | def setup(self): 28 | self.entire_ds = torch.utils.data.ConcatDataset( 29 | [ 30 | WeightedMnistDataset( root=self.config.root, pre_transform=self.transform, train=True), 31 | WeightedMnistDataset( root=self.config.root, pre_transform=self.transform, train=False ) 32 | ] 33 | ) 34 | self.train_ds, self.test_ds = random_split( 35 | self.entire_ds, 36 | [ 37 | int(0.9 * len(self.entire_ds)), 38 | len(self.entire_ds) - int(0.9 * len(self.entire_ds)), 39 | ], 40 | ) # type: ignore 41 | 42 | self.train_ds, self.val_ds = random_split( 43 | self.train_ds, 44 | [ 45 | int(0.9 * len(self.train_ds)), 46 | len(self.train_ds) - int(0.9 * len(self.train_ds)), 47 | ], 48 | ) # type: ignore 49 | 50 | 51 | class WeightedMnistDataset(InMemoryDataset): 52 | """ 53 | This generates the Weighted Complexes for the WECT. 54 | """ 55 | def __init__( 56 | self, root, transform=None, pre_transform=None, train=True, pre_filter=None 57 | ): 58 | self.train = train 59 | self.root = root 60 | super().__init__(root, transform, pre_transform, pre_filter) 61 | if train: 62 | self.data, self.slices = torch.load(self.processed_paths[0]) 63 | else: 64 | self.data, self.slices = torch.load(self.processed_paths[1]) 65 | 66 | @property 67 | def raw_file_names(self): 68 | return ["WMNIST"] 69 | 70 | @property 71 | def processed_file_names(self): 72 | return ["train.pt", "test.pt"] 73 | 74 | def download(self): 75 | if self.train: 76 | MNIST(f"{self.root}/raw/", train=True, download=True) 77 | else: 78 | MNIST(f"{self.root}/raw/", train=False, download=True) 79 | 80 | def process(self): 81 | train_ds = MNIST(f"{self.root}/raw/", train=True, download=True) 82 | test_ds = MNIST(f"{self.root}/raw/", train=False, download=True) 83 | 84 | if self.pre_transform is not None: 85 | train_data_list = [self.pre_transform(data) for data in train_ds] 86 | test_data_list = [self.pre_transform(data) for data in test_ds] 87 | 88 | train_data, train_slices = self.collate(train_data_list) 89 | torch.save((train_data, train_slices), self.processed_paths[0]) 90 | test_data, test_slices = self.collate(test_data_list) 91 | torch.save((test_data, test_slices), self.processed_paths[1]) 92 | 93 | 94 | def initialize(): 95 | register("dataset", WeightedMnistDataModule) 96 | -------------------------------------------------------------------------------- /dect/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Ernst Roell 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 19 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 20 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 25 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 26 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 27 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /dect/README.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/dect-evaluation/87640a57eedb6f528569b8f66647a0e31828f156/dect/README.rst -------------------------------------------------------------------------------- /dect/__init__.py: -------------------------------------------------------------------------------- 1 | from dect import * -------------------------------------------------------------------------------- /dect/dect/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/dect-evaluation/87640a57eedb6f528569b8f66647a0e31828f156/dect/dect/__init__.py -------------------------------------------------------------------------------- /dect/dect/ect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.data import Data 4 | 5 | from typing import Protocol 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass(frozen=True) 10 | class EctConfig: 11 | num_thetas: int = 32 12 | bump_steps: int = 32 13 | R: float = 1.1 14 | ect_type: str = "points" 15 | device: str = "cpu" 16 | num_features: int = 3 17 | normalized: bool = False 18 | 19 | 20 | def compute_ecc(nh, index, lin, dim_size, out): 21 | ecc = torch.nn.functional.sigmoid(50 * torch.sub(lin, nh)) 22 | return torch.index_add(out, 1, index, ecc).movedim(0, 1) 23 | 24 | 25 | def compute_ecc_derivative(nh, index, lin, dim_size, out): 26 | ecc = torch.nn.functional.sigmoid(100 * torch.sub(lin, nh)) * ( 27 | 1 - torch.nn.functional.sigmoid(100 * torch.sub(lin, nh)) 28 | ) 29 | return torch.index_add(out, 1, index, ecc).movedim(0, 1) 30 | 31 | 32 | def compute_ect_points_derivative(data, v, lin, out): 33 | nh = data.x @ v 34 | return compute_ecc_derivative(nh, data.batch, lin, data.num_graphs, out) 35 | 36 | 37 | def compute_ect_points(data, v, lin, out): 38 | nh = data.x @ v 39 | return compute_ecc(nh, data.batch, lin, data.num_graphs, out) 40 | 41 | 42 | def compute_ect_edges(data, v, lin, out): 43 | nh = data.x @ v 44 | eh, _ = nh[data.edge_index].max(dim=0) 45 | return compute_ecc(nh, data.batch, lin, data.num_graphs, out) - compute_ecc( 46 | eh, data.batch[data.edge_index[0]], lin, data.num_graphs, out 47 | ) 48 | 49 | 50 | def compute_ect_faces(data, v, lin, out): 51 | nh = data.x @ v 52 | eh, _ = nh[data.edge_index].max(dim=0) 53 | fh, _ = nh[data.face].max(dim=0) 54 | return ( 55 | compute_ecc(nh, data.batch, lin, data.num_graphs, out) 56 | - compute_ecc(eh, data.batch[data.edge_index[0]], lin, data.num_graphs, out) 57 | + compute_ecc(fh, data.batch[data.face[0]], lin, data.num_graphs, out) 58 | ) 59 | 60 | 61 | class EctLayer(nn.Module): 62 | """docstring for EctLayer.""" 63 | 64 | def __init__(self, config: EctConfig, V=None): 65 | super().__init__() 66 | self.config = config 67 | self.lin = ( 68 | torch.linspace(-config.R, config.R, config.bump_steps) 69 | .view(-1, 1, 1) 70 | .to(config.device) 71 | ) 72 | 73 | if torch.is_tensor(V): 74 | self.v = V 75 | else: 76 | self.v = ( 77 | torch.rand(size=(config.num_features, config.num_thetas)) - 0.5 78 | ).T.to(config.device) 79 | self.v /= self.v.pow(2).sum(axis=1).sqrt().unsqueeze(1) 80 | self.v = self.v.T 81 | 82 | if config.ect_type == "points": 83 | self.compute_ect = compute_ect_points 84 | elif config.ect_type == "edges": 85 | self.compute_ect = compute_ect_edges 86 | elif config.ect_type == "faces": 87 | self.compute_ect = compute_ect_faces 88 | elif config.ect_type == "points_derivative": 89 | self.compute_ect = compute_ect_points_derivative 90 | 91 | def forward(self, data): 92 | out = torch.zeros( 93 | size=( 94 | self.config.bump_steps, 95 | data.batch.max().item() + 1, 96 | self.config.num_thetas, 97 | ), 98 | device=self.config.device, 99 | ) 100 | ect = self.compute_ect(data, self.v, self.lin, out) 101 | if self.config.normalized: 102 | return ect / torch.amax(ect, dim=(1, 2)).unsqueeze(1).unsqueeze(1) 103 | else: 104 | return ect 105 | -------------------------------------------------------------------------------- /dect/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dect" 7 | version = "0.0.0" 8 | dependencies = [ 9 | "numpy", 10 | "torch", 11 | "torch_geometric" 12 | ] 13 | requires-python = ">=3.8" 14 | authors = [ 15 | {name = "Ernst Röell", email = "ernst.roeell@helmholtz-munich.de"}, 16 | {name = "Bastian Rieck", email = "bastian.rieck@helmholtz-munich.de"}, 17 | ] 18 | maintainers = [ 19 | {name = "Ernst Röell", email = "ernst.roeell@helmholtz-munich.de"}, 20 | ] 21 | description = "A fast package to compute the Euler Characteristic Transform" 22 | readme = "README.rst" 23 | license = {file = "LICENSE.txt"} 24 | keywords = ["euler", "characteristic", "topology", "tda", "transform"] 25 | classifiers = [ 26 | "Development Status :: 4 - Beta", 27 | "Programming Language :: Python" 28 | ] 29 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAezElEQVR4nO3dfXRU9eHn8c+dGYlNfkB/SZx02zKlNA+tMduz28MJskY6FRFsf7GBs4C/ID7hc44rR9AYqzzURqg8WEWrVmytQuNjLFKrlDbyy6lI4+o55mRdktSFidomm4SGkKwDw8z+gfkWUCGTuXce368/O/F7v+Ekfefe7733a0UikYgAAJDkSvQEAADJgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADKIAADCIAgDAIAoAAIMoAAAMogAAMIgCAMAgCgAAgygAAAyiAAAwiAIAwCAKAACDKAAADE+iJwDEw1AwpH19QzocCmucx6XJeTnKyeLHHzgZvxVIWx3dg9qyJ6CmvT0K9A8rctxnliRfbrb8JV5Vl/tUVDA+UdMEkooViUQip/8yIHV09Q+rrrFVzZ29crssHQ1//o/4yOcVhfmqryrTpNzsOM4USD5EAWmloSWgFdvaFApHThmDk7ldljwuS6sqS7Vwqs/BGQLJjSggbWxq6tC6He0xj7NsVrFq/EU2zAhIPdx9hLTQ0BKwJQiStG5Hu55pCdgyFpBqiAJSXlf/sFZsa7N1zLu3tamrf9jWMYFUQBSQ8uoaWxWKYv1gNELhiOoaW20dE0gFRAEpraN7UM2dvVEtKo/G0XBEzZ296uwZtHVcINkRBaS0LXsCcrssR8Z2uyw9/SZrC8gsRAEprWlvj+1nCSOOhiNqau9xZGwgWREFpKxDwZACDi8GB/qGNRQMOXoMIJkQBaSs/X1Dcvohm4ikfX1DDh8FSB5EASnrcCicVscBkgFRQMoa54nPj2+8jgMkA37akbIm5+XImfuO/sn65DhApiAKSFk5WR75HH6rqS8vm30XkFGIAlKav8Tr6HMK/mKvI2MDyYooIKVVl/scfU5h0TReo43MQhSQ0ooKxquiMF9uu08WImH9v//ztpZe/e/q7Oy0eXAgeREFpLzK/zSk0JHDsnNrkKwzPFr9b2fr3XffVWlpqe644w4dOnTItvGBZEUUkLIikYh+9rOf6dLKWfJ2vS7Lsu90YXVlqa6rnqf33ntPd9xxh+6//36VlJRo69attsYHSDZEASlpeHhYixcv1i233KJbbrlFu59er2Wzim0Ze/msEi34ZEvO7OxsrVy5Uu+9957OPfdcVVdXq6KiQu+8844txwKSDVFAytm3b5/OO+88vfDCC9q6davWrVsnj8ejGn+R1swtU5bHFfUdSW6XpSyPS2vnlukmf+GnPp88ebKef/55/eEPf9CBAwf0ne98R9dff716e3vt+raApMAezUgpf/zjH7VgwQKNHz9eL730kr797W9/6mu6+odV19iq5s5euV3WKe9OGvm8ojBf9VVlmjSK5x6OHDmihx9+WCtWrJBlWfrxj3+s66+/Xh5PbM8zDAVD2tc3pMOhsMZ5XJqcl8MzEog7ooCUEIlEtGHDBt1222264IIL9Jvf/EZ5eXmn/G86uge1ZU9ATe09CvQNn/DyPEvHHkzzF3u1aJpPhd7xUc+pp6dHd955pzZv3qxzzjlHDzzwgL773e9GNYaZ494eBfo/Y4652fKXeFVd7lNRQfRzBKJFFJD0hoaGtGTJEjU0NOj222/XT37yE7nd7ujGcPCv8JaWFt1888168803NX/+fN13333y+U79fIPTZzPAWBEFJLX3339fVVVV+utf/6onnnhC8+fPT/SUPlM4HNbTTz+t22+/XQMDA7rjjju0fPlynXnmmZ/62oaWgFZsa1MoHInqwTu3y5LHZWlVZakWTuWhOjiDKCBpvfbaa7r00kuVm5url156Seecc06ip3RaBw8e1D333KP7779fX/3qV7VhwwZdcskl5nbZTU0dWrejPebjLJtVrBp/UczjACfj7iM4YigYUttHA3oncEBtHw1EtXtZJBLR2rVrdfHFF2vatGlqaWlJiSBI0oQJE/TTn/5Ura2tKikpUVVVlS666CK99957amgJ2BIESVq3o13PtLB/NOzHmQJsY8ei6aFDh3TllVfq+eef15133qlVq1ZFvX6QLCKRiLZv365bbrlFHw58rC8v+bnCln3fS5bHpZ1LZ7DGAFsRBcTMrkXTjo4OVVVVaf/+/fr1r3+tqqqqeEzfcR9//LEuuOclfXAkW5bLvii4XZamT8nTU1eX2zYmwOUjxKShJaCZG3fpjff7JOm0C6cjn7/xfp9mbtylhk8ugbzyyiuaOnWqjhw5oj179qRNECSpa+CIPjw63tYgSMf+LZs7e9XZM2jruMhsRAFjtqmpQ7UvtioYCkf9+uqj4YiCobBqX2zVvLt/oR/84Ac6//zz9Ze//EVnn322QzNOjC17Ao7u+fD0m6wtwD5EAWNi56Lp/zzyZV1654N66aWXNHHiRFvGTCZNe3sc3fOhqb3HkbGRmYgCotbVP6wV29psHDGitzRFH/7jYxvHTA6HgiEF+ocdPUagbziqu7uAUyEKiFpdY6tCtv7laykUjqiusdXGMZPD/r4hOX0nR0TSvr4hh4+CTEEUEJWO7kE1d/bafjkkXRdND4fCaXUcpD+igKiwaBqdcZ74/IrF6zhIf/wkISosmkZncl6OnEnoP1mfHAewA1HAqLFoGr2cLI98Dj9x7MvLZt8F2IYoYNRYNB0bf4nX0Utu/mKvI2MjMxEFjBqLpmNTXe5z9JLbomm8Rhv2IQoYNRZNx6aoYLwqCvNtP1twuyxVFOaPadc44POk128fHMWi6djVV5XJY3MUPC5L9VVlto4JEAWMGoumYzcpN1urKkttHXN1ZSmvzYbtiAKiwqLp2C2c6tOyWcW2jLV8VokWsCUnHEAUEBUWTWNT4y/SmrllyvK4oo6r22Upy+PS2rlluslf6NAMkemIAqLComnsFk71aefSGZo+JU+STvtvOfL59Cl52rl0BmcIcBQ7ryFqXf3Dmrlxl4I23jqaqVtLmi1M23sU6PuMLUzzsuUv9mrRNF9GBBOJRxQwJg0tAdW+aN9bTdfOLcv4v4CHgiHt6xuS/4ILtXjRv2vVspq0XHRHcuPyEcaERVP75WR5VPrlifqXYK/OONRNEJAQ/NRhzGr8Rcr/lyyt2NamUDgS1QK022XJ47K0urKUIJxk4sSJOnjwYKKngQzFmQJiwqKp/SZMmEAUkDCcKSBmk3Kz9dTV5Sya2oQoIJFYaIYjhoIh7f2oX+f+twrds2qlblz837lGPkrV1dX68MMP9frrryd6KshAXD6CI3KyPPqvX/cq69Df5Dr4EUGIAmcKSCSiAEd5vV719KTXbmpOY6EZiUQU4CiiED3OFJBIRAGOIgrRIwpIJKIARxUUFKi7uzvR00gpEyZMUDAYVDAYTPRUkIGIAhzFmUL0Jk6cKEmcLSAhiAIcNRIF7nwevQkTJkgiCkgMogBHeb1eBYNBDQ4OJnoqKYMoIJGIAhxVUFAgSawrRGEkCgMDAwmeCTIRUYCjvN5j22uyrjB6nCkgkYgCHEUUosdCMxKJKMBRubm5crvdRCEKWVlZOuOMM4gCEoIowFEul0tnnXUWawpRsCyLB9iQMEQBjuNZhehNmDCBhWYkBK+uhOOIQvTG/2u+ug5F9E7ggMZ5XJqcl8ObZhEX7KcAx7E/wOiYTYr29mh/35Bk/XMXO0uSLzdb/hKvqst9KipgkyI4gz894LiCggK9/fbbiZ5G0urqH1ZdY6uaO3vldlnH9rq2TtzWNCJpf/+wntqzX7/avU8VhfmqryrTpNzsxEwaaYs1BTiOy0efr6EloJkbd+mN9/sk6VgQTmHk8zfe79PMjbvU0BJwfI7ILEQBjvN6verv79eRI0cSPZWksqmpQ7UvtioYCp82Bic7Go4oGAqr9sVWbWrqcGiGyEREAY4beYCtt7c3wTNJHg0tAa3b0W7LWOt2tOsZzhhgE6IAxxUUFMg640zt/t9deidwQG0fDWgoGEr0tBKmq39YK7a12Trm3dva1NU/bOuYyEzcfQTHjNxNs6PtI304EJTF3TSSpMs279Eb7/dFfcnoVNwuS9On5Ompq8ttGxOZiSjAdp95N83nGPk8U+6m6ege1IX3/4dj4+9cer4KvZkTWNiPy0ewFXfTnNqWPQG5Xdbpv3AM3C5LT7+Z3v9+cB5RgG24m+b0mvb22HrZ6HhHwxE1tXPrL2JDFGAL7qY5vUPBkAIOLwYH+oYzehEfsSMKiBl304zO/r4hOb2AF5G0r2/I4aMgnREFxKyusVUhmy+JhMIR1TW22jpmoh0OhdPqOEhPRAEx6egeVHNnr+3XyY+GI2ru7FVnz6Ct4ybSOE98ft3idRykJ356EBPuphm9yXk5cuZf6p+sT44DjBVRQEy4m2b0crI88jn8HIYvL5t9FxATooAx426a6PlLvI6eWfmLvY6MjcxBFDBm3E0Tvepyn6NnVoum+RwZG5mDKGDMuJsmekUF41VRmG/72YLbZamiMJ9XXCBmRAFjxt00Y1NfVSYrEpadrx3zuCzVV5XZNh4yV3r9tiGuuJsmepFIRFt/8aC6f/fgCW+NjdXqytK0f5kg4oMoYMy4myY6R48e1c0336za2lotvaRct15YbMu4y2eVaMFU1hJgj/T4bUPC+Eu8emrPfkcWT9PpbpqPP/5YixYtUmNjox555BFdd911kqSzxmdpxbY2hcKRqP4N3S5LHpel1ZWlBAG2Yj8FxIT9AU7vwIEDuuSSS/TWW2+poaFBlZWVJ3zO/hNIJkQBMXNiJzFFwprqm6jnbjzfvjETIBAIaM6cOeru7tbLL7+sc88993O/dmSnuqb2HgX6hk+43dfSsUtp/mKvFk3zpXwokbyIAmLW1T+smRt3KWjbraMRRUJHNNhwm9avrtMVV1xh66JsvLz77ruaM2eOxo0bp1dffVUlJSWj/m+HgiHt6xvS4VBY4zwuTc7LSZu1FSQ3FpoRs0m52VpVWWrjiJZWfP+b+rfvTddVV12liy++WB988IGN4zuvqalJFRUVKigo0O7du6MKgnRsEb/0yxP1X3z/qtIvTyQIiBuiAFssnOrTsln23U1z1Xe/pSeffFLbt2/Xu+++q9LSUm3evNnWe/ud8swzz2j27NkqLy/Xrl279KUvfSnRUwJGjSjANjX+Iq2ZW6YsjyvqJ3bdLktZHpfWzi3TTf5C879///vfV1tbm+bNm6clS5Zo9uzZCgSS982pGzZs0MKFCzV//nxt375d48dz7R+phTUF2M6pu2l+//vf65prrtHBgwe1fv16LVmyJGnWGsLhsJYtW6aNGzeqtrZW9fX1STM3IBpEAY5x4m6agYEBLVu2TI8//rhmzpypxx9/XF/72tccmf9oBYNBXX755Xr22Wf1wAMPqKamJqHzAWJBFBAXdt9N89prr+maa67RgQMHdN999+naa6+VyzX2q6Fjnd/AwIB++MMfavfu3dq6davmzp075jkAyYAoIGUdPHhQy5cv12OPPabvfe97evzxx/X1r3991P+9OZPZ26NA/2ecyeRmy1/iVXW5T0UFnz6T+fDDDzVnzhx1dXXp5Zdf1nnnnRf7NwUkGFFAytu5c6eWLFmi3t5erV27VjfccMMpzxrsWPNoa2vTnDlzJEmvvvqqzj77bHu/KSBBiALSwuDgoG677TY98sgjmjFjhjZv3qxvfOMbn/q6hpZATO8aWlVZqq98vF+VlZXy+Xx65ZVX9JWvfMXObwVIKKKAtPKnP/1JV199tXp6erRmzRrddNNN5qxhU1OH1u1oj/kYg3/eqm97/qbGxkZNnDgx5vGAZEIUkHYOHTqk2tpaPfTQQzr//PO1efNmvXVgnGpfbLXtGPdUfkuLzp1i23hAsiAKSFuvv/66rrrqKvUMh1Vw5YM6auOzmlkel3YuncFbSpF2iALS2tDQkL678nn1WF+U5bbv/UFul6XpU/L01NXlto0JJANec4G09tGhsP6vJ9/WIEjS0XBEzZ296uwZtHVcINGIAtLalj2BqN/DNFpul6Wn30ze9zABY0EUkNaa9vY4slWodOxsoam9x5GxgUQhCkhbh4IhBfqHHT1GoG9YQ8GQo8cA4okoIG3t7xuS03dRRCTt6xty+ChA/BAFpK3Dtm0PmhzHAeKBKCBtjfPE58c7XscB4oGfZqStyXk5cnqbG+uT4wDpgiggbeVkeeRz+IljX152TPtCAMmGKCCt+Uu8jj6n4C/2OjI2kChEAWmtutzn6HMKi6b5HBkbSBSigLRWVDBeFYX5tp8tuF2WKgrzR723NJAqiALSXn1VmTw2R8HjslRfVWbrmEAyIApIe5Nys7WqstTWMVdXlvLabKQlooCMsHCqT8tmFdsy1vJZJVowlbUEpCf2U0BGiXWP5tWVpQQBaY0oION09Q+rrrFVzZ29crusU8Zh5POKwnzVV5VxyQhpjyggY3V0D2rLnoCa2nsU6Bs+4eV5lo49mOYv9mrRNB93GSFjEAVA0lAwpLf27tes2RfrZxvW67Kq2TypjIzEQjOgY6/EKC/+ig7/rV3Zh/sJAjIWUQA+ceaZZyo7O1v9/f2JngqQMEQBOE5eXp76+voSPQ0gYYgCcJzc3FzOFJDRiAJwHKKATEcUgOMQBWQ6ogAchygg0xEF4DgsNCPTEQXgOJwpINMRBeA4ubm5GhgYUCgUSvRUgIQgCsBxcnNzJUn/+Mc/EjsRIEGIAnCcvLw8SWJdARmLKADHGTlTYF0BmYooAMchCsh0RAE4DlFApiMKwHFG3pTKmgIyFVEATsKzCshkRAE4CVFAJiMKwEmIAjIZUQBOwvuPkMmIAnCSCblnqefwOL0TOKC2jwY0FOSVF8gcViQSiSR6EkCidXQPasuegJr29mh/35BkWeYzS5IvN1v+Eq+qy30qKhifuIkCDiMKyGhd/cOqa2xVc2ev3C5LR8Of/+sw8nlFYb7qq8o0KTc7jjMF4oMoIGM1tAS0YlubQuHIKWNwMrfLksdlaVVlqRZO9Tk4QyD+iAIy0qamDq3b0R7zOMtmFavGX2TDjIDkwEIzMk5DS8CWIEjSuh3teqYlYMtYQDIgCsgoXf3DWrGtzdYx797Wpq7+YVvHBBKFKCCj1DW2KhTF+sFohMIR1TW22jomkChEARmjo3tQzZ29US0qj8bRcETNnb3q7Bm0dVwgEYgCMsaWPQG5Xdbpv3AM3C5LT7/J2gJSH1FAxmja22P7WcKIo+GImtp7HBkbiCeigIxwKBhSwOHF4EDfMK/EQMojCsgI+/uG5PQDORFJ+/qGHD4K4CyigIxwOBROq+MATiEKyAjjPPH5UY/XcQCn8BOMjDA5L0fO3Hf0T9YnxwFSGVFARsjJ8sjn8FtNfXnZysnyOHoMwGlEARnDX+J19DkFf7HXkbGBeCIKyBjV5T5Hn1NYNI3XaCP1EQVkjKKC8aoozLf9bMHtslRRmK9CLzuyIfURBWSU+qoyeWyMQiQSkSsSVn1VmW1jAolEFJBRJuVm60dzim0bz7Is/f13D+jBNSsVCvE0M1IfUUBGOXz4sJ6p/x869Oettoy3bFaxVl8xRxs2bNAFF1ygv//977aMCyQK23EiYxw5ckQLFizQ7373O23btk0Hcr8V0x7NqytLteCTPZqbm5u1YMECRSIRPfvss6qoqHDq2wAcxZkCMkIoFNJll12m7du364UXXtBFF12khVN92rl0hqZPyZOk0y5Aj3w+fUqedi6dYYIgSRUVFXr77bf1zW9+U36/X+vXrxd/byEVcaaAtBcOh3XFFVdo69atevbZZzV37txPfU1H96C27Amoqb1Hgb7hE16eZ+nYg2n+Yq8WTfOd8i6jUCikH/3oR1q7dq3mzp2rJ554QhMnTrT/mwIcQhSQ1sLhsK677jo98cQT2rJlixYuXHja/2YoGNK+viEdDoU1zuPS5LycqJ9U/u1vf6vLL79cXq9XL7zwgsrKuDsJqYEoIG1FIhHV1NTo5z//uX71q19p8eLFcT1+Z2en5s2bp46ODj366KO67LLL4np8YCxYU0BaikQiuvXWW/Xwww/rsccei3sQJKmwsFC7d+/WggULtHjxYt1www0KBoNxnwcQDc4UkHYikYjq6uq0Zs0abdq0STfddFPC57N582bV1NSorKxMzz33nCZPnhzVGHZc0gJGgygg7axcuVKrVq3Shg0btHTp0kRPx3j77bc1b948HTx4UFu2bNHs2bNP+fVm8XtvjwL9n7H4nZstf4lX1eU+FRXwig3Ygyggrdx7772qq6vTvffeq9ra2kRP51P6+/u1ePFivfLKK7rrrrt09913y+12n/A1Xf3DqmtsVXNnr9wu65TPUIx8XlGYr/qqMk1y+PXgSH9EAWljw4YNuvXWW7VixQqtXLky0dP5XOFwWGvWrNFdd92lmTNnasuWLcrPz5ckNbQEYnqgblVlqRZO5W2tGDuigLTw0EMPqaamRrW1taqvr5dlOb3PWux27typSy+9VF/4whf03HPPqWU4V+t2tMc87rJZxarxF9kwQ2QiooCkE+2i6i9+8Qtde+21Wrp0qdavX58SQRjxwQcfaP78+fpfwS/qi7PsWxBfO7fshCeugdEiCkgKY11UffLJJ3XllVfqxhtv1IMPPphSQRjx1+4BXbhxl47KZdv8szwu7Vw6gzUGRI0oIKFiWVT9845tqq6u1lVXXaVHH31ULldqPnZz2eY9euP9Plt3hXO7LE2fkqenri63bUxkBqKAhIllUdWKhNXz+4dU9Z+9+uUvf5myQejoHtSF9/+HY+PvXHo+O8IhKqn5m4SUt6mpQ7UvtioYCkf9F/LRcERHwlLu7Bp9Z/GdKRsESdqyJ2D79qAj3C5LT78ZcGRspK/U/W1CympoCcR8l83ItfcNOzv0TEvq/h9f094eWy8bHe9oOKKm9h5Hxkb6IgqIq67+Ya3Y1mbrmHdva1NX/7CtY8bDoWBIAYfnHegb1lCQbUIxekQBcVXX2KqQzX8Zh8IR1TW22jpmPOzvG5LTC3oRSfv6hhw+CtIJUUDcdHQPqrmz1/bLJUfDETV39qqzZ9DWcZ12OBROq+MgPRAFxA2Lqica54nPr1+8joP0wE8L4oZF1RNNzsuR04/aWZ8cBxgtooC4YFH103KyPPI5/MSxLy+bfRcQFaKAuGBR9bP5S7yOXlLzF3sdGRvpiyggLlhU/WzV5T5HL6ktmsZL8RAdooC4YFH1sxUVjFdFYb7tZwtul6WKwnxecYGopdZvEFIWi6qfr76qTB6bo+BxWaqvKrN1TGQGooC4YFH1803KzdaqylJbx1xdWcprszEmRAFxw6Lq51s41adls4ptGWv5rBI22MGYEQXEDYuqp1bjL9KauWXK8riijqfbZSnL49LauWW6yV/o0AyRCYgC4oZF1dNbONWnnUtnaPqUPEk67b/VyOfTp+Rp59IZnCEgZmyyg7jq6h/WzI27FLTx1tF03XrSbFHa3qNA32dsUZqXLX+xV4um+dIiiEgORAFx19ASUO2L9r3VNBM2qR8KhrSvb0iHQ2GN87g0OS8nJRfVkfyIAhJiU1NHzBvtSMcWVbmGDtiHKCBhYtmj2eOytLqyNO3PEIB4IwpIqK7+YdU1tqq5s1dul3XKOIx8XlGYr/qqsrRbQwCSAVFAUmBRFUgORAFJh0VVIHGIAgDA4OE1AIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABhEAQBgEAUAgEEUAAAGUQAAGEQBAGAQBQCAQRQAAAZRAAAYRAEAYBAFAIBBFAAABlEAABj/H3F/qifPnfSeAAAAAElFTkSuQmCC", 11 | "text/plain": [ 12 | "
" 13 | ] 14 | }, 15 | "metadata": {}, 16 | "output_type": "display_data" 17 | } 18 | ], 19 | "source": [ 20 | "import torch\n", 21 | "import networkx as nx\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from torch_geometric.data import Data, Batch\n", 24 | "import numpy as np \n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from dect.dect.ect import EctLayer, EctConfig\n", 27 | "import networkx as nx\n", 28 | "\n", 29 | "\n", 30 | "pos = torch.tensor([[ 0.05531403, 0.31317299],\n", 31 | " [ 0.25020841, 0.51538039],\n", 32 | " [ 0.43322785, 0.37765552],\n", 33 | " [ 0.39449602, 0.04715373],\n", 34 | " [ 0.18697251, -0.10709096],\n", 35 | " [-0.05871538, -0.05422876],\n", 36 | " [-0.31735225, -0.27696388],\n", 37 | " [-0.52772282, -0.2229272 ],\n", 38 | " [-0.41642837, -0.59215184]])\n", 39 | "\n", 40 | "pos -= pos.mean()\n", 41 | "pos /= np.linalg.norm(pos,axis=0)\n", 42 | "ei = torch.tensor([[0, 0, 1, 2, 3, 4, 5, 6, 6],\n", 43 | " [1, 5, 2, 3, 4, 5, 6, 7, 8]],dtype=torch.long)\n", 44 | "\n", 45 | "\n", 46 | "\n", 47 | "# Vizualize graph\n", 48 | "\n", 49 | "G = nx.Graph()\n", 50 | "G.add_edge(0,1)\n", 51 | "G.add_edge(1,2)\n", 52 | "G.add_edge(2,3)\n", 53 | "G.add_edge(3,4)\n", 54 | "G.add_edge(4,5)\n", 55 | "G.add_edge(5,0)\n", 56 | "G.add_edge(5,6)\n", 57 | "G.add_edge(6,7)\n", 58 | "G.add_edge(6,8)\n", 59 | "\n", 60 | "fig, ax = plt.subplots()\n", 61 | "\n", 62 | "ax.set_xlim([-2,2])\n", 63 | "ax.set_ylim([-2,2])\n", 64 | "ax.set_aspect(1)\n", 65 | "\n", 66 | "nx.draw_kamada_kawai(G,ax=ax)# , node_size=10, pos= pos, node_color=\".5\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 27, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# Initialize ect\n", 76 | "\n", 77 | "V = torch.vstack(\n", 78 | " [\n", 79 | " torch.sin(torch.linspace(0, 2 * torch.pi, 64)),\n", 80 | " torch.cos(torch.linspace(0, 2 * torch.pi, 64)),\n", 81 | " ]\n", 82 | " )\n", 83 | "\n", 84 | "CONFIG = EctConfig(num_thetas=64,bump_steps=64,ect_type=\"edges\",device=\"cpu\",num_features=2)\n", 85 | "ectlayer = EctLayer(config=CONFIG,V=V)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 36, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "batch = Batch.from_data_list(\n", 95 | " [\n", 96 | " Data(x=pos,edge_index=ei),\n", 97 | " ]\n", 98 | ")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 43, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "" 110 | ] 111 | }, 112 | "execution_count": 43, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | }, 116 | { 117 | "data": { 118 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGfCAYAAAAZGgYhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArxUlEQVR4nO3df3BV9Z3/8de9Se5NyI97CUISSsLSKS2ohSogZrG7LWbLMB0HF6ZrO3SW7Tp1dAMVsNOanYqt0xrWbivVxlBdFtrZsmzZGWzpjrBOrHHaBZSoU5UdipUtqZCglvxO7k3uPd8/HO+3l/v5WA7c8Ln38nzMnBnzuYeTz7k59749ua+8PwHP8zwBAHCZBV1PAABwZaIAAQCcoAABAJygAAEAnKAAAQCcoAABAJygAAEAnKAAAQCcoAABAJygAAEAnCierAO3tbXp29/+tnp6erRw4UI9+uijuuGGG/7kv0smkzp9+rQqKysVCAQma3oAgEnieZ4GBwc1c+ZMBYPvc5/jTYI9e/Z4oVDI+9d//Vfvtdde8774xS960WjU6+3t/ZP/tru725PExsbGxpbnW3d39/u+3wc8L/vNSJcuXaolS5bo+9//vqR372rq6+u1YcMG3Xvvve/7b/v7+xWNRnX7hvsVCpdme2oAgEkWj41px6PfUF9fnyKRiHW/rP8KLh6Pq6urSy0tLamxYDCopqYmHTp0KGP/WCymWCyW+npwcFCSFAqXKkwBAoC89ac+Rsl6COHtt99WIpFQTU1N2nhNTY16enoy9m9tbVUkEklt9fX12Z4SACAHOU/BtbS0qL+/P7V1d3e7nhIA4DLI+q/grrrqKhUVFam3tzdtvLe3V7W1tRn7h8NhhcPhbE8DAJDjsn4HFAqFtGjRInV0dKTGksmkOjo61NjYmO1vBwDIU5Pyd0CbN2/WunXrtHjxYt1www3atm2bhoeH9YUvfGEyvh0AIA9NSgG67bbb9NZbb2nLli3q6enRxz72MR04cCAjmAAAuHJNWieE9evXa/369ZN1eABAnnOeggMAXJkoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcoQAAAJ4pdTyBnBAKTd2zPm7xjA0Ce4g4IAOAEBQgA4AQFCADgBAUIAOAEBQgA4ITvAvTcc8/plltu0cyZMxUIBPTkk0+mPe55nrZs2aK6ujqVlZWpqalJJ06cyNZ8L1wgYNyKSoqNWyhcYtzCZaGMzbZvcUmxcQsEg8YNAK5kvt8Fh4eHtXDhQrW1tRkff+ihh/TII49o+/btOnLkiMrLy7VixQqNjY1d8mQBAIXD998BrVy5UitXrjQ+5nmetm3bpq997WtatWqVJOlHP/qRampq9OSTT+qzn/1sxr+JxWKKxWKprwcGBvxOCQCQh7L6e6CTJ0+qp6dHTU1NqbFIJKKlS5fq0KFDxn/T2tqqSCSS2urr67M5JQBAjspqAerp6ZEk1dTUpI3X1NSkHjtfS0uL+vv7U1t3d3c2pwQAyFHOW/GEw2GFw2HX0wAAXGZZLUC1tbWSpN7eXtXV1aXGe3t79bGPfSyb3yqNKVEWKi0x7hsuMxe7UFnogo9t6+02Hp8wjsdGYsbx+FjcOJ6YSFzw9wSAfJXVX8HNmTNHtbW16ujoSI0NDAzoyJEjamxszOa3AgDkOd93QENDQ3r99ddTX588eVIvv/yyqqur1dDQoI0bN+qb3/ym5s6dqzlz5ui+++7TzJkzdeutt2Zz3gCAPOe7AB09elSf/OQnU19v3rxZkrRu3Trt2rVLX/nKVzQ8PKw77rhDfX19uummm3TgwAGVlpZmb9YAgLznuwB94hOfkPc+n0cEAgE98MADeuCBBy5pYgCAwuY8BedHsLjIOF46JfPuqjxabtw3PMUcQiguNj8VgWDmQnVe0lyAk8mkcTxebg4bjPQPm8cHRzPGJsbNAQfCCQDyFQ3JAABOUIAAAE5QgAAATlCAAABOUIAAAE7kbArOtGhbWUWZcd+qaVUZY6Xl5r87ChpSbe9+Q8u4iTmMpyLLA8UllvGQ+ek3pf2Gzg0Z9yUdhwy2a9nPNeHn9fB+uA4Lnqld2YUuuMkdEADACQoQAMAJChAAwAkKEADACQoQAMCJnE3BlU4JK3xeB+3o9IhxX1N/N1sKI5CtdI/p2JaEXcDSw644ZF40LxTOHLcl6QbeHjCOT5gWtZOUTBj61ZFUyh2W6zNYZEgaWfa1jduaCJvGgxeYYvpTbP0RuQ7zkOW6MiWOA0UX9rPkDggA4AQFCADgBAUIAOAEBQgA4AQFCADgRM6m4CLTIyotm5I2ZuvvdqF9hy6GKdlmSxkVGZJKklRsSLVJ79MLznA+FVMrjfuGSs0rvA68Y07HxUZiGWO2fnKeJcEEH3yk2iSpxHJNmBKTRZZ0pe3YthRcYjwzMVlkSV3a0nH2Y5uvrXhsPGNsIm7eN2FJdJKau7zCZSHjeHRGNGNsbNS87/m4AwIAOEEBAgA4QQECADhBAQIAOJGzIYTS8lKVTUkPHVzusIFkDhzYwgO2BfMqppjHw8WWD5yLMj8Atn3IW1E5xTh+5ne9xnFT656RwRHjvvGxuHHc9sHyFc0SNiguMf+MbYGasgrzeKg080Nd08KF707FXyuepOFDfr/HtjEdWzKHEEYt1+HYcGZwRpLG45nHkEQ4IQuKDNetaeFPyXLNBi8swMQdEADACQoQAMAJChAAwAkKEADACQoQAMCJnE3BBYuKFDSkwSaLLd1jSh9VRSuM+0ammBNpJZbzKMrConm2YycbzEmgoCHtF7K02Bg6N2gcHxkcNY5fMek4w8+nxLK4YHm03DxeZblWwuafRVGxj/9XtF0/tnSYqVVUthZutCRGSwyvK9t1WNJvTsfZrk9Two5knIXl52y6Pssj5vc90/v0hb53cwcEAHCCAgQAcIICBABwggIEAHCCAgQAcCJnU3CTxbZYlyntJknR6sz+R1Vl5t5uxZZUWzALiSLbMUKWfnLTq8wL2MWmmftqmdgWJbP15BvuHzaO5206zvKcmxJvldXm57tiqjk5ZFt4zvY9/fZg83PsSWVbkC+QmUoLW16DRbYUqSUZaFqMMTZq7mt4pafjbIvMVU2PZIzZ3g8uBXdAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcKNgVnW+HUtpqprb9bWUlm4ik2bl6JcUz+EjVFAXP9N62I6refXEmRJR1XHc0Yi41eeDJOkiKGhMz7GeobyhizrZSZS2yrmZoSb7YUnO16y0qqLY8Zz9+2qmzI0q/M0pvMpO9sv3F8wpLQtK0em6+pOVty1bbKqSmROBnXLHdAAAAnKEAAACcoQAAAJyhAAAAnfBWg1tZWLVmyRJWVlZoxY4ZuvfVWHT9+PG2fsbExNTc3a9q0aaqoqNCaNWvU29ub1UkDAPKfrxRcZ2enmpubtWTJEk1MTOgf//Ef9alPfUrHjh1Tefm7Kz9u2rRJ//Vf/6W9e/cqEolo/fr1Wr16tX71q19NygnYBC2pj9LyUuN4wksax88NZfY3s/U2Syb9JWRMq5NKUpEhfRW2rbgZDhvHSw3pPUmaEspMt9gSXBPj5qSauXuUPR2XTGQ+t7a+cV7S/HOYTMFic8LQ1sfN9HyRdps8tufQ1pvMlI5LJsyvzeEBW/9C87WfSJjHTdd4LiXmyirM/SttSUJbai7bfBWgAwcOpH29a9cuzZgxQ11dXfqLv/gL9ff3a8eOHdq9e7eWL18uSdq5c6fmz5+vw4cP68Ybb8zezAEAee2Sylx//7vZ+urqaklSV1eXxsfH1dTUlNpn3rx5amho0KFDh4zHiMViGhgYSNsAAIXvogtQMpnUxo0btWzZMl177bWSpJ6eHoVCIUWj0bR9a2pq1NPTYzxOa2urIpFIaquvr7/YKQEA8shFF6Dm5ma9+uqr2rNnzyVNoKWlRf39/amtu7v7ko4HAMgPF9WKZ/369fr5z3+u5557TrNmzUqN19bWKh6Pq6+vL+0uqLe3V7W1tcZjhcNhhS0fpF8K22JVE3FzgCA2Ym5HY2oZ4zds4JcpnDBq+aB8JGwJJ1RMMY5XlmaGMKrLzR9EDoXNH9AaP3CVFLLMJWoIJyQsrXhGh0aN49n4QNf2wWp5lfm5srUpMQUOCBtcfn7CCdaFAUvN12zcsoDd2PCYef+xzP1tbX4mM5xgC9TYAkKTscicH77ugDzP0/r167Vv3z4988wzmjNnTtrjixYtUklJiTo6OlJjx48f16lTp9TY2JidGQMACoKvO6Dm5mbt3r1bP/3pT1VZWZn6XCcSiaisrEyRSES33367Nm/erOrqalVVVWnDhg1qbGwkAQcASOOrALW3t0uSPvGJT6SN79y5U3/3d38nSXr44YcVDAa1Zs0axWIxrVixQo899lhWJgsAKBy+CpC1RfkfKS0tVVtbm9ra2i56UgCAwkcvOACAEwW7IF1iwpzUSkyY0262u7sLuevLNlM3moAleWZrGRKPmRfNi1VmjkfLzSmwyFRzi563Lce2NdEJT8lMOdpSObaU4njMnEqyMiSkTPOQpMhV5rmUWFJ9JN5ym+nnY2uVZEvLhsssLa4srbxM6c2RwRHjvraEXTbaUNkSnWUV5nm7vpa5AwIAOEEBAgA4QQECADhBAQIAOEEBAgA4UbApuKTPRIk3yf3dLpUnfyk92/kPTQxmjE3Ezak2Wz85W6LIlu4x9WCbUmleIGvc0n/t3NlzxnFTrz5JKjYs6mdLu9nSca4TQsge288yUGTuhWbrG2hLzZleE7YU5dC5IeP42Ii5z5ztGjctXGm7xoOW83SNOyAAgBMUIACAExQgAIATFCAAgBMUIACAEwWbgsv1VFu22M7TT2pudMicvklY+s+FS0PGcVsKzsSWyrGtXBkbNffws62gWlmd2ceuPGJO9dkST7hy+U3NlRiuoSLL6qS28eAfzNehbRXW8mh5xljY0qsuYFhlWXL/PskrDwDgBAUIAOAEBQgA4AQFCADgBAUIAOBEwabgYGZKvSRk7jUVGzEnz2y94LLBduxoTdTX/lWGnnK52g8L+c+UmgsWmf//3rY6aZFl/xFL6nSq4TVREjL3n7P2jLSuY2w4xiQk5rgDAgA4QQECADhBAQIAOEEBAgA4kbMhhEAwkNE+wnXbiEJle15tH1COx8wL2Plhaw1SFDQHBWyLxk2tmWocNwUokpbWQsBk8NvOx9ZGJ2q5xk2tpSbGzYEi22t2YnzCOJ4wHMcLXHh7L9vr+3zcAQEAnKAAAQCcoAABAJygAAEAnKAAAQCcyNkUXDAYUPC8RZ5sqSzScZPD72J3NqZETJEtCWRJu0Ui5oXqysPm/c+89U7G2PDAiHFfrh/kAlsbnQ/UTTeOhwwL201Ykp6xCXMKbiRmXkTStABkbMy8rykxFyQFBwDIZRQgAIATFCAAgBMUIACAExQgAIATOZuCK51SqtLzeiPFDckMSZqYMPQtItl02Vn7uxkSb7ZFuaZGK43jFWHz/sWWNN306mjGWNyS4hmPm/thcQ1hMtheJxVT/SU9TUosay6Gi81v9VNC5mPHSjNfb4NjY8Z9R4czxz3LIpfn4w4IAOAEBQgA4AQFCADgBAUIAOAEBQgA4ETOpuDKK8o0pSJ9xb9hS3uhCUuPL1xethUgTYm36miVcd9KQ/pGUkZfwNS45XtWGI5TMdWcsOs722cc99vzDrgQtj6I0yPm14TtGjdJGlYnleyvzWLL60olmX3pbMcoMhwjGLiw1Ye5AwIAOEEBAgA4QQECADhBAQIAOOErhNDe3q729nb93//9nyTpmmuu0ZYtW7Ry5UpJ0tjYmO655x7t2bNHsVhMK1as0GOPPaaamhrfE4uNT6gonr6IUtLSGsX04RgfIF9+RUXm/58pm5IZCDB9cClJE0nzh5dBy4ertuOYxmdUmT/kHR0aNY6PGVqMSLTowYWxtdwpj5Ybx8tCIeP4RMLc1sYUOLCFEMYtx7Ad2/Q6HJ8wt6yKj2eOxwxjJr7ugGbNmqWtW7eqq6tLR48e1fLly7Vq1Sq99tprkqRNmzZp//792rt3rzo7O3X69GmtXr3az7cAAFwhfN0B3XLLLWlff+tb31J7e7sOHz6sWbNmaceOHdq9e7eWL18uSdq5c6fmz5+vw4cP68Ybb8zerAEAee+iPwNKJBLas2ePhoeH1djYqK6uLo2Pj6upqSm1z7x589TQ0KBDhw5ZjxOLxTQwMJC2AQAKn+8C9Morr6iiokLhcFh33nmn9u3bp6uvvlo9PT0KhUKKRqNp+9fU1Kinp8d6vNbWVkUikdRWX1/v+yQAAPnHdwH6yEc+opdffllHjhzRXXfdpXXr1unYsWMXPYGWlhb19/entu7u7os+FgAgf/huxRMKhfShD31IkrRo0SK98MIL+t73vqfbbrtN8XhcfX19aXdBvb29qq2ttR4vHA4rbFhwaXRwWOevaZRMmBMeniX5gcvLllIcGc5MmZmSM5JUXGxuUxKyLKhVYmlrYmrpYztGdXXEON47Nm4cn0heWMIHk2tSX/e2Y1va0ZiSuLaWO1dZWu4kLQnQgTFzSjNuWIgzbkmqTVheb0nDMSQpkcicS9IwJkmeYd6jI8PGfc93yX8HlEwmFYvFtGjRIpWUlKijoyP12PHjx3Xq1Ck1NjZe6rcBABQYX3dALS0tWrlypRoaGjQ4OKjdu3fr2Wef1cGDBxWJRHT77bdr8+bNqq6uVlVVlTZs2KDGxkYScACADL4K0NmzZ/W3f/u3OnPmjCKRiBYsWKCDBw/qr/7qryRJDz/8sILBoNasWZP2h6gAAJzPVwHasWPH+z5eWlqqtrY2tbW1XdKkAACFj15wAAAncnZBuvF4QsXFF5Y2ojdXbrCleEaHMnuqBYMx474BSz8524J0xSXmpFEymnlNRKdMMewpRcrKjOMDVebxoXNDGWN+r0Frgssy7ifxZVs4LF/5TbuZdrcmuCzHtqXDSkrN/dqKDOnNMsv1U2ZY7E2SBsfMvQfffqffOJ4wzNGUSJP8J4j9POemfcfj5ufvfNwBAQCcoAABAJygAAEAnKAAAQCcoAABAJzI2RScl/RIt+UZ288rKUNfKXNYRwFLWslmPGZOfJl6WRVbknQVhr5xkjQ9au4RFxvJTPDFR+PGfa0pK8t5jsct/efimYlQWzLQFoKzpuMsK3f6OYbvvmyGa8X6XNmuK8tzaEpjJifM+yYsK4ImLL3TyirNScrK6sqMseqqzDFJGombr5W3/2BOu5muN8n8fE1mfzw/78cXui93QAAAJyhAAAAnKEAAACcoQAAAJyhAAAAncjYFh8LhKz0j874BS1LLlvoxJYfe+kOfcd/gVdXG8VJLz66KaEXG2Dsj7xj3jY+ZE0+2ZNPYkHn1y3FTCs7SN8+WVPOVjstWPzkfvcZsaTdrutISpTT1Q/PbC812bNPPQZKmGNJxE5Zj/KFvwDg+NmzuBWc/z/xPCXMHBABwggIEAHCCAgQAcIICBABwghAC8oLvBd8MYYaRQfMH/GcD54zj1dEq43hZKHNRMltbmME/DBrHRy1hA1PLHdvxbcEMm1xaqM7XgmfZWuzP30GMw+MxS6ukicyfW/cbb/o5tO9QSSHgDggA4AQFCADgBAUIAOAEBQgA4AQFCADgBCk45DU/C77Z2uLYWqAM9Q8bx0vLMxewsyW1bGk36wJ2tpX6jPte8K7IkuKSIuP4cF/mtfKO5boKTwkbxysi5eb9DdebVBjpOO6AAABOUIAAAE5QgAAATlCAAABOUIAAAE6QgkNesKXdEuMJ4/jwQGYqydrHy9J/zfY9p1SWZYyVGRYkk8yJOcmegkNuCxaZU3Cmnn+JCfO1absOE+Pm63BayLwwYnEo/9++uQMCADhBAQIAOEEBAgA4QQECADhBAQIAOJH/MQoUFN9pt/4h4/iAKZVkOUbS0n/N1t/NlGCbsBy7rCIzMSfZ+8/FLOPIDbZ+gsbEm+Vanhg3jw8PjBjHSyvMPQmrqiszxgLB/LqnyK/ZAgAKBgUIAOAEBQgA4AQFCADgBCEETDpjsMDyAe24pS2OacEvSRrqywwbSFLMT6sby1xsxuOZrVSGzpnDEPZ2PubWPbY2LUlLWxdcXrZ2Ob7YFlG0/IwH3u43jodLQ5ljebZ4HXdAAAAnKEAAACcoQAAAJyhAAAAnKEAAACcuKQW3detWtbS06O6779a2bdskSWNjY7rnnnu0Z88exWIxrVixQo899phqamqyMV/kAFuyy5ruMbS0iY/GjPva2pGM9JtTcHFLasxvss0Xw7FNyTjJno4rqzCnlUKGZJMkjQ2NXuDkUGhsic5+QzquusT8lm5bvM51Ou6i74BeeOEF/eAHP9CCBQvSxjdt2qT9+/dr79696uzs1OnTp7V69epLnigAoLBcVAEaGhrS2rVr9cQTT2jq1Kmp8f7+fu3YsUPf/e53tXz5ci1atEg7d+7U//zP/+jw4cNZmzQAIP9dVAFqbm7Wpz/9aTU1NaWNd3V1aXx8PG183rx5amho0KFDh4zHisViGhgYSNsAAIXP92dAe/bs0YsvvqgXXngh47Genh6FQiFFo9G08ZqaGvX09BiP19raqm984xt+pwEAyHO+7oC6u7t1991368c//rFKS80fovrV0tKi/v7+1Nbd3Z2V4wIAcpuvO6Curi6dPXtW119/fWoskUjoueee0/e//30dPHhQ8XhcfX19aXdBvb29qq2tNR4zHA4rHA5f3OwxqZIJc28qW8DMlgSLjWQm3kYHzaku20JtE7YeXJOZdvPDuviYed4jlvMvLinK2pRQICzXlikxWhI2pygj0yPG8aJit9ebrwJ0880365VXXkkb+8IXvqB58+bpq1/9qurr61VSUqKOjg6tWbNGknT8+HGdOnVKjY2N2Zs1ACDv+SpAlZWVuvbaa9PGysvLNW3atNT47bffrs2bN6u6ulpVVVXasGGDGhsbdeONN2Zv1gCAvJf15RgefvhhBYNBrVmzJu0PUQEA+GOXXICeffbZtK9LS0vV1tamtra2Sz00AKCA0QsOAOAEK6JCXjJpHB8bNvdr85N2e3c8M9lmW/k0mTDPJWfSbn5Z5m1bWdN6/sB5TCuoDv7BvEJwSbjEOF4eMa/MGyy6POk47oAAAE5QgAAATlCAAABOUIAAAE5QgAAATpCCu8KYEm+29JppxUVJmrCk4CbGzb3jEoa0Tt6m2iaZLZEIXAhbQrXvrT7jeLDIfA9iWrF3MpJx3AEBAJygAAEAnKAAAQCcoAABAJzI2RCC53nyzvugOhAIOJpN7jr/OfqjB4zDpsBB31vmsMHokHnRtIJrlwMUCstrMD4aN473W177pvfa0nLzwqGBYOZ9jPV96TzcAQEAnKAAAQCcoAABAJygAAEAnKAAAQCcyN0UXCKZkbaytY24EtJxtlSJrXWLrb3OwDsDGWMjgyPGfU0LXgHIP7b3CVvSNRDMfE/1vCrjvqXlmW17vAtcWJE7IACAExQgAIATFCAAgBMUIACAExQgAIATOZuCGx4YVmI8PfllSltIUnFJ5kJJ1mSc38Scn/5mWTq2KfGWmLCk3UbNabehc0PGcVPijbQbcGXyk47zkub3K9OCk2Oj5mTt+bgDAgA4QQECADhBAQIAOEEBAgA4QQECADiRsym4gXcGFCsdTxuLj5lX9TOl44pD5lMLGlbvez9JS0okG8e29XebiE9kjNl6u9n6uNmeKxJvAP4U0/uErW9cIpG5b2zMvO/5uAMCADhBAQIAOEEBAgA4QQECADiRsyGE2GhcSqbXx8S4+QP0seGxjLFQuMS4b7A4s23P+/HzoX1Rib+n09TCQjKHEMZj44Y9pYnxzH0le4sNALgYfha/jMXMIajzcQcEAHCCAgQAcIICBABwggIEAHCCAgQAcCJnU3DyvIwF22yJL1OazJYasy5UZ53GhS9Il61jJxOZaRPrPPwsmAcA2WZ6D7rA9yXugAAATlCAAABOUIAAAE5QgAAATlCAAABO+CpAX//61xUIBNK2efPmpR4fGxtTc3Ozpk2bpoqKCq1Zs0a9vb3Zm+17ybjzNi+ZzNgS4xPGbcLnZjvOZB7bdD62cweAfOX7Duiaa67RmTNnUtsvf/nL1GObNm3S/v37tXfvXnV2dur06dNavXp1VicMACgMvv8OqLi4WLW1tRnj/f392rFjh3bv3q3ly5dLknbu3Kn58+fr8OHDuvHGG43Hi8ViisX+fzfVgYEBv1MCAOQh33dAJ06c0MyZM/XBD35Qa9eu1alTpyRJXV1dGh8fV1NTU2rfefPmqaGhQYcOHbIer7W1VZFIJLXV19dfxGkAAPKNrwK0dOlS7dq1SwcOHFB7e7tOnjypj3/84xocHFRPT49CoZCi0Wjav6mpqVFPT4/1mC0tLerv709t3d3dF3UiAID84utXcCtXrkz994IFC7R06VLNnj1bP/nJT1RWVnZREwiHwwqHwxf1bwEA+euSYtjRaFQf/vCH9frrr6u2tlbxeFx9fX1p+/T29ho/M3LGlibLRspsMo8NAAXmkgrQ0NCQfvvb36qurk6LFi1SSUmJOjo6Uo8fP35cp06dUmNj4yVPFABQWHz9Cu7LX/6ybrnlFs2ePVunT5/W/fffr6KiIn3uc59TJBLR7bffrs2bN6u6ulpVVVXasGGDGhsbrQk4AMCVy1cB+v3vf6/Pfe5zeueddzR9+nTddNNNOnz4sKZPny5JevjhhxUMBrVmzRrFYjGtWLFCjz322KRMHACQ3wKenwVvLoOBgQFFIhHd9eVWhcOlrqcDAPApFhtT+z+/m3Cuqqqy7kcvOACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAExQgAIATFCAAgBMUIACAE74L0JtvvqnPf/7zmjZtmsrKyvTRj35UR48eTT3ueZ62bNmiuro6lZWVqampSSdOnMjqpAEA+c9XATp37pyWLVumkpISPfXUUzp27Ji+853vaOrUqal9HnroIT3yyCPavn27jhw5ovLycq1YsUJjY2NZnzwAIH8V+9n5n/7pn1RfX6+dO3emxubMmZP6b8/ztG3bNn3ta1/TqlWrJEk/+tGPVFNToyeffFKf/exnszRtAEC+83UH9LOf/UyLFy/WZz7zGc2YMUPXXXednnjiidTjJ0+eVE9Pj5qamlJjkUhES5cu1aFDh4zHjMViGhgYSNsAAIXPVwF644031N7errlz5+rgwYO666679KUvfUk//OEPJUk9PT2SpJqamrR/V1NTk3rsfK2trYpEIqmtvr7+Ys4DAJBnfBWgZDKp66+/Xg8++KCuu+463XHHHfriF7+o7du3X/QEWlpa1N/fn9q6u7sv+lgAgPzhqwDV1dXp6quvThubP3++Tp06JUmqra2VJPX29qbt09vbm3rsfOFwWFVVVWkbAKDw+SpAy5Yt0/Hjx9PGfvOb32j27NmS3g0k1NbWqqOjI/X4wMCAjhw5osbGxixMFwBQKHyl4DZt2qQ///M/14MPPqi/+Zu/0fPPP6/HH39cjz/+uCQpEAho48aN+uY3v6m5c+dqzpw5uu+++zRz5kzdeuutkzF/AECe8lWAlixZon379qmlpUUPPPCA5syZo23btmnt2rWpfb7yla9oeHhYd9xxh/r6+nTTTTfpwIEDKi0tzfrkAQD5K+B5nud6En9sYGBAkUhEd325VeEwRQsA8k0sNqb2f343YPZ+n+vTCw4A4AQFCADgBAUIAOAEBQgA4AQFCADgBAUIAOAEBQgA4AQFCADgBAUIAOAEBQgA4AQFCADgBAUIAOCEr27Yl8N7vVHjsTHHMwEAXIz33r//VK/rnOuG/fvf/1719fWupwEAuETd3d2aNWuW9fGcK0DJZFKnT59WZWWlBgcHVV9fr+7u7oJeqntgYIDzLBBXwjlKnGehyfZ5ep6nwcFBzZw5U8Gg/ZOenPsVXDAYTFXMQCAgSaqqqiroH/57OM/CcSWco8R5FppsnmckEvmT+xBCAAA4QQECADiR0wUoHA7r/vvvVzgcdj2VScV5Fo4r4RwlzrPQuDrPnAshAACuDDl9BwQAKFwUIACAExQgAIATFCAAgBMUIACAEzldgNra2vRnf/ZnKi0t1dKlS/X888+7ntIlee6553TLLbdo5syZCgQCevLJJ9Me9zxPW7ZsUV1dncrKytTU1KQTJ064mexFam1t1ZIlS1RZWakZM2bo1ltv1fHjx9P2GRsbU3Nzs6ZNm6aKigqtWbNGvb29jmZ8cdrb27VgwYLUX443NjbqqaeeSj1eCOd4vq1btyoQCGjjxo2psUI4z69//esKBAJp27x581KPF8I5vufNN9/U5z//eU2bNk1lZWX66Ec/qqNHj6Yev9zvQTlbgP7jP/5Dmzdv1v33368XX3xRCxcu1IoVK3T27FnXU7tow8PDWrhwodra2oyPP/TQQ3rkkUe0fft2HTlyROXl5VqxYoXGxvKnM3hnZ6eam5t1+PBhPf300xofH9enPvUpDQ8Pp/bZtGmT9u/fr71796qzs1OnT5/W6tWrHc7av1mzZmnr1q3q6urS0aNHtXz5cq1atUqvvfaapMI4xz/2wgsv6Ac/+IEWLFiQNl4o53nNNdfozJkzqe2Xv/xl6rFCOcdz585p2bJlKikp0VNPPaVjx47pO9/5jqZOnZra57K/B3k56oYbbvCam5tTXycSCW/mzJlea2urw1lljyRv3759qa+TyaRXW1vrffvb306N9fX1eeFw2Pv3f/93BzPMjrNnz3qSvM7OTs/z3j2nkpISb+/eval9/vd//9eT5B06dMjVNLNi6tSp3r/8y78U3DkODg56c+fO9Z5++mnvL//yL727777b87zC+Vnef//93sKFC42PFco5ep7nffWrX/Vuuukm6+Mu3oNy8g4oHo+rq6tLTU1NqbFgMKimpiYdOnTI4cwmz8mTJ9XT05N2zpFIREuXLs3rc+7v75ckVVdXS5K6uro0Pj6edp7z5s1TQ0ND3p5nIpHQnj17NDw8rMbGxoI7x+bmZn36059OOx+psH6WJ06c0MyZM/XBD35Qa9eu1alTpyQV1jn+7Gc/0+LFi/WZz3xGM2bM0HXXXacnnngi9biL96CcLEBvv/22EomEampq0sZramrU09PjaFaT673zKqRzTiaT2rhxo5YtW6Zrr71W0rvnGQqFFI1G0/bNx/N85ZVXVFFRoXA4rDvvvFP79u3T1VdfXVDnuGfPHr344otqbW3NeKxQznPp0qXatWuXDhw4oPb2dp08eVIf//jHNTg4WDDnKElvvPGG2tvbNXfuXB08eFB33XWXvvSlL+mHP/yhJDfvQTm3HAMKR3Nzs1599dW036cXko985CN6+eWX1d/fr//8z//UunXr1NnZ6XpaWdPd3a27775bTz/9tEpLS11PZ9KsXLky9d8LFizQ0qVLNXv2bP3kJz9RWVmZw5llVzKZ1OLFi/Xggw9Kkq677jq9+uqr2r59u9atW+dkTjl5B3TVVVepqKgoI2nS29ur2tpaR7OaXO+dV6Gc8/r16/Xzn/9cv/jFL9JWRKytrVU8HldfX1/a/vl4nqFQSB/60Ie0aNEitba2auHChfre975XMOfY1dWls2fP6vrrr1dxcbGKi4vV2dmpRx55RMXFxaqpqSmI8zxfNBrVhz/8Yb3++usF87OUpLq6Ol199dVpY/Pnz0/9utHFe1BOFqBQKKRFixapo6MjNZZMJtXR0aHGxkaHM5s8c+bMUW1tbdo5DwwM6MiRI3l1zp7naf369dq3b5+eeeYZzZkzJ+3xRYsWqaSkJO08jx8/rlOnTuXVeZokk0nFYrGCOcebb75Zr7zyil5++eXUtnjxYq1duzb134VwnucbGhrSb3/7W9XV1RXMz1KSli1blvEnEb/5zW80e/ZsSY7egyYl2pAFe/bs8cLhsLdr1y7v2LFj3h133OFFo1Gvp6fH9dQu2uDgoPfSSy95L730kifJ++53v+u99NJL3u9+9zvP8zxv69atXjQa9X760596v/71r71Vq1Z5c+bM8UZHRx3P/MLdddddXiQS8Z599lnvzJkzqW1kZCS1z5133uk1NDR4zzzzjHf06FGvsbHRa2xsdDhr/+69916vs7PTO3nypPfrX//au/fee71AIOD993//t+d5hXGOJn+cgvO8wjjPe+65x3v22We9kydPer/61a+8pqYm76qrrvLOnj3reV5hnKPned7zzz/vFRcXe9/61re8EydOeD/+8Y+9KVOmeP/2b/+W2udyvwflbAHyPM979NFHvYaGBi8UCnk33HCDd/jwYddTuiS/+MUvPEkZ27p16zzPezcGed9993k1NTVeOBz2br75Zu/48eNuJ+2T6fwkeTt37kztMzo66v3DP/yDN3XqVG/KlCneX//1X3tnzpxxN+mL8Pd///fe7NmzvVAo5E2fPt27+eabU8XH8wrjHE3OL0CFcJ633XabV1dX54VCIe8DH/iAd9ttt3mvv/566vFCOMf37N+/37v22mu9cDjszZs3z3v88cfTHr/c70GsBwQAcCInPwMCABQ+ChAAwAkKEADACQoQAMAJChAAwAkKEADACQoQAMAJChAAwAkKEADACQoQAMAJChAAwIn/Bwk7b+108yWgAAAAAElFTkSuQmCC", 119 | "text/plain": [ 120 | "
" 121 | ] 122 | }, 123 | "metadata": {}, 124 | "output_type": "display_data" 125 | } 126 | ], 127 | "source": [ 128 | "res = ectlayer(batch)\n", 129 | "img = res.squeeze().numpy()\n", 130 | "plt.imshow(img,cmap=\"bone\",vmin=-3,vmax=3)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "Python 3", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.10.11" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 2 162 | } 163 | -------------------------------------------------------------------------------- /experiment/DD/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | pin_memory: true 11 | drop_last: false 12 | name: DD 13 | cleaned: true 14 | use_node_attr: true 15 | model: 16 | module: models.ect_cnn 17 | ectconfig: 18 | num_thetas: 32 19 | bump_steps: 32 20 | batch_size: 128 21 | R: 1.1 22 | num_features: 89 23 | ecc_type: edges 24 | device: cuda 25 | num_classes: 2 26 | hidden: 50 27 | trainer: 28 | lr: 0.001 29 | num_epochs: 100 30 | num_reruns: 5 31 | -------------------------------------------------------------------------------- /experiment/Letter-high/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: ['dev'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 33 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-high 14 | cleaned: false 15 | model: 16 | module: models.ect_cnn 17 | ectconfig: 18 | num_thetas: 16 19 | bump_steps: 16 20 | R: 1.1 21 | ecc_type: edges 22 | device: cuda 23 | num_features: 2 24 | num_classes: 15 25 | hidden: 25 26 | trainer: 27 | lr: 0.001 28 | num_epochs: 100 29 | num_reruns: 5 30 | -------------------------------------------------------------------------------- /experiment/Letter-low/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_cnn 17 | ect_config: 18 | num_thetas: 16 19 | bump_steps: 16 20 | batch_size: 128 21 | R: 1.1 22 | num_features: 2 23 | device: cuda 24 | ecc_type: edges 25 | num_classes: 15 26 | hidden: 25 27 | trainer: 28 | lr: 0.001 29 | num_epochs: 100 30 | num_reruns: 5 31 | -------------------------------------------------------------------------------- /experiment/Letter-med/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-med 14 | cleaned: false 15 | model: 16 | module: models.ect_cnn 17 | ectconfig: 18 | num_thetas: 16 19 | bump_steps: 16 20 | batch_size: 128 21 | R: 1.1 22 | num_features: 2 23 | device: cuda 24 | ecc_type: edges 25 | num_classes: 15 26 | hidden: 25 27 | trainer: 28 | lr: 0.001 29 | num_epochs: 100 30 | num_reruns: 5 31 | -------------------------------------------------------------------------------- /experiment/PROTEINS_full/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 128 11 | pin_memory: true 12 | drop_last: false 13 | name: PROTEINS_full 14 | cleaned: false 15 | use_node_attr: true 16 | model: 17 | module: models.ect_cnn 18 | ectconfig: 19 | num_thetas: 32 20 | bump_steps: 32 21 | batch_size: 128 22 | R: 1.1 23 | num_features: 32 24 | device: cuda 25 | ecc_type: edges 26 | num_classes: 2 27 | hidden: 50 28 | trainer: 29 | lr: 0.001 30 | num_epochs: 100 31 | num_reruns: 5 32 | -------------------------------------------------------------------------------- /experiment/ablations/ablations0.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 2 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations1.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 4 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations2.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 6 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations3.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 8 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations4.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 10 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations5.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 12 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations6.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 14 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations7.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 16 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/ablations/ablations8.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | name: Letter-low 14 | cleaned: false 15 | model: 16 | module: models.ect_linear_edges 17 | num_thetas: 18 18 | bump_steps: 16 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 2 22 | num_classes: 15 23 | hidden: 25 24 | trainer: 25 | lr: 0.001 26 | num_epochs: 100 27 | num_reruns: 5 28 | -------------------------------------------------------------------------------- /experiment/bzr/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: BZR 14 | cleaned: true 15 | use_node_attr: true 16 | model: 17 | module: models.ect_cnn 18 | ectconfig: 19 | num_thetas: 16 20 | bump_steps: 16 21 | batch_size: 128 22 | R: 1.1 23 | num_features: 38 24 | device: cuda 25 | ecc_type: edges 26 | num_classes: 2 27 | hidden: 25 28 | trainer: 29 | lr: 0.001 30 | num_epochs: 100 31 | num_reruns: 5 32 | -------------------------------------------------------------------------------- /experiment/cox2/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: COX2 14 | cleaned: true 15 | use_node_attr: true 16 | model: 17 | module: models.ect_cnn 18 | ectconfig: 19 | num_thetas: 32 20 | bump_steps: 32 21 | batch_size: 128 22 | R: 1.1 23 | num_features: 38 24 | device: cuda 25 | ecc_type: edges 26 | num_classes: 2 27 | hidden: 50 28 | trainer: 29 | lr: 0.001 30 | num_epochs: 200 31 | num_reruns: 5 32 | -------------------------------------------------------------------------------- /experiment/dhfr/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.tu 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: DHFR 14 | cleaned: true 15 | use_node_attr: true 16 | model: 17 | module: models.ect_cnn 18 | ectconfig: 19 | num_thetas: 16 20 | bump_steps: 16 21 | batch_size: 128 22 | R: 1.1 23 | num_features: 38 24 | device: cuda 25 | ecc_type: edges 26 | num_classes: 2 27 | hidden: 25 28 | trainer: 29 | lr: 0.001 30 | num_epochs: 200 31 | num_reruns: 5 32 | -------------------------------------------------------------------------------- /experiment/gnn_mnist_classification/ect_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.gnn_benchmark 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 256 11 | pin_memory: false 12 | drop_last: false 13 | name: MNIST 14 | model: 15 | module: models.ect_cnn 16 | ectconfig: 17 | num_thetas: 32 18 | bump_steps: 32 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 3 22 | device: cuda 23 | ecc_type: edges 24 | num_classes: 10 25 | hidden: 100 26 | trainer: 27 | lr: 0.001 28 | num_epochs: 100 29 | num_reruns: 5 30 | -------------------------------------------------------------------------------- /experiment/lrgb/lrgb_cnn_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: dect-lrgb-test 3 | project: dect 4 | tags: ['dev','edges','cnn'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.lrgb 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: Peptides-func 14 | model: 15 | module: models.ect_cnn 16 | ectconfig: 17 | num_thetas: 16 18 | bump_steps: 16 19 | R: 1.1 20 | ecc_type: edges 21 | device: cuda 22 | num_features: 9 23 | num_classes: 10 24 | hidden: 25 25 | trainer: 26 | lr: 0.001 27 | num_epochs: 100 28 | num_reruns: 5 29 | -------------------------------------------------------------------------------- /experiment/lrgb/lrgb_cnn_points.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: dect-lrgb-test 3 | project: dect 4 | tags: ['dev','cnn','points'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.lrgb 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: Peptides-func 14 | model: 15 | module: models.ect_cnn 16 | ectconfig: 17 | num_thetas: 16 18 | bump_steps: 16 19 | R: 1.1 20 | ecc_type: points 21 | device: cuda 22 | num_features: 9 23 | num_classes: 10 24 | hidden: 25 25 | trainer: 26 | lr: 0.001 27 | num_epochs: 100 28 | num_reruns: 5 29 | -------------------------------------------------------------------------------- /experiment/lrgb/lrgb_linear_edges.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: dect-lrgb-test 3 | project: dect 4 | tags: ['dev','linear','edges'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.lrgb 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: Peptides-func 14 | model: 15 | module: models.ect_linear 16 | ectconfig: 17 | num_thetas: 16 18 | bump_steps: 16 19 | R: 1.1 20 | ecc_type: edges 21 | device: cuda 22 | num_features: 9 23 | num_classes: 10 24 | hidden: 25 25 | trainer: 26 | lr: 0.001 27 | num_epochs: 100 28 | num_reruns: 5 29 | -------------------------------------------------------------------------------- /experiment/lrgb/lrgb_linear_points.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: dect-lrgb-test 3 | project: dect 4 | tags: ['dev','linear','points'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.lrgb 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 64 11 | pin_memory: true 12 | drop_last: false 13 | name: Peptides-func 14 | model: 15 | module: models.ect_linear 16 | ectconfig: 17 | num_thetas: 16 18 | bump_steps: 16 19 | R: 1.1 20 | ecc_type: points 21 | device: cuda 22 | num_features: 9 23 | num_classes: 10 24 | hidden: 25 25 | trainer: 26 | lr: 0.001 27 | num_epochs: 100 28 | num_reruns: 5 29 | -------------------------------------------------------------------------------- /experiment/manifold_classification/ect_cnn_faces.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.manifold 8 | root: ./data 9 | num_workers: 0 10 | batch_size: 32 11 | pin_memory: true 12 | drop_last: false 13 | num_samples: 100 14 | model: 15 | module: models.ect_linear 16 | ectconfig: 17 | num_thetas: 8 18 | bump_steps: 8 19 | batch_size: 128 20 | R: 1.1 21 | num_features: 3 22 | device: cuda 23 | ecc_type: faces 24 | num_classes: 10 25 | hidden: 10 26 | trainer: 27 | lr: 0.001 28 | num_epochs: 10 29 | num_reruns: 5 30 | -------------------------------------------------------------------------------- /experiment/manifold_classification/ect_linear_faces.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: desct-test-new 3 | project: desct 4 | tags: [] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.manifold 8 | root: ./data 9 | num_workers: 0 10 | pin_memory: true 11 | drop_last: false 12 | num_samples: 100 13 | model: 14 | module: models.ect_linear 15 | ectconfig: 16 | bump_steps: 16 17 | R: 1.1 18 | num_features: 3 19 | num_classes: 3 20 | device: cuda 21 | ecc_type: faces 22 | hidden: 20 23 | trainer: 24 | lr: 0.001 25 | num_epochs: 100 26 | num_reruns: 5 27 | -------------------------------------------------------------------------------- /experiment/weighted_mnist/wect.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | name: wect-mnist-classification 3 | project: dect 4 | tags: ['dev'] 5 | experiment_folder: experiment 6 | data: 7 | module: datasets.weighted_mnist 8 | root: ./data/wect 9 | num_workers: 0 10 | batch_size: 128 11 | pin_memory: true 12 | drop_last: false 13 | cleaned: false 14 | model: 15 | module: models.wect_linear 16 | ectconfig: 17 | num_thetas: 32 18 | bump_steps: 32 19 | R: 1.1 20 | ecc_type: faces 21 | device: cuda 22 | num_features: 2 23 | num_classes: 10 24 | hidden: 25 25 | trainer: 26 | lr: 0.001 27 | num_epochs: 100 28 | num_reruns: 1 29 | -------------------------------------------------------------------------------- /figures/ect_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/dect-evaluation/87640a57eedb6f528569b8f66647a0e31828f156/figures/ect_animation.gif -------------------------------------------------------------------------------- /generate_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | import shutil 4 | from datasets.gnn_benchmark import GNNBenchmarkDataModuleConfig 5 | from datasets.modelnet import ModelNetDataModuleConfig 6 | from datasets.manifold import ManifoldDataModuleConfig 7 | from datasets.tu import ( 8 | TUDataModule, 9 | TUEnzymesConfig, 10 | TUIMDBBConfig, 11 | TUDDConfig, 12 | TUProteinsFullConfig, 13 | TURedditBConfig, 14 | TULetterHighConfig, 15 | TULetterLowConfig, 16 | TULetterMedConfig, 17 | TUNCI1Config, 18 | TUNCI109Config, 19 | TUBZRConfig, 20 | TUCOX2Config, 21 | TUFrankensteinConfig, 22 | TUFingerprintConfig, 23 | TUDHFRConfig, 24 | ) 25 | from models.base_model import ECTModelConfig 26 | from config import Config, TrainerConfig, Meta 27 | 28 | # ╭──────────────────────────────────────────────────────────╮ 29 | # │ Helper methods │ 30 | # ╰──────────────────────────────────────────────────────────╯ 31 | 32 | 33 | def create_experiment_folder(path): 34 | shutil.rmtree(path, ignore_errors=True) 35 | os.makedirs(path) 36 | 37 | 38 | def save_config(cfg, path): 39 | c = OmegaConf.create(cfg) 40 | with open(path, "w") as f: 41 | OmegaConf.save(c, f) 42 | 43 | # ╭──────────────────────────────────────────────────────────╮ 44 | # │ Experiments │ 45 | # ╰──────────────────────────────────────────────────────────╯ 46 | 47 | 48 | # def tu_bbbp(experiment_folder="experiment", trainer=None, meta=None) -> None: 49 | # experiment = f"./{experiment_folder}/BBBP" 50 | # create_experiment_folder(experiment) 51 | 52 | # modules = [ 53 | # # "models.ect_cnn_points", 54 | # "models.ect_cnn_edges", 55 | # # "models.ect_linear_points", 56 | # # "models.ect_linear_edges", 57 | # ] 58 | 59 | # # Create the dataset config. 60 | # data = TUNCI1Config() 61 | 62 | # for module in modules: 63 | # modelconfig = ECTModelConfig( 64 | # module=module, num_features=30, num_classes=2, num_thetas=32, bump_steps=32 65 | # ) 66 | 67 | # config = Config(meta, data, modelconfig, trainer) 68 | # save_config( 69 | # config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 70 | # ) 71 | 72 | # , 73 | # TUFrankensteinConfig, 74 | # TUFingerprintConfig, 75 | # TUDHFRConfig, 76 | 77 | 78 | def tu_dhfr(experiment_folder="experiment", trainer=None, meta=None) -> None: 79 | experiment = f"./{experiment_folder}/dhfr" 80 | create_experiment_folder(experiment) 81 | 82 | modules = [ 83 | "models.ect_cnn_edges", 84 | ] 85 | trainer = TrainerConfig(lr=0.001, num_epochs=200, num_reruns=5) 86 | # Create the dataset config. 87 | data = TUDHFRConfig() 88 | 89 | for module in modules: 90 | modelconfig = ECTModelConfig( 91 | module=module, 92 | num_features=38, 93 | num_classes=2, 94 | num_thetas=32, 95 | bump_steps=32, 96 | hidden=40, 97 | ) 98 | 99 | config = Config(meta, data, modelconfig, trainer) 100 | save_config( 101 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 102 | ) 103 | 104 | 105 | def tu_fingerprint(experiment_folder="experiment", trainer=None, meta=None) -> None: 106 | experiment = f"./{experiment_folder}/fingerprint" 107 | create_experiment_folder(experiment) 108 | 109 | modules = [ 110 | "models.ect_cnn_edges", 111 | ] 112 | # trainer = TrainerConfig(lr=0.001, num_epochs=100, num_reruns=5) 113 | # Create the dataset config. 114 | data = TUFingerprintConfig(cleaned=False) 115 | 116 | for module in modules: 117 | modelconfig = ECTModelConfig( 118 | module=module, 119 | num_features=2, 120 | num_classes=6, 121 | num_thetas=32, 122 | bump_steps=32, 123 | hidden=40, 124 | ) 125 | 126 | config = Config(meta, data, modelconfig, trainer) 127 | save_config( 128 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 129 | ) 130 | 131 | 132 | def tu_frankenstein(experiment_folder="experiment", trainer=None, meta=None) -> None: 133 | experiment = f"./{experiment_folder}/frankenstein" 134 | create_experiment_folder(experiment) 135 | 136 | modules = [ 137 | "models.ect_cnn_edges", 138 | ] 139 | # trainer = TrainerConfig(lr=0.001, num_epochs=100, num_reruns=5) 140 | # Create the dataset config. 141 | data = TUFrankensteinConfig() 142 | 143 | for module in modules: 144 | modelconfig = ECTModelConfig( 145 | module=module, 146 | num_features=780, 147 | num_classes=2, 148 | num_thetas=32, 149 | bump_steps=32, 150 | hidden=20, 151 | ) 152 | 153 | config = Config(meta, data, modelconfig, trainer) 154 | save_config( 155 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 156 | ) 157 | 158 | 159 | def tu_cox2(experiment_folder="experiment", trainer=None, meta=None) -> None: 160 | experiment = f"./{experiment_folder}/cox2" 161 | create_experiment_folder(experiment) 162 | 163 | modules = [ 164 | "models.ect_cnn_edges", 165 | ] 166 | trainer = TrainerConfig(lr=0.001, num_epochs=200, num_reruns=5) 167 | # Create the dataset config. 168 | data = TUCOX2Config() 169 | 170 | for module in modules: 171 | modelconfig = ECTModelConfig( 172 | module=module, num_features=38, num_classes=2, num_thetas=32, bump_steps=32 173 | ) 174 | 175 | config = Config(meta, data, modelconfig, trainer) 176 | save_config( 177 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 178 | ) 179 | 180 | 181 | def tu_bzr(experiment_folder="experiment", trainer=None, meta=None) -> None: 182 | experiment = f"./{experiment_folder}/bzr" 183 | create_experiment_folder(experiment) 184 | 185 | modules = [ 186 | "models.ect_cnn_edges", 187 | ] 188 | trainer = TrainerConfig(lr=0.001, num_epochs=200, num_reruns=5) 189 | # Create the dataset config. 190 | data = TUBZRConfig() 191 | 192 | for module in modules: 193 | modelconfig = ECTModelConfig( 194 | module=module, 195 | num_features=38, 196 | num_classes=2, 197 | num_thetas=32, 198 | bump_steps=32, 199 | hidden=50, 200 | ) 201 | 202 | config = Config(meta, data, modelconfig, trainer) 203 | save_config( 204 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 205 | ) 206 | 207 | 208 | def tu_nci1(experiment_folder="experiment", trainer=None, meta=None) -> None: 209 | experiment = f"./{experiment_folder}/nci1" 210 | create_experiment_folder(experiment) 211 | 212 | modules = [ 213 | # "models.ect_cnn_points", 214 | "models.ect_cnn_edges", 215 | # "models.ect_linear_points", 216 | # "models.ect_linear_edges", 217 | ] 218 | trainer = TrainerConfig(lr=0.001, num_epochs=100, num_reruns=5) 219 | # Create the dataset config. 220 | data = TUNCI1Config() 221 | 222 | for module in modules: 223 | modelconfig = ECTModelConfig( 224 | module=module, num_features=30, num_classes=2, num_thetas=32, bump_steps=32 225 | ) 226 | 227 | config = Config(meta, data, modelconfig, trainer) 228 | save_config( 229 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 230 | ) 231 | 232 | 233 | def tu_nci109(experiment_folder="experiment", trainer=None, meta=None) -> None: 234 | experiment = f"./{experiment_folder}/nci109" 235 | create_experiment_folder(experiment) 236 | trainer = TrainerConfig(lr=0.001, num_epochs=500, num_reruns=5) 237 | modules = [ 238 | # "models.ect_cnn_points", 239 | "models.ect_cnn_edges", 240 | # "models.ect_linear_points", 241 | # "models.ect_linear_edges", 242 | ] 243 | 244 | # Create the dataset config. 245 | data = TUNCI109Config(batch_size=64) 246 | 247 | for module in modules: 248 | modelconfig = ECTModelConfig( 249 | module=module, 250 | num_features=36, 251 | num_classes=2, 252 | num_thetas=64, 253 | bump_steps=64, 254 | hidden=50, 255 | ) 256 | 257 | config = Config(meta, data, modelconfig, trainer) 258 | save_config( 259 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 260 | ) 261 | 262 | 263 | def tu_reddit_b(experiment_folder="experiment", trainer=None, meta=None) -> None: 264 | experiment = f"./{experiment_folder}/REDDIT-BINARY" 265 | create_experiment_folder(experiment) 266 | 267 | modules = [ 268 | # "models.ect_cnn_points", 269 | "models.ect_cnn_edges", 270 | # "models.ect_linear_points", 271 | # "models.ect_linear_edges", 272 | ] 273 | 274 | # Create the dataset config. 275 | data = TURedditBConfig() 276 | 277 | for module in modules: 278 | modelconfig = ECTModelConfig( 279 | module=module, num_features=1, num_classes=2, num_thetas=32, bump_steps=32 280 | ) 281 | 282 | config = Config(meta, data, modelconfig, trainer) 283 | save_config( 284 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 285 | ) 286 | 287 | 288 | def tu_imdb_b(experiment_folder="experiment", trainer=None, meta=None) -> None: 289 | experiment = f"./{experiment_folder}/IMDB-BINARY" 290 | create_experiment_folder(experiment) 291 | 292 | modules = [ 293 | # "models.ect_cnn_points", 294 | "models.ect_cnn_edges", 295 | # "models.ect_linear_points", 296 | # "models.ect_linear_edges", 297 | ] 298 | 299 | # Create the dataset config. 300 | data = TUIMDBBConfig() 301 | 302 | for module in modules: 303 | modelconfig = ECTModelConfig( 304 | module=module, num_features=541, num_classes=2, num_thetas=32, bump_steps=32 305 | ) 306 | 307 | config = Config(meta, data, modelconfig, trainer) 308 | save_config( 309 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 310 | ) 311 | 312 | 313 | def tu_letter_low_classification( 314 | experiment_folder="experiment", trainer=None, meta=None 315 | ) -> None: 316 | experiment = f"./{experiment_folder}/Letter-low" 317 | create_experiment_folder(experiment) 318 | 319 | modules = [ 320 | # "models.ect_cnn_points", 321 | "models.ect_cnn_edges", 322 | # "models.ect_linear_points", 323 | # "models.ect_linear_edges", 324 | ] 325 | 326 | # Create the dataset config. 327 | data = TULetterLowConfig(batch_size=32) 328 | 329 | for module in modules: 330 | modelconfig = ECTModelConfig( 331 | module=module, num_features=2, num_classes=15, hidden=100 332 | ) 333 | 334 | config = Config(meta, data, modelconfig, trainer) 335 | save_config( 336 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 337 | ) 338 | 339 | 340 | def tu_letter_med_classification( 341 | experiment_folder="experiment", trainer=None, meta=None 342 | ) -> None: 343 | experiment = f"./{experiment_folder}/Letter-med" 344 | create_experiment_folder(experiment) 345 | 346 | modules = [ 347 | # "models.ect_cnn_points", 348 | "models.ect_cnn_edges", 349 | # "models.ect_linear_points", 350 | # "models.ect_linear_edges", 351 | ] 352 | 353 | # Create the dataset config. 354 | data = TULetterMedConfig(batch_size=32) 355 | 356 | for module in modules: 357 | modelconfig = ECTModelConfig( 358 | module=module, num_features=2, num_classes=15, hidden=100 359 | ) 360 | 361 | config = Config(meta, data, modelconfig, trainer) 362 | save_config( 363 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 364 | ) 365 | 366 | 367 | def tu_letter_high_classification( 368 | experiment_folder="experiment", trainer=None, meta=None 369 | ) -> None: 370 | experiment = f"./{experiment_folder}/Letter-high" 371 | create_experiment_folder(experiment) 372 | 373 | modules = [ 374 | # "models.ect_cnn_points", 375 | "models.ect_cnn_edges", 376 | # "models.ect_linear_points", 377 | # "models.ect_linear_edges", 378 | ] 379 | trainer = TrainerConfig(lr=0.001, num_epochs=150, num_reruns=5) 380 | # Create the dataset config. 381 | data = TULetterHighConfig(batch_size=32) 382 | 383 | for module in modules: 384 | modelconfig = ECTModelConfig( 385 | module=module, num_features=2, num_classes=15, hidden=100 386 | ) 387 | 388 | config = Config(meta, data, modelconfig, trainer) 389 | save_config( 390 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 391 | ) 392 | 393 | 394 | def tu_dd(experiment_folder="experiment", trainer=None, meta=None) -> None: 395 | experiment = f"./{experiment_folder}/DD" 396 | create_experiment_folder(experiment) 397 | 398 | modules = [ 399 | # "models.ect_cnn_points", 400 | "models.ect_cnn_edges", 401 | # "models.ect_linear_points", 402 | # "models.ect_linear_edges", 403 | ] 404 | 405 | # Create the dataset config. 406 | data = TUDDConfig() 407 | 408 | for module in modules: 409 | modelconfig = ECTModelConfig( 410 | module=module, num_features=89, num_classes=2, num_thetas=32, bump_steps=32 411 | ) 412 | 413 | config = Config(meta, data, modelconfig, trainer) 414 | save_config( 415 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 416 | ) 417 | 418 | 419 | def tu_enzymes(experiment_folder="experiment", trainer=None, meta=None) -> None: 420 | experiment = f"./{experiment_folder}/ENZYMES" 421 | create_experiment_folder(experiment) 422 | 423 | modules = [ 424 | # "models.ect_cnn_points", 425 | "models.ect_cnn_edges", 426 | # "models.ect_linear_points", 427 | # "models.ect_linear_edges", 428 | ] 429 | 430 | # Create the dataset config. 431 | data = TUEnzymesConfig() 432 | 433 | for module in modules: 434 | modelconfig = ECTModelConfig( 435 | module=module, num_features=21, num_classes=6, num_thetas=32, bump_steps=32 436 | ) 437 | 438 | config = Config(meta, data, modelconfig, trainer) 439 | save_config( 440 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 441 | ) 442 | 443 | 444 | def tu_proteins(experiment_folder="experiment", trainer=None, meta=None) -> None: 445 | experiment = f"./{experiment_folder}/PROTEINS_full" 446 | create_experiment_folder(experiment) 447 | 448 | trainer = TrainerConfig(lr=0.001, num_epochs=100, num_reruns=5) 449 | 450 | modules = [ 451 | # "models.ect_cnn_points", 452 | "models.ect_cnn_edges", 453 | # "models.ect_linear_points", 454 | # "models.ect_linear_edges", 455 | ] 456 | 457 | # Create the dataset config. 458 | data = TUProteinsFullConfig(batch_size=128) 459 | 460 | for module in modules: 461 | modelconfig = ECTModelConfig( 462 | module=module, 463 | num_features=32, 464 | num_classes=2, 465 | num_thetas=32, 466 | bump_steps=32, 467 | hidden=50, 468 | ) 469 | 470 | config = Config(meta, data, modelconfig, trainer) 471 | save_config( 472 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 473 | ) 474 | 475 | 476 | def gnn_classification( 477 | name="MNIST", experiment_folder="experiment", trainer=None, meta=None 478 | ) -> None: 479 | experiment = f"./{experiment_folder}/gnn_{name.lower()}_classification" 480 | create_experiment_folder(experiment) 481 | 482 | modules = [ 483 | # "models.ect_cnn_points", 484 | "models.ect_cnn_edges", 485 | # "models.ect_linear_points", 486 | # "models.ect_linear_edges", 487 | ] 488 | trainer = TrainerConfig(lr=0.001, num_epochs=200, num_reruns=5) 489 | # Create the dataset config. 490 | data = GNNBenchmarkDataModuleConfig( 491 | module="datasets.gnn_benchmark", 492 | batch_size=64, 493 | name=name, 494 | pin_memory=False, 495 | ) 496 | 497 | if name == "MNIST" or name == "PATTERN": 498 | num_features = 3 499 | else: 500 | num_features = 5 501 | 502 | if name == "PATTERN": 503 | num_classes = 2 504 | else: 505 | num_classes = 10 506 | 507 | for module in modules: 508 | # Create linear points model config 509 | modelconfig = ECTModelConfig( 510 | module=module, 511 | num_features=num_features, 512 | hidden=200, 513 | num_classes=num_classes, 514 | ) 515 | 516 | config = Config(meta, data, modelconfig, trainer) 517 | save_config( 518 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 519 | ) 520 | 521 | 522 | def gnn_modelnet_classification( 523 | name="10", experiment_folder="experiment", trainer=None, meta=None 524 | ) -> None: 525 | experiment = f"./{experiment_folder}/gnn_modelnet_{name}_classification" 526 | create_experiment_folder(experiment) 527 | 528 | modules = [ 529 | "models.ect_cnn_points", 530 | # "models.ect_linear_points", 531 | ] 532 | 533 | for module in modules: 534 | for samplepoints in [100, 1000, 5000]: 535 | # Create linear points model config 536 | modelconfig = ECTModelConfig(module=module) 537 | 538 | # Create the dataset config. 539 | data = ModelNetDataModuleConfig( 540 | root=f"./data/modelnet_{name}_{samplepoints}", 541 | module="datasets.modelnet", 542 | samplepoints=samplepoints, 543 | name=name, 544 | drop_last=True, 545 | ) 546 | 547 | config = Config(meta, data, modelconfig, trainer) 548 | save_config( 549 | config, 550 | os.path.join( 551 | experiment, f"{module.split(sep='.')[1]}_{samplepoints}.yaml" 552 | ), 553 | ) 554 | 555 | 556 | def manifold_classification( 557 | experiment_folder="experiment", trainer=None, meta=None 558 | ) -> None: 559 | """ 560 | This experiment trains a ect cnn and linear model to distinguish 561 | three classes, 562 | - a noisy torus, 563 | - a sphere and 564 | - a mobius strip. 565 | 566 | Models used: 567 | - ECTLinear 568 | - ECTCNN 569 | """ 570 | 571 | experiment = f"./{experiment_folder}/manifold_classification" 572 | create_experiment_folder(experiment) 573 | 574 | modules = [ 575 | # "models.ect_cnn_points", 576 | # "models.ect_cnn_edges", 577 | # "models.ect_linear_points", 578 | # "models.ect_linear_edges", 579 | "models.ect_linear_faces", 580 | "models.ect_cnn_faces", 581 | ] 582 | 583 | # Create the dataset config. 584 | data = ManifoldDataModuleConfig(module="datasets.manifold", batch_size=32) 585 | 586 | for module in modules: 587 | # Create linear points model config 588 | modelconfig = ECTModelConfig(module=module) 589 | 590 | config = Config(meta, data, modelconfig, trainer) 591 | save_config( 592 | config, os.path.join(experiment, f"{module.split(sep='.')[1]}.yaml") 593 | ) 594 | 595 | 596 | def theta_sweep(experiment_folder="experiment", trainer=None, meta=None): 597 | experiment = f"./{experiment_folder}/theta_sweep" 598 | create_experiment_folder(experiment) 599 | 600 | modules = [ 601 | "models.ect_linear_points", 602 | ] 603 | 604 | # Create the dataset config. 605 | data = GNNBenchmarkDataModuleConfig( 606 | module="datasets.gnn_benchmark", name="MNIST", pin_memory=False 607 | ) 608 | 609 | for module in modules: 610 | for theta in range(1, 32): 611 | # Create linear points model config 612 | modelconfig = ECTModelConfig(module=module, num_thetas=theta, hidden=100) 613 | 614 | config = Config(meta, data, modelconfig, trainer) 615 | save_config( 616 | config, 617 | os.path.join( 618 | experiment, f"{module.split(sep='.')[1]}_{theta:03d}.yaml" 619 | ), 620 | ) 621 | 622 | 623 | if __name__ == "__main__": 624 | # Create Trainer Config 625 | 626 | trainer = TrainerConfig(lr=0.001, num_epochs=100, num_reruns=5) 627 | # Create meta data 628 | meta = Meta("desct-test-new") 629 | experiment_folder = "experiment" 630 | # tu_bzr(experiment_folder, trainer, meta) 631 | # tu_cox2(experiment_folder, trainer, meta) 632 | # tu_frankenstein(experiment_folder, trainer, meta) 633 | # tu_fingerprint(experiment_folder, trainer, meta) 634 | # tu_dhfr(experiment_folder, trainer, meta) 635 | 636 | # tu_cox2(experiment_folder, trainer, meta) 637 | # tu_proteins(experiment_folder, trainer, meta) 638 | # tu_dd(experiment_folder, trainer, meta) 639 | # tu_imdb_b(experiment_folder, trainer, meta) 640 | # tu_reddit_b(experiment_folder, trainer, meta) 641 | # tu_letter_high_classification(experiment_folder, trainer, meta) 642 | # tu_letter_med_classification(experiment_folder, trainer, meta) 643 | # tu_letter_low_classification(experiment_folder, trainer, meta) 644 | 645 | # ogb_mol(experiment_folder, trainer, meta) 646 | # # gnn_classification("MNIST", experiment_folder, trainer, meta) 647 | # gnn_classification("CIFAR10", experiment_folder, trainer, meta) 648 | # # gnn_modelnet_classification("10", experiment_folder, trainer, meta) 649 | # gnn_modelnet_classification("40", experiment_folder, trainer, meta) 650 | manifold_classification(experiment_folder, trainer, meta) 651 | # theta_sweep(experiment_folder, trainer, meta) 652 | -------------------------------------------------------------------------------- /loaders/factory.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | creation_funcs = {} 4 | 5 | 6 | def register(name: str, module): 7 | """Register game characters""" 8 | print(f"Registerd {name}") 9 | creation_funcs[name] = module 10 | 11 | 12 | def unregister(character_type: str): 13 | """Register game characters""" 14 | creation_funcs.pop(character_type, None) 15 | 16 | 17 | def load_module(name, config): 18 | module = importlib.import_module(config.module) # type: ignore 19 | module.initialize() 20 | try: 21 | creation_func = creation_funcs[name] 22 | return creation_func(config) 23 | except KeyError: 24 | raise ValueError from None 25 | 26 | 27 | class PluginInterface: 28 | """PluginInterface docstring""" 29 | 30 | @staticmethod 31 | def initialize() -> None: 32 | """Initialize the plugin""" 33 | 34 | 35 | def import_module(name: str) -> PluginInterface: 36 | return importlib.import_module(name) # type: ignore 37 | 38 | 39 | def load_plugin(plugin_name: str) -> None: 40 | """Load the plugins""" 41 | plugin = import_module(plugin_name) 42 | plugin.initialize() 43 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from typing import Callable, Any 4 | import time 5 | import functools 6 | 7 | import wandb 8 | 9 | 10 | class Logger: 11 | def __init__(self, dev: bool = True, wandb_logger: bool = True): 12 | self.dev = dev 13 | self.wandb = None 14 | 15 | self.logger = logging.getLogger(__name__) 16 | self.logger.setLevel(logging.DEBUG) 17 | 18 | timestr = time.strftime("%Y%m%d-%H%M%S") 19 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 20 | 21 | ch = logging.StreamHandler() 22 | ch.setLevel(logging.DEBUG) 23 | ch.setFormatter(formatter) 24 | 25 | fh = logging.FileHandler(filename=f"./logs/test-{timestr}.logs", mode="a") 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | 29 | self.logger.addHandler(ch) 30 | self.logger.addHandler(fh) 31 | 32 | def wandb_init(self, config): 33 | self.wandb = wandb.init( 34 | project=config.project, 35 | name=config.name, 36 | tags=config.tags + ["dev"], 37 | ) 38 | 39 | def log(self, msg: str | None = None, params: dict[str, str] | None = None) -> None: 40 | if msg: 41 | self.logger.info(msg=msg) 42 | 43 | if params and self.wandb: 44 | self.wandb.log(params) 45 | 46 | def log_config(self, config): 47 | if self.wandb: 48 | self.wandb.config.update(config) 49 | else: 50 | raise ValueError() 51 | 52 | 53 | def timing(logger: Logger, dev: bool = True): 54 | if dev: 55 | 56 | def decorator(func: Callable[..., Any]) -> Callable[..., Any]: 57 | @functools.wraps(func) 58 | def wrapper(*args: Any, **kwargs: Any) -> Any: 59 | logger.log(f"Calling {func.__name__}") 60 | ts = time.time() 61 | value = func(*args, **kwargs) 62 | te = time.time() 63 | logger.log(f"Finished {func.__name__}") 64 | if logger: 65 | logger.log("func:%r took: %2.4f sec" % (func.__name__, te - ts)) 66 | return value 67 | 68 | return wrapper 69 | 70 | return decorator 71 | else: 72 | 73 | def decorator(func: Callable[..., Any]) -> Callable[..., Any]: 74 | @functools.wraps(func) 75 | def wrapper(*args: Any, **kwargs: Any) -> Any: 76 | value = func(*args, **kwargs) 77 | return value 78 | 79 | return wrapper 80 | 81 | return decorator 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from utils import count_parameters 4 | from utils import listdir 5 | from logger import Logger, timing 6 | from metrics.metrics import compute_confusion, compute_acc 7 | import loaders.factory as loader 8 | import time 9 | 10 | import torchmetrics 11 | 12 | torch.cuda.empty_cache() 13 | mylogger = Logger() 14 | 15 | 16 | class EarlyStopper: 17 | def __init__(self, patience=10, min_delta=0.01): 18 | self.patience = patience 19 | self.min_delta = min_delta 20 | self.counter = 0 21 | self.min_validation_loss = torch.inf 22 | 23 | def __call__(self, validation_loss): 24 | if validation_loss < self.min_validation_loss: 25 | self.min_validation_loss = validation_loss 26 | self.counter = 0 27 | elif validation_loss > (self.min_validation_loss + self.min_delta): 28 | self.counter += 1 29 | if self.counter >= self.patience: 30 | print("stopped early") 31 | return True 32 | return False 33 | 34 | 35 | def clip_grad(model, max_norm): 36 | total_norm = 0 37 | for p in model.parameters(): 38 | param_norm = p.grad.data.norm(2) 39 | total_norm += param_norm**2 40 | total_norm = total_norm ** (0.5) 41 | clip_coef = max_norm / (total_norm + 1e-6) 42 | if clip_coef < 1: 43 | for p in model.parameters(): 44 | p.grad.data.mul_(clip_coef) 45 | return total_norm 46 | 47 | 48 | class Experiment: 49 | def __init__(self, experiment, logger, dev=True): 50 | self.config = OmegaConf.load(experiment) 51 | self.dev = dev 52 | self.logger = logger 53 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 54 | 55 | self.logger.log("Setup") 56 | self.logger.wandb_init(self.config.meta) 57 | 58 | # Load the dataset 59 | self.dm = loader.load_module("dataset", self.config.data) 60 | print(self.dm.entire_ds[0].x.shape) 61 | print(self.config.model) 62 | # Load the model 63 | model = loader.load_module("model", self.config.model) 64 | 65 | # Send model to device 66 | self.model = model.to(self.device) 67 | 68 | # Loss function and optimizer. 69 | self.loss_fn = torch.nn.CrossEntropyLoss() 70 | self.optimizer = torch.optim.Adam( 71 | [{"params": self.model.parameters()}], 72 | lr=self.config.trainer.lr, 73 | # weight_decay=1e-7, 74 | # eps=1e-4, 75 | ) 76 | # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 77 | # self.optimizer, mode="min", factor=0.2, patience=10, verbose=True 78 | # ) 79 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR( 80 | self.optimizer, milestones=[500], gamma=0.1 81 | ) 82 | 83 | self.early_stopper = EarlyStopper() 84 | 85 | self.accuracy_list = [] 86 | 87 | # Log info 88 | # Log info 89 | self.logger.log(f"Configurations:\n {OmegaConf.to_yaml(self.config)}") 90 | self.logger.log( 91 | f"{self.config.model.module} has {count_parameters(self.model)} trainable parameters" 92 | ) 93 | 94 | @timing(mylogger) 95 | def run(self): 96 | """ 97 | Runs an experiment given the loaded config files. 98 | """ 99 | start = time.time() 100 | for epoch in range(self.config.trainer.num_epochs): 101 | self.run_epoch() 102 | 103 | # if self.early_stopper(val_loss): 104 | # break 105 | 106 | if epoch % 10 == 0: 107 | end = time.time() 108 | self.compute_metrics(epoch, end - start) 109 | start = time.time() 110 | 111 | # Compute test accuracy 112 | loss, acc, roc = compute_acc( 113 | self.model, self.dm.test_dataloader(), self.config.model.num_classes 114 | ) 115 | self.accuracy_list.append(acc) 116 | 117 | # Log statements 118 | self.logger.log( 119 | f"Test accuracy: {acc:.3f}", 120 | params={ 121 | # "thetas": self.config.model.num_thetas, 122 | "test_acc": acc, 123 | "test_loss": loss, 124 | }, 125 | ) 126 | return loss, acc 127 | 128 | def run_epoch(self): 129 | self.model.train() 130 | for batch in self.dm.train_dataloader(): 131 | batch_gpu, y_gpu = batch.to(self.device), batch.y.to(self.device) 132 | self.optimizer.zero_grad() 133 | pred = self.model(batch_gpu) 134 | loss = self.loss_fn(pred, y_gpu) 135 | loss.backward() 136 | clip_grad(self.model, 5) 137 | self.optimizer.step() 138 | # raise "hello" 139 | 140 | val_loss, _, _ = compute_acc( 141 | self.model, self.dm.val_dataloader(), self.config.model.num_classes 142 | ) 143 | self.scheduler.step() 144 | return val_loss 145 | 146 | # del batch_gpu, y_gpu, pred, loss 147 | 148 | def compute_metrics(self, epoch, run_time): 149 | val_loss, val_acc, val_roc = compute_acc( 150 | self.model, self.dm.val_dataloader(), self.config.model.num_classes 151 | ) 152 | train_loss, train_acc, _ = compute_acc( 153 | self.model, self.dm.train_dataloader(), self.config.model.num_classes 154 | ) 155 | 156 | # Log statements to console 157 | self.logger.log( 158 | msg=f"epoch {epoch} | Train Loss {train_loss.item():.3f} | Val Loss {val_loss.item():.3f} | Train Accuracy {train_acc:.3f} | Val Accuracy {val_acc:.3f} | Run time {run_time:.2f} ", 159 | params={"epoch": epoch, "val_acc": val_acc}, 160 | ) 161 | 162 | 163 | def compute_avg(acc: torch.Tensor): 164 | # torch.save(self.model.ectlayer.v, "test.pt") 165 | final_acc_mean = torch.mean(acc) 166 | final_acc_std = torch.std(acc) 167 | print(acc) 168 | # Log statements 169 | mylogger.log( 170 | f"Final accuracy {final_acc_mean:.4f} with std {final_acc_std:.4f}.", 171 | ) 172 | 173 | 174 | def main(): 175 | experiments = [ 176 | # "DD", 177 | # "ENZYMES", 178 | # "IMDB-BINARY", 179 | # "Letter-high", 180 | # "Letter-med", 181 | # "Letter-low", 182 | # "gnn_mnist_classification", 183 | # "gnn_cifar10_classification", 184 | # "PROTEINS_full", 185 | # "REDDIT-BINARY", 186 | # "OGB-MOLHIV" 187 | # "dhfr", 188 | # "bzr", 189 | # "cox2", 190 | "lrgb" 191 | ] 192 | 193 | for experiment in experiments: 194 | for config in listdir(f"./experiment/{experiment}"): 195 | accs = [] 196 | for _ in range(5): 197 | print("Running experiment", experiment, config) 198 | exp = Experiment(config, logger=mylogger, dev=True) 199 | loss, acc = exp.run() 200 | accs.append(acc) 201 | compute_avg(torch.tensor(accs)) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import confusion_matrix 3 | from torchmetrics.classification import AUROC 4 | 5 | 6 | def compute_confusion(model, loader): 7 | y_true = [] 8 | y_pred = [] 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | with torch.no_grad(): 12 | for batch in loader: 13 | batch_gpu, y_gpu = batch.to(device), batch.y.to(device) 14 | y_pred.append(model(batch_gpu)) 15 | y_true.append(y_gpu) 16 | 17 | y_true = torch.cat(y_true) 18 | y_pred = torch.cat(y_pred).max(axis=1)[1] 19 | cfm = confusion_matrix( 20 | y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy() 21 | ) 22 | return cfm 23 | 24 | 25 | import torchmetrics 26 | 27 | 28 | def compute_acc(model, loader, num_classes=None): 29 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 30 | model = model.eval() 31 | acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device) 32 | # auroc = AUROC(task="multiclass", num_classes=num_classes).to(device) 33 | loss_fn = torch.nn.CrossEntropyLoss() 34 | loss = torch.tensor([0.0], device=device) 35 | with torch.no_grad(): 36 | for batch in loader: 37 | batch_gpu, y_gpu = batch.to(device), batch.y.to(device) 38 | logits = model(batch_gpu) 39 | loss += loss_fn(logits, y_gpu) 40 | # auroc(logits, y_gpu) 41 | acc(logits, y_gpu) 42 | a = acc.compute() 43 | # roc = auroc.compute() 44 | acc.reset() 45 | # auroc.reset() 46 | return loss, a, torch.tensor([0]) 47 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # import os, sys 2 | # import importlib 3 | # from pydoc import locate 4 | # from abc import ABC, abstractmethod 5 | # import torch 6 | # import torch.nn.functional as F 7 | # from torch.utils.data import random_split 8 | 9 | # """ 10 | # This package includes all the modules related to data loading and preprocessing. 11 | # """ 12 | 13 | # """ 14 | # The below code fetches all files in the datasets folder and imports them. Each 15 | # dataset can have its own file, or grouped in sets. 16 | # """ 17 | # """ path = os.path.dirname(os.path.abspath(__file__)) """ 18 | # """ for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']: """ 19 | # """ mod = __import__('.'.join([__name__, py]), fromlist=[py]) """ 20 | # """ classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] """ 21 | # """ for cls in classes: """ 22 | # """ setattr(sys.modules[__name__], cls.__name__, cls) """ 23 | 24 | 25 | # def load_model(name, config): 26 | # model = locate(f"models.{name}") 27 | # if not model: 28 | # print(f"Tried to load model {name}") 29 | # raise AttributeError() 30 | # return model(config) 31 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from dataclasses import dataclass 3 | 4 | 5 | class BaseModel(nn.Module): 6 | """ 7 | This is an abstract base model for the model class. 8 | """ 9 | 10 | def __init__(self, config): 11 | super().__init__() 12 | self.config = config 13 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass(frozen=True) 5 | class EctConfig: 6 | num_thetas: int = 64 7 | bump_steps: int = 64 8 | R: float = 1.1 9 | ect_type: str = "points" 10 | device: str = 'cuda:0' 11 | num_features: int = 3 12 | 13 | 14 | @dataclass(frozen=True) 15 | class ModelConfig: 16 | ectconfig: EctConfig = EctConfig() 17 | num_features: int = 3 18 | num_classes: int = 10 19 | hidden: int = 50 20 | -------------------------------------------------------------------------------- /models/ect_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | import operator 5 | 6 | 7 | from models.layers.layers import EctLayer 8 | from models.config import ModelConfig 9 | from loaders.factory import register 10 | from models.base_model import BaseModel 11 | 12 | 13 | class EctCnnModel(BaseModel): 14 | def __init__(self, config: ModelConfig): 15 | super().__init__(config) 16 | self.ectlayer = EctLayer(config.ectconfig) 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(1, 8, kernel_size=3), 19 | nn.MaxPool2d(2), 20 | nn.Conv2d(8, 16, kernel_size=3), 21 | nn.MaxPool2d(2), 22 | ) 23 | num_features = functools.reduce( 24 | operator.mul, 25 | list( 26 | self.conv( 27 | torch.rand( 28 | 1, config.ectconfig.bump_steps, config.ectconfig.num_thetas 29 | ) 30 | ).shape 31 | ), 32 | ) 33 | 34 | self.linear = nn.Sequential( 35 | nn.Linear(num_features, config.hidden), 36 | nn.ReLU(), 37 | nn.Linear(config.hidden, config.hidden), 38 | nn.ReLU(), 39 | nn.Linear(config.hidden, config.num_classes), 40 | ) 41 | 42 | def forward(self, batch): 43 | x = self.ectlayer(batch).unsqueeze(1) 44 | x = self.conv(x).view(x.size(0), -1) 45 | x = self.linear(x) 46 | return x 47 | 48 | 49 | def initialize(): 50 | register("model", EctCnnModel) 51 | -------------------------------------------------------------------------------- /models/ect_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.base_model import BaseModel 4 | 5 | from models.layers.layers import EctLayer 6 | from models.config import EctConfig 7 | 8 | 9 | class EctLinearModel(BaseModel): 10 | def __init__(self, config: EctConfig): 11 | super().__init__(config) 12 | self.ectlayer = EctLayer(config.ectconfig) 13 | 14 | self.linear = nn.Sequential( 15 | nn.Linear(self.config.ectconfig.num_thetas * self.config.ectconfig.bump_steps, config.hidden), 16 | nn.ReLU(), 17 | nn.Linear(config.hidden, config.hidden), 18 | nn.ReLU(), 19 | nn.Linear(config.hidden, config.num_classes), 20 | ) 21 | 22 | def forward(self, batch): 23 | x = self.ectlayer(batch).reshape( 24 | -1, self.config.ectconfig.num_thetas * self.config.ectconfig.bump_steps 25 | ) 26 | x = self.linear(x) 27 | return x 28 | 29 | 30 | from loaders.factory import register 31 | 32 | 33 | def initialize(): 34 | register("model", EctLinearModel) 35 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidos-lab/dect-evaluation/87640a57eedb6f528569b8f66647a0e31828f156/models/layers/__init__.py -------------------------------------------------------------------------------- /models/layers/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import segment_coo 4 | import geotorch 5 | from models.config import EctConfig 6 | from torch_geometric.data import Data 7 | 8 | from typing import Protocol 9 | from dataclasses import dataclass 10 | 11 | 12 | def compute_ecc(nh, index, lin, dim_size): 13 | ecc = torch.nn.functional.sigmoid(200 * torch.sub(lin, nh)) 14 | return segment_coo(ecc, index.view(1, -1), dim_size=dim_size).movedim(0, 1) 15 | 16 | 17 | def compute_ect_points(data, v, lin): 18 | nh = data.x @ v 19 | return compute_ecc(nh, data.batch, lin, data.num_graphs) 20 | 21 | 22 | def compute_ect_edges(data, v, lin): 23 | nh = data.x @ v 24 | eh, _ = nh[data.edge_index].max(dim=0) 25 | return compute_ecc(nh, data.batch, lin, dim_size=data.num_graphs) - compute_ecc( 26 | eh, data.batch[data.edge_index[0]], lin, dim_size=data.num_graphs 27 | ) 28 | 29 | 30 | def compute_ect_faces(data, v, lin): 31 | nh = data.x @ v 32 | eh, _ = nh[data.edge_index].max(dim=0) 33 | fh, _ = nh[data.face].max(dim=0) 34 | return ( 35 | compute_ecc(nh, data.batch, lin, dim_size=data.num_graphs) 36 | - compute_ecc(eh, data.batch[data.edge_index[0]], lin, dim_size=data.num_graphs) 37 | + compute_ecc(fh, data.batch[data.face[0]], lin, dim_size=data.num_graphs) 38 | ) 39 | 40 | 41 | class EctLayer(nn.Module): 42 | """docstring for EctLayer.""" 43 | 44 | def __init__(self, config: EctConfig, fixed=False): 45 | super().__init__() 46 | self.fixed = fixed 47 | self.lin = ( 48 | torch.linspace(-config.R, config.R, config.bump_steps) 49 | .view(-1, 1, 1) 50 | .to(config.device) 51 | ) 52 | if self.fixed: 53 | self.v = torch.vstack( 54 | [ 55 | torch.sin(torch.linspace(0, 2 * torch.pi, config.num_thetas)), 56 | torch.cos(torch.linspace(0, 2 * torch.pi, config.num_thetas)), 57 | ] 58 | ).to(config.device) 59 | else: 60 | self.v = (torch.rand(size=(config.num_features, config.num_thetas)) - 0.5).T.to(config.device) 61 | self.v /= self.v.pow(2).sum(axis=1).sqrt().unsqueeze(1) 62 | self.v = nn.Parameter(self.v.T) 63 | 64 | if config.ect_type == "points": 65 | self.compute_ect = compute_ect_points 66 | elif config.ect_type == "edges": 67 | self.compute_ect = compute_ect_edges 68 | elif config.ect_type == "faces": 69 | self.compute_ect = compute_ect_faces 70 | 71 | def postinit(self): 72 | if not self.fixed: 73 | geotorch.constraints.sphere(self, "v") 74 | 75 | def forward(self, data): 76 | return self.compute_ect(data, self.v, self.lin) 77 | -------------------------------------------------------------------------------- /models/layers/layers_wect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import segment_coo 4 | import geotorch 5 | from models.config import EctConfig 6 | from torch_geometric.data import Data 7 | 8 | from typing import Protocol 9 | from dataclasses import dataclass 10 | 11 | 12 | def compute_wecc(nh, index, lin, weight,out): 13 | ecc = torch.nn.functional.sigmoid(500 * torch.sub(lin, nh)) * weight.view(1,-1,1) 14 | res = torch.index_add(out,1, index, ecc).movedim(0, 1) 15 | return res 16 | 17 | 18 | def compute_wect_points(data, v, lin, out): 19 | # Compute the weights 20 | edge_weights,_ = data.node_weights[data.edge_index].max(axis=0) 21 | face_weights,_ = data.node_weights[data.face].max(axis=0) 22 | 23 | nh = data.x @ v 24 | eh, _ = nh[data.edge_index].min(dim=0) 25 | fh, _ = nh[data.face].min(dim=0) 26 | return ( 27 | compute_wecc(nh, data.batch, lin, data.node_weights,out) 28 | ) 29 | 30 | def compute_wect_edges(data, v, lin, out): 31 | # Compute the weights 32 | edge_weights,_ = data.node_weights[data.edge_index].max(axis=0) 33 | face_weights,_ = data.node_weights[data.face].max(axis=0) 34 | nh = data.x @ v 35 | eh, _ = nh[data.edge_index].min(dim=0) 36 | fh, _ = nh[data.face].min(dim=0) 37 | return ( 38 | compute_wecc(nh, data.batch, lin, data.node_weights,out) 39 | - compute_wecc(eh, data.batch[data.edge_index[0]], lin, edge_weights,out) 40 | ) 41 | 42 | def compute_wect_faces(data, v, lin, out): 43 | # Compute the weights 44 | edge_weights,_ = data.node_weights[data.edge_index].max(axis=0) 45 | face_weights,_ = data.node_weights[data.face].max(axis=0) 46 | nh = data.x @ v 47 | eh, _ = nh[data.edge_index].min(dim=0) 48 | fh, _ = nh[data.face].min(dim=0) 49 | return ( 50 | compute_wecc(nh, data.batch, lin, data.node_weights,out) 51 | - compute_wecc(eh, data.batch[data.edge_index[0]], lin, edge_weights,out) 52 | + compute_wecc(fh, data.batch[data.face[0]], lin, face_weights,out) 53 | ) 54 | 55 | 56 | class WectLayer(nn.Module): 57 | def __init__(self, config: EctConfig, fixed=False): 58 | super().__init__() 59 | self.config = config 60 | self.fixed = fixed 61 | self.lin = ( 62 | torch.linspace(-config.R, config.R, config.bump_steps) 63 | .view(-1, 1, 1) 64 | .to(config.device) 65 | ) 66 | if self.fixed: 67 | self.v = torch.vstack( 68 | [ 69 | torch.sin(torch.linspace(0, 2 * torch.pi, 256)), 70 | torch.cos(torch.linspace(0, 2 * torch.pi, 256)), 71 | ] 72 | ).to(config.device) 73 | else: 74 | self.v = torch.nn.Parameter( 75 | torch.rand(size=(config.num_features, config.num_thetas)) - 0.5 76 | ).to(config.device) 77 | 78 | if config.ecc_type == "points": 79 | self.compute_wect = compute_wect_points 80 | elif config.ecc_type == "edges": 81 | self.compute_wect = compute_wect_edges 82 | elif config.ecc_type == "faces": 83 | self.compute_wect = compute_wect_faces 84 | 85 | def __post_init__(self): 86 | if self.fixed: 87 | geotorch.constraints.sphere(self, "v") 88 | 89 | def forward(self, data): 90 | out = torch.zeros( 91 | self.config.bump_steps, 92 | data.batch.max() + 1, 93 | self.config.num_thetas, 94 | dtype=torch.float32, 95 | device=self.config.device, 96 | ) 97 | return self.compute_wect(data, self.v, self.lin, out) 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /models/wect_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | import operator 5 | 6 | 7 | from models.layers.layers_wect import WectLayer 8 | from models.config import ModelConfig 9 | from loaders.factory import register 10 | from models.base_model import BaseModel 11 | 12 | 13 | class EctCnnModel(BaseModel): 14 | def __init__(self, config: ModelConfig): 15 | super().__init__(config) 16 | self.ectlayer = WectLayer(config.ectconfig) 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(1, 8, kernel_size=3), 19 | nn.MaxPool2d(2), 20 | nn.Conv2d(8, 16, kernel_size=3), 21 | nn.MaxPool2d(2), 22 | ) 23 | num_features = functools.reduce( 24 | operator.mul, 25 | list( 26 | self.conv( 27 | torch.rand( 28 | 1, config.ectconfig.bump_steps, config.ectconfig.num_thetas 29 | ) 30 | ).shape 31 | ), 32 | ) 33 | 34 | self.linear = nn.Sequential( 35 | nn.Linear(num_features, config.hidden), 36 | nn.ReLU(), 37 | nn.Linear(config.hidden, config.hidden), 38 | nn.ReLU(), 39 | nn.Linear(config.hidden, config.num_classes), 40 | ) 41 | 42 | def forward(self, batch): 43 | x = self.ectlayer(batch).unsqueeze(1) 44 | x = self.conv(x).view(x.size(0), -1) 45 | x = self.linear(x) 46 | return x 47 | 48 | 49 | def initialize(): 50 | register("model", EctCnnModel) 51 | -------------------------------------------------------------------------------- /models/wect_linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.base_model import BaseModel 3 | 4 | from models.layers.layers_wect import WectLayer 5 | from models.config import EctConfig 6 | 7 | 8 | class WectLinearModel(BaseModel): 9 | def __init__(self, config: EctConfig): 10 | super().__init__(config) 11 | 12 | self.ectlayer = WectLayer(config.ectconfig) 13 | print("ran THIS") 14 | self.linear = nn.Sequential( 15 | nn.Linear(self.config.ectconfig.num_thetas * self.config.ectconfig.bump_steps, config.hidden), 16 | nn.ReLU(), 17 | nn.Linear(config.hidden, config.hidden), 18 | nn.ReLU(), 19 | nn.Linear(config.hidden, config.num_classes), 20 | ) 21 | 22 | def forward(self, batch): 23 | x = self.ectlayer(batch).reshape( 24 | -1, self.config.ectconfig.num_thetas * self.config.ectconfig.bump_steps 25 | ) 26 | x = self.linear(x) 27 | return x 28 | 29 | 30 | from loaders.factory import register 31 | 32 | 33 | def initialize(): 34 | register("model", WectLinearModel) 35 | -------------------------------------------------------------------------------- /notebooks/dect.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch \n", 10 | "import torch_geometric\n", 11 | "from torch_geometric.data import Data" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import matplotlib_inline" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 11, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Basic dataset with three points,three edges and one face.\n", 31 | "points_coordinates = torch.tensor([[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]])\n", 32 | "edge_index=torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)\n", 33 | "face_index=torch.tensor([[0], [1], [2]], dtype=torch.long)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 14, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# We pick a direction between 0 and 2*pi\n", 43 | "theta = torch.tensor(0.0) \n", 44 | "xi = torch.tensor([torch.sin(theta),torch.cos(theta)])" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 16, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "tensor([0.0000, 0.0000, 0.5000])\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "# Next we compute the node heights as the inner product of the vertex coordinates \n", 62 | "# and the direction.\n", 63 | "node_heigth = points_coordinates @ xi \n", 64 | "print(node_heigth)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 20, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "tensor([0.0000, 0.5000, 0.5000])" 76 | ] 77 | }, 78 | "execution_count": 20, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "# The \"height\" of each edge is defined as the maximum of the node heights of the \n", 85 | "# vertices it is spanned by. Since the edge indices are given as tuples, we lookup the \n", 86 | "# edge height tuples in the node heights vector (indexing is the same) and \n", 87 | "# compute the column-wise maximum. \n", 88 | "\n", 89 | "edge_height_tuples = node_heigth[edge_index]\n", 90 | "edge_height = edge_height_tuples.max(dim=0)[0]\n", 91 | "edge_height" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 21, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "tensor([0.5000])" 103 | ] 104 | }, 105 | "execution_count": 21, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "# For the face heights we perform the same computation.\n", 112 | "face_height_tuples = node_heigth[face_index]\n", 113 | "face_height = face_height_tuples.max(dim=0)[0]\n", 114 | "face_height" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# The two critical points are 0 and 1/2 and these are the places the ect changes. \n", 124 | "# Below zero the ecc is zero, between 0 and 1/2 it is 2-1 = 1 (2 points 1 edges)\n", 125 | "# and above 1/2 it is 3-3+1=1. \n", 126 | "# We find these numbers by counting all the points edges and faces below a certain\n", 127 | "# value. " 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 28, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "tensor([[0., 0., 0.],\n", 140 | " [0., 0., 0.],\n", 141 | " [0., 0., 0.],\n", 142 | " [0., 0., 0.],\n", 143 | " [0., 0., 0.],\n", 144 | " [0., 0., 0.],\n", 145 | " [0., 0., 0.],\n", 146 | " [0., 0., 0.],\n", 147 | " [0., 0., 0.],\n", 148 | " [0., 0., 0.],\n", 149 | " [0., 0., 0.],\n", 150 | " [0., 0., 0.],\n", 151 | " [1., 1., 0.],\n", 152 | " [1., 1., 0.],\n", 153 | " [1., 1., 0.],\n", 154 | " [1., 1., 0.],\n", 155 | " [1., 1., 0.],\n", 156 | " [1., 1., 0.],\n", 157 | " [1., 1., 1.],\n", 158 | " [1., 1., 1.],\n", 159 | " [1., 1., 1.],\n", 160 | " [1., 1., 1.],\n", 161 | " [1., 1., 1.],\n", 162 | " [1., 1., 1.],\n", 163 | " [1., 1., 1.]])\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "# Instead of counting points, we assign each element an indicator function that \n", 169 | "# zero below the critical point and 1 above it. \n", 170 | "# To do so, we translate the indicator function for each point, edge and face. \n", 171 | "\n", 172 | "# Discretize interval in 25 steps\n", 173 | "interval = torch.linspace(-1,1,25).view(-1,1)\n", 174 | "\n", 175 | "translated_nodes = interval - node_heigth \n", 176 | "ecc_points = torch.heaviside(translated_nodes,values=torch.tensor([1.0]))\n", 177 | "print(ecc_points)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 31, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "# Note that the 0 is at index 13 and 1/2 is at index 20. Indeed this is where \n", 187 | "# the curves change value. \n", 188 | "\n", 189 | "# We do the same for the faces and edges." 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 30, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "tensor([[0., 0., 0.],\n", 202 | " [0., 0., 0.],\n", 203 | " [0., 0., 0.],\n", 204 | " [0., 0., 0.],\n", 205 | " [0., 0., 0.],\n", 206 | " [0., 0., 0.],\n", 207 | " [0., 0., 0.],\n", 208 | " [0., 0., 0.],\n", 209 | " [0., 0., 0.],\n", 210 | " [0., 0., 0.],\n", 211 | " [0., 0., 0.],\n", 212 | " [0., 0., 0.],\n", 213 | " [1., 0., 0.],\n", 214 | " [1., 0., 0.],\n", 215 | " [1., 0., 0.],\n", 216 | " [1., 0., 0.],\n", 217 | " [1., 0., 0.],\n", 218 | " [1., 0., 0.],\n", 219 | " [1., 1., 1.],\n", 220 | " [1., 1., 1.],\n", 221 | " [1., 1., 1.],\n", 222 | " [1., 1., 1.],\n", 223 | " [1., 1., 1.],\n", 224 | " [1., 1., 1.],\n", 225 | " [1., 1., 1.]])\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "\n", 231 | "\n", 232 | "# Discretize interval in 25 steps\n", 233 | "interval = torch.linspace(-1,1,25).view(-1,1)\n", 234 | "\n", 235 | "translated_edges = interval - edge_height \n", 236 | "ecc_edges = torch.heaviside(translated_edges,values=torch.tensor([1.0]))\n", 237 | "print(ecc_edges)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 32, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "tensor([[0., 0., 0.],\n", 250 | " [0., 0., 0.],\n", 251 | " [0., 0., 0.],\n", 252 | " [0., 0., 0.],\n", 253 | " [0., 0., 0.],\n", 254 | " [0., 0., 0.],\n", 255 | " [0., 0., 0.],\n", 256 | " [0., 0., 0.],\n", 257 | " [0., 0., 0.],\n", 258 | " [0., 0., 0.],\n", 259 | " [0., 0., 0.],\n", 260 | " [0., 0., 0.],\n", 261 | " [1., 0., 0.],\n", 262 | " [1., 0., 0.],\n", 263 | " [1., 0., 0.],\n", 264 | " [1., 0., 0.],\n", 265 | " [1., 0., 0.],\n", 266 | " [1., 0., 0.],\n", 267 | " [1., 1., 1.],\n", 268 | " [1., 1., 1.],\n", 269 | " [1., 1., 1.],\n", 270 | " [1., 1., 1.],\n", 271 | " [1., 1., 1.],\n", 272 | " [1., 1., 1.],\n", 273 | " [1., 1., 1.]])\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "# Discretize interval in 25 steps\n", 279 | "interval = torch.linspace(-1,1,25).view(-1,1)\n", 280 | "\n", 281 | "translated_edges = interval - edge_height \n", 282 | "ecc_edges = torch.heaviside(translated_edges,values=torch.tensor([1.0]))\n", 283 | "print(ecc_edges)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 33, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "tensor([[0.],\n", 296 | " [0.],\n", 297 | " [0.],\n", 298 | " [0.],\n", 299 | " [0.],\n", 300 | " [0.],\n", 301 | " [0.],\n", 302 | " [0.],\n", 303 | " [0.],\n", 304 | " [0.],\n", 305 | " [0.],\n", 306 | " [0.],\n", 307 | " [0.],\n", 308 | " [0.],\n", 309 | " [0.],\n", 310 | " [0.],\n", 311 | " [0.],\n", 312 | " [0.],\n", 313 | " [1.],\n", 314 | " [1.],\n", 315 | " [1.],\n", 316 | " [1.],\n", 317 | " [1.],\n", 318 | " [1.],\n", 319 | " [1.]])\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "# Discretize interval in 25 steps\n", 325 | "interval = torch.linspace(-1,1,25).view(-1,1)\n", 326 | "\n", 327 | "translated_faces = interval - face_height \n", 328 | "ecc_faces = torch.heaviside(translated_faces,values=torch.tensor([1.0]))\n", 329 | "print(ecc_faces)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 37, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.,\n", 341 | " 1., 1., 1., 1., 1., 1., 1.])" 342 | ] 343 | }, 344 | "execution_count": 37, 345 | "metadata": {}, 346 | "output_type": "execute_result" 347 | } 348 | ], 349 | "source": [ 350 | "# The ect along this direction is then computed by first computing the sum of\n", 351 | "# columns in each of the three matrices and then by computing the \n", 352 | "# alternating sum of the three matrices.\n", 353 | "\n", 354 | "ecc = ecc_points.sum(axis=1) - ecc_edges.sum(axis=1) + ecc_faces.sum(axis=1) \n", 355 | "ecc" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# We can indeed verify that at index 13 the value changes from 0 to 1 (which is)\n", 365 | "# the origin in our coordinate system." 366 | ] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "display_name": ".venv", 372 | "language": "python", 373 | "name": "python3" 374 | }, 375 | "language_info": { 376 | "codemirror_mode": { 377 | "name": "ipython", 378 | "version": 3 379 | }, 380 | "file_extension": ".py", 381 | "mimetype": "text/x-python", 382 | "name": "python", 383 | "nbconvert_exporter": "python", 384 | "pygments_lexer": "ipython3", 385 | "version": "3.10.11" 386 | } 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ansi2html==1.8.0 2 | antlr4-python3-runtime==4.9.3 3 | appdirs==1.4.4 4 | asttokens==2.2.1 5 | attrs==23.1.0 6 | backcall==0.2.0 7 | certifi==2023.7.22 8 | charset-normalizer==3.2.0 9 | click==8.1.6 10 | colorama==0.4.6 11 | comm==0.1.4 12 | ConfigArgParse==1.7 13 | dash==2.11.1 14 | dash-core-components==2.0.0 15 | dash-html-components==2.0.0 16 | dash-table==5.0.0 17 | decorator==5.1.1 18 | docker-pycreds==0.4.0 19 | executing==1.2.0 20 | fastjsonschema==2.18.0 21 | filelock==3.9.0 22 | Flask==2.2.5 23 | geotorch==0.3.0 24 | gitdb==4.0.10 25 | GitPython==3.1.32 26 | idna==3.4 27 | ipython==8.14.0 28 | ipywidgets==8.1.0 29 | itsdangerous==2.1.2 30 | jedi==0.19.0 31 | Jinja2==3.1.2 32 | joblib==1.3.1 33 | jsonschema==4.19.0 34 | jsonschema-specifications==2023.7.1 35 | jupyter_core==5.3.1 36 | jupyterlab-widgets==3.0.8 37 | MarkupSafe==2.1.3 38 | matplotlib-inline==0.1.6 39 | mpmath==1.2.1 40 | nbformat==5.7.0 41 | nest-asyncio==1.5.7 42 | networkx==3.0 43 | numpy==1.25.2 44 | omegaconf==2.3.0 45 | open3d==0.17.0 46 | packaging==23.1 47 | pandas==2.0.3 48 | parso==0.8.3 49 | pathtools==0.1.2 50 | pickleshare==0.7.5 51 | Pillow==9.3.0 52 | platformdirs==3.10.0 53 | plotly==5.15.0 54 | prompt-toolkit==3.0.39 55 | protobuf==4.23.4 56 | psutil==5.9.5 57 | pure-eval==0.2.2 58 | Pygments==2.16.1 59 | pyparsing==3.1.1 60 | python-dateutil==2.8.2 61 | pytz==2023.3 62 | pywin32==306 63 | PyYAML==6.0.1 64 | referencing==0.30.2 65 | requests==2.31.0 66 | retrying==1.3.4 67 | rpds-py==0.9.2 68 | scikit-learn==1.3.0 69 | scipy==1.11.1 70 | sentry-sdk==1.29.2 71 | setproctitle==1.3.2 72 | six==1.16.0 73 | smmap==5.0.0 74 | stack-data==0.6.2 75 | sympy==1.11.1 76 | tenacity==8.2.2 77 | threadpoolctl==3.2.0 78 | torch==2.0.1+cu117 79 | torch-geometric==2.3.1 80 | torchaudio==2.0.2+cu117 81 | torchvision==0.15.2+cu117 82 | tqdm==4.65.1 83 | traitlets==5.9.0 84 | typing_extensions==4.4.0 85 | tzdata==2023.3 86 | urllib3==2.0.4 87 | wandb==0.15.8 88 | wcwidth==0.2.6 89 | Werkzeug==2.2.3 90 | widgetsnbextension==4.0.8 91 | -------------------------------------------------------------------------------- /single_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from main import compute_avg 3 | 4 | from main import Experiment 5 | from logger import Logger, timing 6 | 7 | mylogger = Logger() 8 | 9 | def main(): 10 | accs = [] 11 | for _ in range(1): 12 | print("Running experiment", "ect_cnn_best.yaml") 13 | exp = Experiment( 14 | "./experiment/lrgb/lrgb_linear_edges.yaml", logger=mylogger, dev=True 15 | ) 16 | # exp = Experiment( 17 | # "./experiment/manifold_classification/ect_cnn_faces.yaml", 18 | # logger=mylogger, 19 | # dev=True, 20 | # ) 21 | loss, acc = exp.run() 22 | accs.append(acc) 23 | compute_avg(torch.tensor(accs)) 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from types import SimpleNamespace 3 | import math 4 | import numpy as np 5 | import os 6 | from pathlib import Path 7 | import torch 8 | from torch.optim import lr_scheduler 9 | import argparse 10 | import json 11 | 12 | class Parser(argparse.ArgumentParser): 13 | def __init__(self): 14 | super(Parser,self).__init__() 15 | self.add_argument( 16 | '-c', '--config', 17 | default="./config.json", 18 | type=str, 19 | help="Choose configuration for the experiment" 20 | ) 21 | 22 | def parse(self): 23 | """Loads config file if a string was passed 24 | and returns the input if a dictionary was passed. 25 | """ 26 | config_file = self.parse_args().config 27 | if isinstance(config_file, str): 28 | with open(config_file) as json_file: 29 | return json.load(json_file, object_hook=lambda d: SimpleNamespace(**d)) 30 | elif isinstance(config_file, dict): 31 | return config_file 32 | else: 33 | raise AttributeError() 34 | 35 | # def transfer_to_device(x, device): 36 | # """Transfers pytorch tensors or lists of tensors to GPU. This 37 | # function is recursive to be able to deal with lists of lists. 38 | # """ 39 | # if isinstance(x, list): 40 | # for i in range(len(x)): 41 | # x[i] = transfer_to_device(x[i], device) 42 | # else: 43 | # x = x.to(device) 44 | # return x 45 | # 46 | # 47 | # def get_scheduler(optimizer, configuration, last_epoch=-1): 48 | # """Return a learning rate scheduler. 49 | # """ 50 | # if configuration['lr_policy'] == 'step': 51 | # scheduler = lr_scheduler.StepLR(optimizer, step_size=configuration['lr_decay_iters'], gamma=0.3, last_epoch=last_epoch) 52 | # else: 53 | # return NotImplementedError('learning rate policy [{0}] is not implemented'.format(configuration['lr_policy'])) 54 | # return scheduler 55 | # 56 | # 57 | # def stack_all(list, dim=0): 58 | # """Stack all iterables of torch tensors in a list (i.e. [[(tensor), (tensor)], [(tensor), (tensor)]]) 59 | # """ 60 | # return [torch.stack(s, dim) for s in list] 61 | 62 | 63 | 64 | def count_parameters(model): 65 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 66 | 67 | 68 | def listdir(d): 69 | return [os.path.join(d, f) for f in os.listdir(d)] 70 | --------------------------------------------------------------------------------