├── .cruft.json ├── .flake8 ├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ └── check-links.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE.txt ├── README.md ├── codespell.txt ├── mlc_config.json ├── mypy.ini ├── notebooks ├── 002_preproc_data.ipynb ├── 009_build_graphs_ml.ipynb ├── 020_one_shot_object_condensation.ipynb ├── 030_edge_classification.ipynb └── 040_three_shot_object_condensation.ipynb └── pyproject.toml /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "git@github.com:klieret/python-cookiecutter.git", 3 | "commit": "c650bc109b1f408ded56bbb421747630309e10f8", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "gnn-tracking-tutorials", 8 | "package_name": "gnn_tracking_tut", 9 | "description": "Tutorials and onboarding for the GNN Tracking project", 10 | "user": "gnn-tracking", 11 | "url": "https://github.com/gnn-tracking/tutorials", 12 | "full_name": "Kilian Lieret, Gage deZoort", 13 | "email": "kilian.lieret@posteo.de", 14 | "maintainer": "Kilian Lieret, Gage deZoort", 15 | "maintainer_email": "kilian.lieret@posteo.de", 16 | "year": "2023", 17 | "_copy_without_render": ["*.css"], 18 | "_template": "git@github.com:klieret/python-cookiecutter.git" 19 | } 20 | }, 21 | "directory": null 22 | } 23 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 80 3 | select = C,E,F,W,B,B950 4 | ignore = E203, E501, W503 5 | exclude = 6 | .git, 7 | __pycache__, 8 | notebooks, 9 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | *.ipynb diff=jupyternotebook 3 | 4 | *.ipynb merge=jupyternotebook 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /.github/workflows/check-links.yaml: -------------------------------------------------------------------------------- 1 | name: Check Markdown links 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: "0 0 1 * *" 8 | 9 | jobs: 10 | markdown-link-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@master 14 | - uses: gaurav-nelson/github-action-markdown-link-check@v1 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # PROJECT SPECIFIC IGNORES 3 | # ============================================================================= 4 | *.pt 5 | **/lightning_logs/** 6 | **/wandb/** 7 | 8 | # ============================================================================= 9 | # GENERAL PYTHON GITIGNORE 10 | # ============================================================================= 11 | # Created by https://www.toptal.com/developers/gitignore/api/python 12 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 13 | 14 | ### Python ### 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | cover/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | .pybuilder/ 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | # For a library or package, you might want to ignore these files since the code is 101 | # intended to run in multiple environments; otherwise, check them in: 102 | # .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # poetry 112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 113 | # This is especially recommended for binary packages to ensure reproducibility, and is more 114 | # commonly ignored for libraries. 115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 116 | #poetry.lock 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/#use-with-ide 124 | .pdm.toml 125 | 126 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 127 | __pypackages__/ 128 | 129 | # Celery stuff 130 | celerybeat-schedule 131 | celerybeat.pid 132 | 133 | # SageMath parsed files 134 | *.sage.py 135 | 136 | # Environments 137 | .env 138 | .venv 139 | env/ 140 | venv/ 141 | ENV/ 142 | env.bak/ 143 | venv.bak/ 144 | 145 | # Spyder project settings 146 | .spyderproject 147 | .spyproject 148 | 149 | # Rope project settings 150 | .ropeproject 151 | 152 | # mkdocs documentation 153 | /site 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | #.idea/ 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.8.0 4 | hooks: 5 | - id: black 6 | - id: black-jupyter 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.6.0 9 | hooks: 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: detect-private-key 13 | - id: end-of-file-fixer 14 | exclude: '.*\.ipynb' 15 | - id: trailing-whitespace 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | name: isort (python) 21 | args: ["--profile", "black", "-a", "", "--append-only"] 22 | - repo: https://github.com/PyCQA/flake8 23 | rev: "7.1.1" 24 | hooks: 25 | - id: flake8 26 | additional_dependencies: ["flake8-bugbear"] 27 | - repo: https://github.com/pre-commit/mirrors-mypy 28 | rev: "v1.11.2" 29 | hooks: 30 | - id: mypy 31 | exclude: 'docs/source/conf\.py' 32 | - repo: https://github.com/codespell-project/codespell 33 | rev: "v2.3.0" 34 | hooks: 35 | - id: codespell 36 | args: ["-I", "codespell.txt"] 37 | exclude: '.*\.ipynb' 38 | - repo: https://github.com/asottile/pyupgrade 39 | rev: v3.17.0 40 | hooks: 41 | - id: pyupgrade 42 | args: ["--py37-plus"] 43 | - repo: https://github.com/asottile/setup-cfg-fmt 44 | rev: "v2.5.0" 45 | hooks: 46 | - id: setup-cfg-fmt 47 | args: [--include-version-classifiers, --max-py-version=3.10] 48 | - repo: https://github.com/hadialqattan/pycln 49 | rev: v2.4.0 50 | hooks: 51 | - id: pycln 52 | args: [--config=pyproject.toml] 53 | 54 | ci: 55 | autoupdate_schedule: monthly 56 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 4 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Kilian Lieret 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # GNN Tracking Tutorials 4 | 5 | 6 | 7 | 8 | 9 | 10 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/gnn-tracking/tutorials/main.svg)](https://results.pre-commit.ci/latest/github/gnn-tracking/tutorials/main) 11 | [![link checker](https://github.com/gnn-tracking/tutorials/actions/workflows/check-links.yaml/badge.svg)](https://github.com/gnn-tracking/tutorials/actions/workflows/check-links.yaml) 12 | [![gitmoji](https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67.svg)](https://gitmoji.dev) 13 | [![License](https://img.shields.io/github/license/gnn-tracking/tutorials)](https://github.com/gnn-tracking/tutorials/blob/master/LICENSE.txt) 14 | [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) 15 | [![PR welcome](https://img.shields.io/badge/PR-Welcome-%23FF8300.svg)](https://git-scm.com/book/en/v2/GitHub-Contributing-to-a-Project) 16 | 17 |
18 | 19 | ## 📝 Description 20 | 21 | Tutorials and onboarding for the GNN Tracking project 22 | 23 | ## 📦 Installation 24 | 25 | 1. Follow the instructions from [the main library](https://github.com/gnn-tracking/gnn_tracking) 26 | to set up the conda environment and install the package. 27 | 28 | 2. Run `pytest` on the main package 29 | 30 | 3. Download the [trackml dataset](https://competitions.codalab.org/competitions/20112) (see note below). 31 | Note that the full dataset is linked in "Dataset description and other files" in the "Participate" 32 | tab (O(100GB) of data). The data files of the "Starting Kit"/"Public Data" are only a tiny fraction 33 | of this. If you just want to check that everything works or are currently waiting to get access to 34 | the data (see note below), you can use the small dataset from [test-data][]. 35 | 36 | [test-data]: https://github.com/gnn-tracking/test-data 37 | 38 | > [!Note] 39 | > The website [competitions.codalab.org](https://competitions.codalab.org/) does no longer accept new 40 | > registrations (needed to access the data set). If you do not already have an account there, 41 | > contact us for the dataset. 42 | 43 | > [!Warning] 44 | > Do not use the data from the similar Kaggle challenge: This is an older version of the same 45 | > data that is missing some bug fixes! 46 | 47 | ## 🧰 Development setup 48 | 49 | ```bash 50 | pip3 install pre-commit 51 | pre-commit install 52 | nbdime config-git enable 53 | ``` 54 | -------------------------------------------------------------------------------- /codespell.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnn-tracking/tutorials/099636640c25b4c9f30a17c178534dfa1c09f0fe/codespell.txt -------------------------------------------------------------------------------- /mlc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "ignorePatterns": [ 3 | { 4 | "pattern": "https://github.com/.*/issues.*" 5 | }, 6 | { 7 | "pattern": "https://github.com/.*/pulls.*" 8 | } 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | follow_imports = silent 4 | exclude = docs/ 5 | check_untyped_defs = True 6 | # disallow_untyped_defs = True 7 | -------------------------------------------------------------------------------- /notebooks/009_build_graphs_ml.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "source": [ 10 | "# Building graphs with Metric Learning\n", 11 | "\n", 12 | "This notebook shows how to build graphs using a metric learning strategy. For this, every hit is independently projected to a latent space using a fully connected neural network. The network is trained to put hits from the same particle close to each other and hits from different particles far from each other. An initial graph can then be constructed by connecting hits that are close in this space.\n", 13 | "This strategy has been adapted by ExaTrkx, see for example [section 5.2 here.](https://link.springer.com/10.1140/epjc/s10052-021-09675-8)\n", 14 | "\n", 15 | "This notebook also serves as an introduction to the new pytorch lightning-based framework." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "vscode": { 23 | "languageId": "python" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "from functools import partial\n", 29 | "from pathlib import Path\n", 30 | "\n", 31 | "import torch\n", 32 | "\n", 33 | "from gnn_tracking.training.ml import MLModule\n", 34 | "from gnn_tracking.models.graph_construction import GraphConstructionFCNN\n", 35 | "from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss\n", 36 | "from pytorch_lightning import Trainer\n", 37 | "from gnn_tracking.utils.loading import TrackingDataModule\n", 38 | "from gnn_tracking.training.callbacks import PrintValidationMetrics\n", 39 | "from gnn_tracking.utils.versioning import assert_version_geq\n", 40 | "\n", 41 | "from torch_geometric.data import Data\n", 42 | "from torch import nn\n", 43 | "from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin\n", 44 | "\n", 45 | "assert_version_geq(\"23.12.0\")" 46 | ] 47 | }, 48 | { 49 | "attachments": {}, 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Step 1: Configuring the data" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "The configuration for train/val/test data and its dataloader is held in the `TrackingDataModule` (subclass of `LightningDataModule`)." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 2, 67 | "metadata": { 68 | "vscode": { 69 | "languageId": "python" 70 | } 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "data_dir = Path.cwd().resolve().parent.parent / \"test-data\" / \"data\" / \"point_clouds\" / \"v8\"\n", 75 | "assert data_dir.is_dir()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": { 82 | "vscode": { 83 | "languageId": "python" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "dm = TrackingDataModule(\n", 89 | " train=dict(\n", 90 | " dirs=[data_dir],\n", 91 | " stop=1,\n", 92 | " ),\n", 93 | " val=dict(\n", 94 | " dirs=[data_dir],\n", 95 | " start=1,\n", 96 | " stop=2,\n", 97 | " ),\n", 98 | " identifier=\"point_clouds_v8\"\n", 99 | " # could also configure a 'test' set here\n", 100 | ")" 101 | ] 102 | }, 103 | { 104 | "attachments": {}, 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "Other keys allow to configure the loaders (batch size, number of workers, etc.). See the docstring of `TrackingDataModule` for details." 109 | ] 110 | }, 111 | { 112 | "attachments": {}, 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "### Details (for understanding)" 117 | ] 118 | }, 119 | { 120 | "attachments": {}, 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "Note that all of the following will be done implicitly by the `Trainer` and you won't have to worry about it. But if you want to inspect the data, you can do so.\n", 125 | "\n", 126 | "When calling the `setup` method, the `LightningDataModule` initializes instances of `TrackingDataset` (`torch_geometric.Dataset`) for each of these. We can get the corresponding dataloaders by calling `dm.train_dataloader()` and analog for validation and test.\n", 127 | "\n", 128 | "Example:" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": { 135 | "vscode": { 136 | "languageId": "python" 137 | } 138 | }, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "\u001b[32m[14:44:10] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 145 | "\u001b[36m[14:44:10] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt\u001b[0m\n", 146 | "\u001b[32m[14:44:10] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 147 | "\u001b[36m[14:44:10] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt\u001b[0m\n" 148 | ] 149 | }, 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "{'train': TrackingDataset(), 'val': TrackingDataset()}" 154 | ] 155 | }, 156 | "execution_count": 4, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "# This is called by the Trainer automatically and sets up the datasets\n", 163 | "dm.setup(stage=\"fit\") # 'fit' combines 'train' and 'val'\n", 164 | "# Now the datasets are available:\n", 165 | "dm.datasets" 166 | ] 167 | }, 168 | { 169 | "attachments": {}, 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "For example, we can inspect the first element of the training dataset:" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "metadata": { 180 | "vscode": { 181 | "languageId": "python" 182 | } 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "data = dm.datasets[\"train\"][0]" 187 | ] 188 | }, 189 | { 190 | "attachments": {}, 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "To get the corresponding dataloaders, use one of the methods (but again, you probalby won't need to):" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 6, 200 | "metadata": { 201 | "vscode": { 202 | "languageId": "python" 203 | } 204 | }, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "(,\n", 210 | " )" 211 | ] 212 | }, 213 | "execution_count": 6, 214 | "metadata": {}, 215 | "output_type": "execute_result" 216 | } 217 | ], 218 | "source": [ 219 | "dm.train_dataloader(), dm.val_dataloader()" 220 | ] 221 | }, 222 | { 223 | "attachments": {}, 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "## Step 2: Configuring a model" 228 | ] 229 | }, 230 | { 231 | "attachments": {}, 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "We write a normal `torch.nn.Module`. The easiest way is to import one of the modules that we have already written in the `gnn_tracking` librar." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 7, 241 | "metadata": { 242 | "vscode": { 243 | "languageId": "python" 244 | } 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "model = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=5, hidden_dim=64)" 249 | ] 250 | }, 251 | { 252 | "attachments": {}, 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "However, you can also write your own. Here is a very simple one:" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "metadata": { 263 | "vscode": { 264 | "languageId": "python" 265 | } 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "class DemoGraphConstructionModel(nn.Module, HyperparametersMixin):\n", 270 | " def __init__(\n", 271 | " self,\n", 272 | " in_dim: int,\n", 273 | " hidden_dim: int,\n", 274 | " out_dim: int,\n", 275 | " depth: int = 5,\n", 276 | " ):\n", 277 | " super().__init__()\n", 278 | " # This is made available by the HyperparametersMixin\n", 279 | " # all of our hyperparameters from the __init__ arguments\n", 280 | " # are saved to self.hparams (but we don't need this in this\n", 281 | " # example)\n", 282 | " self.save_hyperparameters()\n", 283 | " assert depth > 2\n", 284 | " _layers = [\n", 285 | " nn.Linear(in_dim, hidden_dim),\n", 286 | " nn.ReLU(),\n", 287 | " ]\n", 288 | " for _ in range(depth - 2):\n", 289 | " _layers.append(nn.Linear(hidden_dim, hidden_dim))\n", 290 | " _layers.append(nn.ReLU())\n", 291 | " _layers.append(nn.Linear(hidden_dim, out_dim))\n", 292 | " self._model = nn.Sequential(*_layers)\n", 293 | "\n", 294 | " def forward(self, data: Data):\n", 295 | " # Our trainer class will expect us to return a dictionary, where\n", 296 | " # the key H has the transformed latent space.\n", 297 | " return {\"H\": self._model(data.x)}" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "Uncomment the next line to use model we just wrote (rather than the default)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 9, 310 | "metadata": { 311 | "vscode": { 312 | "languageId": "python" 313 | } 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "# model = DemoGraphConstructionModel(in_dim=14, out_dim=8, hidden_dim=64)" 318 | ] 319 | }, 320 | { 321 | "attachments": {}, 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "If you are familiar with normal pytorch, there was only few differences:\n", 326 | "\n", 327 | "1. We inherit from `HyperparamsMixin`\n", 328 | "2. We call `self.save_hyperparameters()`" 329 | ] 330 | }, 331 | { 332 | "attachments": {}, 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "### Details (for understanding)" 337 | ] 338 | }, 339 | { 340 | "attachments": {}, 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "We saved all hyperparameters:" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 10, 350 | "metadata": { 351 | "vscode": { 352 | "languageId": "python" 353 | } 354 | }, 355 | "outputs": [ 356 | { 357 | "data": { 358 | "text/plain": [ 359 | "\"alpha\": 0.6\n", 360 | "\"depth\": 5\n", 361 | "\"hidden_dim\": 64\n", 362 | "\"in_dim\": 14\n", 363 | "\"out_dim\": 8" 364 | ] 365 | }, 366 | "execution_count": 10, 367 | "metadata": {}, 368 | "output_type": "execute_result" 369 | } 370 | ], 371 | "source": [ 372 | "model.hparams" 373 | ] 374 | }, 375 | { 376 | "attachments": {}, 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "Note how `depth=5` was saved despite not being specified explicitly (it was recognized as a default parameter)." 381 | ] 382 | }, 383 | { 384 | "attachments": {}, 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "As always, you can simply evaluate the `model` on a piece of data:" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 11, 394 | "metadata": { 395 | "vscode": { 396 | "languageId": "python" 397 | } 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "out = model(data)" 402 | ] 403 | }, 404 | { 405 | "attachments": {}, 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "## Step 3: Configuring loss functions, metrics and the lightning module" 410 | ] 411 | }, 412 | { 413 | "attachments": {}, 414 | "cell_type": "markdown", 415 | "metadata": {}, 416 | "source": [ 417 | "The pytorch model is bundled together with a set of loss functions (just one here), that we backpropagate from in the training step, and a set of metrics. Together, these components make up the `LightningModule` that we pass to the pytorch lightning `Trainer` for training.\n", 418 | "\n", 419 | "If you were familiar with our previous `TCNTrainer` training class, this `MLModule` now fulfills (almost) the exact same role." 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 12, 425 | "metadata": { 426 | "vscode": { 427 | "languageId": "python" 428 | } 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "lmodel = MLModule(\n", 433 | " model=model,\n", 434 | " loss_fct=GraphConstructionHingeEmbeddingLoss(\n", 435 | " lw_repulsive=0.5,\n", 436 | " max_num_neighbors=10,\n", 437 | " ),\n", 438 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n", 439 | ")" 440 | ] 441 | }, 442 | { 443 | "attachments": {}, 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "### Details (for understanding)" 448 | ] 449 | }, 450 | { 451 | "attachments": {}, 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "Again, all hyperparameters are accessible (even the ones that weren't explicitly specified but only set by default):" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 13, 461 | "metadata": { 462 | "vscode": { 463 | "languageId": "python" 464 | } 465 | }, 466 | "outputs": [ 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "\"gc_scanner\": None\n", 471 | "\"loss_fct\": {'class_path': 'gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss', 'init_args': {'lw_repulsive': 0.5, 'r_emb': 1.0, 'max_num_neighbors': 10, 'pt_thld': 0.9, 'max_eta': 4.0, 'p_attr': 1.0, 'p_rep': 1.0}}\n", 472 | "\"model\": {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN', 'init_args': {'in_dim': 14, 'hidden_dim': 64, 'out_dim': 8, 'depth': 5, 'alpha': 0.6}}\n", 473 | "\"preproc\": None" 474 | ] 475 | }, 476 | "execution_count": 13, 477 | "metadata": {}, 478 | "output_type": "execute_result" 479 | } 480 | ], 481 | "source": [ 482 | "lmodel.hparams" 483 | ] 484 | }, 485 | { 486 | "attachments": {}, 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "As you can see, any _objects_ that were passed to the model are also saved to the hyperparameters in a way that we can bring them back." 491 | ] 492 | }, 493 | { 494 | "attachments": {}, 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "The loss function takes output from the model and the data and returns two separate losses:" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 14, 504 | "metadata": { 505 | "vscode": { 506 | "languageId": "python" 507 | } 508 | }, 509 | "outputs": [ 510 | { 511 | "data": { 512 | "text/plain": [ 513 | "Data(x=[66114, 14], edge_index=[2, 229066], y=[0], layer=[66114], particle_id=[66114], pt=[66114], reconstructable=[66114], sector=[66114], eta=[66114], n_hits=[66114], n_layers_hit=[66114])" 514 | ] 515 | }, 516 | "execution_count": 14, 517 | "metadata": {}, 518 | "output_type": "execute_result" 519 | } 520 | ], 521 | "source": [ 522 | "data" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 17, 528 | "metadata": { 529 | "vscode": { 530 | "languageId": "python" 531 | } 532 | }, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "MultiLossFctReturn(loss_dct={'attractive': tensor(0.0303, grad_fn=), 'repulsive': tensor(0.9888, grad_fn=)}, weight_dct={'attractive': 1.0, 'repulsive': 1.0}, extra_metrics={})" 538 | ] 539 | }, 540 | "execution_count": 17, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "loss_fct = GraphConstructionHingeEmbeddingLoss()\n", 547 | "loss_fct(\n", 548 | " x=out[\"H\"],\n", 549 | " particle_id=data.particle_id,\n", 550 | " batch=data.batch,\n", 551 | " edge_index=data.edge_index,\n", 552 | " pt=data.pt,\n", 553 | " eta=data.eta,\n", 554 | " reconstructable=data.reconstructable,\n", 555 | " true_edge_index=data.edge_index,\n", 556 | " max_num_neighbors=2,\n", 557 | ")" 558 | ] 559 | }, 560 | { 561 | "attachments": {}, 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "Both parts of the loss functions are combined with the loss weight we have configured above (weight of 1 for attractive, weight of 0.5 for repulsive). All of this is done in `MLModule.get_losses` (returning the total loss and a dictionary of the individual losses):" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 16, 571 | "metadata": { 572 | "vscode": { 573 | "languageId": "python" 574 | } 575 | }, 576 | "outputs": [ 577 | { 578 | "data": { 579 | "text/plain": [ 580 | "(tensor(0.5280, grad_fn=),\n", 581 | " {'attractive': tensor(0.0303, grad_fn=),\n", 582 | " 'repulsive': tensor(0.9954, grad_fn=),\n", 583 | " 'attractive_weighted': 0.03026364929974079,\n", 584 | " 'repulsive_weighted': 0.4976946711540222,\n", 585 | " 'total': 0.527958333492279})" 586 | ] 587 | }, 588 | "execution_count": 16, 589 | "metadata": {}, 590 | "output_type": "execute_result" 591 | } 592 | ], 593 | "source": [ 594 | "lmodel.get_losses(out, data)" 595 | ] 596 | }, 597 | { 598 | "attachments": {}, 599 | "cell_type": "markdown", 600 | "metadata": {}, 601 | "source": [ 602 | "## Step 4: Training" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 29, 608 | "metadata": { 609 | "vscode": { 610 | "languageId": "python" 611 | } 612 | }, 613 | "outputs": [ 614 | { 615 | "name": "stderr", 616 | "output_type": "stream", 617 | "text": [ 618 | "GPU available: False, used: False\n", 619 | "TPU available: False, using: 0 TPU cores\n", 620 | "IPU available: False, using: 0 IPUs\n", 621 | "HPU available: False, using: 0 HPUs\n", 622 | "\u001b[32m[14:56:36] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 623 | "\u001b[36m[14:56:36] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt\u001b[0m\n", 624 | "\u001b[32m[14:56:36] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 625 | "\u001b[36m[14:56:36] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt\u001b[0m\n", 626 | "\n", 627 | " | Name | Type | Params\n", 628 | "-----------------------------------------------------------------\n", 629 | "0 | model | GraphConstructionFCNN | 17.8 K\n", 630 | "1 | loss_fct | GraphConstructionHingeEmbeddingLoss | 0 \n", 631 | "-----------------------------------------------------------------\n", 632 | "17.8 K Trainable params\n", 633 | "0 Non-trainable params\n", 634 | "17.8 K Total params\n", 635 | "0.071 Total estimated model params size (MB)\n" 636 | ] 637 | }, 638 | { 639 | "name": "stdout", 640 | "output_type": "stream", 641 | "text": [ 642 | "Epoch 0: 100%|███████████████████████████████████████████████████████████| 1/1 [01:38<00:00, 0.01it/s, v_num=2, attractive_train=0.0301, repulsive_train=0.995, attractive_weighted_train=0.0301, repulsive_weighted_train=0.498, total_train=0.528]" 643 | ] 644 | }, 645 | { 646 | "data": { 647 | "text/html": [ 648 | "
\n"
 649 |       ],
 650 |       "text/plain": []
 651 |      },
 652 |      "metadata": {},
 653 |      "output_type": "display_data"
 654 |     },
 655 |     {
 656 |      "name": "stdout",
 657 |      "output_type": "stream",
 658 |      "text": [
 659 |       "\n",
 660 |       "\u001b[3m              Validation epoch=0               \u001b[0m                                                                                                                                                                                              \n",
 661 |       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓\n",
 662 |       "┃\u001b[1m \u001b[0m\u001b[1mMetric                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m  Value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mError\u001b[0m\u001b[1m \u001b[0m┃\n",
 663 |       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩\n",
 664 |       "│\u001b[1;95m \u001b[0m\u001b[1;95mattractive               \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.03383\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m  nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
 665 |       "│ attractive_train          │ 0.03009 │   nan │\n",
 666 |       "│ attractive_weighted       │ 0.03383 │   nan │\n",
 667 |       "│ attractive_weighted_train │ 0.03009 │   nan │\n",
 668 |       "│\u001b[1;95m \u001b[0m\u001b[1;95mrepulsive                \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.99553\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m  nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
 669 |       "│ repulsive_train           │ 0.99535 │   nan │\n",
 670 |       "│ repulsive_weighted        │ 0.49777 │   nan │\n",
 671 |       "│ repulsive_weighted_train  │ 0.49767 │   nan │\n",
 672 |       "│\u001b[1;95m \u001b[0m\u001b[1;95mtotal                    \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.53160\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m  nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
 673 |       "│ total_train               │ 0.52777 │   nan │\n",
 674 |       "└───────────────────────────┴─────────┴───────┘\n",
 675 |       "\n",
 676 |       "Epoch 0: 100%|███████████████████████████████████████████████████████████| 1/1 [03:02<00:00,  0.01it/s, v_num=2, attractive_train=0.0301, repulsive_train=0.995, attractive_weighted_train=0.0301, repulsive_weighted_train=0.498, total_train=0.528]"
 677 |      ]
 678 |     },
 679 |     {
 680 |      "name": "stderr",
 681 |      "output_type": "stream",
 682 |      "text": [
 683 |       "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
 684 |      ]
 685 |     },
 686 |     {
 687 |      "name": "stdout",
 688 |      "output_type": "stream",
 689 |      "text": [
 690 |       "Epoch 0: 100%|███████████████████████████████████████████████████████████| 1/1 [03:02<00:00,  0.01it/s, v_num=2, attractive_train=0.0301, repulsive_train=0.995, attractive_weighted_train=0.0301, repulsive_weighted_train=0.498, total_train=0.528]\n"
 691 |      ]
 692 |     }
 693 |    ],
 694 |    "source": [
 695 |     "trainer = Trainer(\n",
 696 |     "    max_epochs=1,\n",
 697 |     "    accelerator=\"cpu\",\n",
 698 |     "    log_every_n_steps=1,\n",
 699 |     "    callbacks=[PrintValidationMetrics()],\n",
 700 |     ")\n",
 701 |     "trainer.fit(model=lmodel, datamodule=dm)"
 702 |    ]
 703 |   },
 704 |   {
 705 |    "attachments": {},
 706 |    "cell_type": "markdown",
 707 |    "metadata": {
 708 |     "collapsed": false
 709 |    },
 710 |    "source": [
 711 |     "### If there are issues with the progress bar\n",
 712 |     "\n",
 713 |     "The lightning progress bar can be finnicky when combined with printing the validation results to the command line, especially when running from a Jupyter notebook. Here's a couple of things to try:\n",
 714 |     "\n",
 715 |     "* set `enable_progress_bar=False` in the `Trainer` initialization to disable the progress bar\n",
 716 |     "* use `callbacks=[pytorch_lightning.callbacks.RichProgressBar(leave=True), ...]` in the `Trainer` initialization (this is a prettier progress bar, anyway). I\n",
 717 |     "* use `callbacks=[gnn_tracking.utils.lightning.SimpleTqdmProgressBar(leave=True), ...]`\n",
 718 |     "* remove the `PrintValidationMetrics` callback"
 719 |    ]
 720 |   },
 721 |   {
 722 |    "attachments": {},
 723 |    "cell_type": "markdown",
 724 |    "metadata": {},
 725 |    "source": [
 726 |     "## Restoring a pre-trained model"
 727 |    ]
 728 |   },
 729 |   {
 730 |    "attachments": {},
 731 |    "cell_type": "markdown",
 732 |    "metadata": {},
 733 |    "source": [
 734 |     "Take a look at the `lightning_logs` directory:"
 735 |    ]
 736 |   },
 737 |   {
 738 |    "cell_type": "code",
 739 |    "execution_count": 31,
 740 |    "metadata": {
 741 |     "vscode": {
 742 |      "languageId": "python"
 743 |     }
 744 |    },
 745 |    "outputs": [
 746 |     {
 747 |      "name": "stdout",
 748 |      "output_type": "stream",
 749 |      "text": [
 750 |       "version_0  version_1  version_2\n"
 751 |      ]
 752 |     }
 753 |    ],
 754 |    "source": [
 755 |     "! ls lightning_logs"
 756 |    ]
 757 |   },
 758 |   {
 759 |    "cell_type": "markdown",
 760 |    "metadata": {},
 761 |    "source": [
 762 |     "Take the latest version number in the following"
 763 |    ]
 764 |   },
 765 |   {
 766 |    "cell_type": "code",
 767 |    "execution_count": 32,
 768 |    "metadata": {
 769 |     "vscode": {
 770 |      "languageId": "python"
 771 |     }
 772 |    },
 773 |    "outputs": [
 774 |     {
 775 |      "name": "stdout",
 776 |      "output_type": "stream",
 777 |      "text": [
 778 |       "'epoch=0-step=1.ckpt'\n"
 779 |      ]
 780 |     }
 781 |    ],
 782 |    "source": [
 783 |     "! ls lightning_logs/version_2/checkpoints"
 784 |    ]
 785 |   },
 786 |   {
 787 |    "attachments": {},
 788 |    "cell_type": "markdown",
 789 |    "metadata": {},
 790 |    "source": [
 791 |     "Navigate to one of the versions and take a look at the `hparams.yaml` file. It should contain exactly the hyperparameters from the run.\n"
 792 |    ]
 793 |   },
 794 |   {
 795 |    "cell_type": "code",
 796 |    "execution_count": 33,
 797 |    "metadata": {
 798 |     "vscode": {
 799 |      "languageId": "python"
 800 |     }
 801 |    },
 802 |    "outputs": [
 803 |     {
 804 |      "name": "stdout",
 805 |      "output_type": "stream",
 806 |      "text": [
 807 |       "model:\n",
 808 |       "  class_path: gnn_tracking.models.graph_construction.GraphConstructionFCNN\n",
 809 |       "  init_args:\n",
 810 |       "    in_dim: 14\n",
 811 |       "    hidden_dim: 64\n",
 812 |       "    out_dim: 8\n",
 813 |       "    depth: 5\n",
 814 |       "    alpha: 0.6\n",
 815 |       "preproc: null\n",
 816 |       "loss_fct:\n",
 817 |       "  class_path: gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss\n",
 818 |       "  init_args:\n",
 819 |       "    lw_repulsive: 0.5\n",
 820 |       "    r_emb: 1.0\n",
 821 |       "    max_num_neighbors: 10\n",
 822 |       "    pt_thld: 0.9\n",
 823 |       "    max_eta: 4.0\n",
 824 |       "    p_attr: 1.0\n",
 825 |       "    p_rep: 1.0\n",
 826 |       "gc_scanner: null\n"
 827 |      ]
 828 |     }
 829 |    ],
 830 |    "source": [
 831 |     "! cat lightning_logs/version_2/hparams.yaml"
 832 |    ]
 833 |   },
 834 |   {
 835 |    "cell_type": "markdown",
 836 |    "metadata": {},
 837 |    "source": [
 838 |     "Similarly, you can check out the `config.yaml` file for additional config values affecting the `Trainer` and other elements of the training process."
 839 |    ]
 840 |   },
 841 |   {
 842 |    "attachments": {},
 843 |    "cell_type": "markdown",
 844 |    "metadata": {},
 845 |    "source": [
 846 |     "We can bring back the trained model by loading one of the checkpoints:"
 847 |    ]
 848 |   },
 849 |   {
 850 |    "cell_type": "code",
 851 |    "execution_count": 34,
 852 |    "metadata": {
 853 |     "vscode": {
 854 |      "languageId": "python"
 855 |     }
 856 |    },
 857 |    "outputs": [
 858 |     {
 859 |      "name": "stderr",
 860 |      "output_type": "stream",
 861 |      "text": [
 862 |       "\u001b[36m[15:03:04] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction\u001b[0m\n",
 863 |       "\u001b[36m[15:03:04] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses.metric_learning\u001b[0m\n"
 864 |      ]
 865 |     }
 866 |    ],
 867 |    "source": [
 868 |     "restored_model = MLModule.load_from_checkpoint(\n",
 869 |     "    \"lightning_logs/version_2/checkpoints/epoch=0-step=1.ckpt\"\n",
 870 |     ")"
 871 |    ]
 872 |   },
 873 |   {
 874 |    "attachments": {},
 875 |    "cell_type": "markdown",
 876 |    "metadata": {},
 877 |    "source": [
 878 |     "Note how we didn't have to specify any the hyperparameters again.\n",
 879 |     "\n",
 880 |     "However, we can easily change some of them by adding them as additional keyword arguments."
 881 |    ]
 882 |   },
 883 |   {
 884 |    "cell_type": "code",
 885 |    "execution_count": 36,
 886 |    "metadata": {
 887 |     "vscode": {
 888 |      "languageId": "python"
 889 |     }
 890 |    },
 891 |    "outputs": [
 892 |     {
 893 |      "name": "stderr",
 894 |      "output_type": "stream",
 895 |      "text": [
 896 |       "\u001b[36m[15:03:23] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction\u001b[0m\n"
 897 |      ]
 898 |     }
 899 |    ],
 900 |    "source": [
 901 |     "restored_model_modified = MLModule.load_from_checkpoint(\n",
 902 |     "    \"lightning_logs/version_2/checkpoints/epoch=0-step=1.ckpt\",\n",
 903 |     "    loss_fct=GraphConstructionHingeEmbeddingLoss(lw_repulsive=0.1, max_num_neighbors=5),\n",
 904 |     ")"
 905 |    ]
 906 |   },
 907 |   {
 908 |    "attachments": {},
 909 |    "cell_type": "markdown",
 910 |    "metadata": {},
 911 |    "source": [
 912 |     "Note that you cannot modify the model architecture however (but you could in principle change the `beta` parameter of the residual connections)."
 913 |    ]
 914 |   },
 915 |   {
 916 |    "attachments": {},
 917 |    "cell_type": "markdown",
 918 |    "metadata": {},
 919 |    "source": [
 920 |     "## Running all of this from the command line"
 921 |    ]
 922 |   },
 923 |   {
 924 |    "attachments": {},
 925 |    "cell_type": "markdown",
 926 |    "metadata": {},
 927 |    "source": [
 928 |     "All of the following can be achieved by running the following command:\n",
 929 |     "\n",
 930 |     "```bash\n",
 931 |     "python3 gnn_tracking/trainers/run.py fit --model configs/model.yml --data configs/data.yml  --trainer.accelerator cpu --trainer.accelerator cpu\n",
 932 |     "```\n",
 933 |     "\n",
 934 |     "with the data config file\n",
 935 |     "\n",
 936 |     "```yaml\n",
 937 |     "train:\n",
 938 |     "  dirs:\n",
 939 |     "    - /path/to/your/dir\n",
 940 |     "  stop: 5\n",
 941 |     "test:\n",
 942 |     "  dirs:\n",
 943 |     "    - /path/to/your/dir\n",
 944 |     "  star: 10\n",
 945 |     "  stop: 15\n",
 946 |     "val:\n",
 947 |     "  dirs:\n",
 948 |     "    - /path/to/your/dir\n",
 949 |     "  start: 5\n",
 950 |     "  stop: 10\n",
 951 |     "identifier: point_clouds_v8\n",
 952 |     "```\n",
 953 |     "\n",
 954 |     "and model config file:\n",
 955 |     "\n",
 956 |     "```yaml\n",
 957 |     "class_path: gnn_tracking.training.ml.MLModule\n",
 958 |     "init_args:\n",
 959 |     "  model:\n",
 960 |     "    class_path: gnn_tracking.models.graph_construction.GraphConstructionFCNN\n",
 961 |     "    init_args:\n",
 962 |     "      in_dim: 14\n",
 963 |     "      out_dim: 8\n",
 964 |     "      hidden_dim: 512\n",
 965 |     "      depth: 5\n",
 966 |     "  loss_fct:\n",
 967 |     "    class_path: gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss\n",
 968 |     "    init_args:\n",
 969 |     "      lw_repulsive: 0.5\n",
 970 |     "  optimizer:\n",
 971 |     "    class_path: torch.optim.Adam\n",
 972 |     "    init_args:\n",
 973 |     "      lr: 0.0001\n",
 974 |     "```"
 975 |    ]
 976 |   },
 977 |   {
 978 |    "attachments": {},
 979 |    "cell_type": "markdown",
 980 |    "metadata": {},
 981 |    "source": [
 982 |     "To quickly override one of the options, you can simply add them to the command line, e.g., `--model.init_args.loss_fct.init_args.lw_repulsive=0.1` or `--model.model.init_args.depth=6`."
 983 |    ]
 984 |   },
 985 |   {
 986 |    "attachments": {},
 987 |    "cell_type": "markdown",
 988 |    "metadata": {},
 989 |    "source": [
 990 |     "## Advanced: Connecting with Weights & Biases\n",
 991 |     "\n",
 992 |     "Weights and Biases (wandb.ai) is a great tool to log all of your runs to. \n",
 993 |     "It's also very easy to set up in principle by adding a callback to the `Trainer`.\n",
 994 |     "\n",
 995 |     "However, first you need to create an account (it's free!). If you collaborate with us, you probably want to reach out to us so that we can add you to our project (and can see each other's runs).\n",
 996 |     "\n",
 997 |     "Once you have your account, copy your API key into the file `~/.wandb_api_key` on the server from which you run your ML models.\n",
 998 |     "\n",
 999 |     "Because we want to later identify our current trial among other trials (and have an easy-to-remember name), let's first create an identifier:\n"
1000 |    ]
1001 |   },
1002 |   {
1003 |    "cell_type": "code",
1004 |    "execution_count": 37,
1005 |    "metadata": {
1006 |     "vscode": {
1007 |      "languageId": "python"
1008 |     }
1009 |    },
1010 |    "outputs": [],
1011 |    "source": [
1012 |     "from gnn_tracking.utils.nomenclature import random_trial_name"
1013 |    ]
1014 |   },
1015 |   {
1016 |    "cell_type": "code",
1017 |    "execution_count": 38,
1018 |    "metadata": {
1019 |     "vscode": {
1020 |      "languageId": "python"
1021 |     }
1022 |    },
1023 |    "outputs": [
1024 |     {
1025 |      "data": {
1026 |       "text/html": [
1027 |        "
─────────────────────────── demonic-logical-platypus ───────────────────────────\n",
1028 |        "
\n" 1029 | ], 1030 | "text/plain": [ 1031 | "\u001b[92m─────────────────────────── \u001b[0m\u001b[1;33mdemonic-logical-platypus\u001b[0m\u001b[92m ───────────────────────────\u001b[0m\n" 1032 | ] 1033 | }, 1034 | "metadata": {}, 1035 | "output_type": "display_data" 1036 | } 1037 | ], 1038 | "source": [ 1039 | "name = random_trial_name()" 1040 | ] 1041 | }, 1042 | { 1043 | "attachments": {}, 1044 | "cell_type": "markdown", 1045 | "metadata": {}, 1046 | "source": [ 1047 | "\n", 1048 | "After this, let's set up the logger:" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": null, 1054 | "metadata": { 1055 | "vscode": { 1056 | "languageId": "python" 1057 | } 1058 | }, 1059 | "outputs": [ 1060 | { 1061 | "name": "stderr", 1062 | "output_type": "stream", 1063 | "text": [ 1064 | "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", 1065 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id godlike-buzzard-of-wonder.\n" 1066 | ] 1067 | }, 1068 | { 1069 | "data": { 1070 | "text/html": [ 1071 | "Tracking run with wandb version 0.15.4" 1072 | ], 1073 | "text/plain": [ 1074 | "" 1075 | ] 1076 | }, 1077 | "metadata": {}, 1078 | "output_type": "display_data" 1079 | }, 1080 | { 1081 | "data": { 1082 | "text/html": [ 1083 | "W&B syncing is set to `offline` in this directory.
Run `wandb online` or set WANDB_MODE=online to enable cloud syncing." 1084 | ], 1085 | "text/plain": [ 1086 | "" 1087 | ] 1088 | }, 1089 | "metadata": {}, 1090 | "output_type": "display_data" 1091 | } 1092 | ], 1093 | "source": [ 1094 | "from pytorch_lightning.loggers import WandbLogger\n", 1095 | "\n", 1096 | "\n", 1097 | "wandb_logger = WandbLogger(\n", 1098 | " project=\"ml\",\n", 1099 | " group=\"first\",\n", 1100 | " offline=True, # <-- see notes below\n", 1101 | " version=name,\n", 1102 | ")" 1103 | ] 1104 | }, 1105 | { 1106 | "attachments": {}, 1107 | "cell_type": "markdown", 1108 | "metadata": {}, 1109 | "source": [ 1110 | "We want to keep our checkpoints locally, so let's also initialize the default logger (which would be replaced by `WandbLogger` if we don't add it manually):" 1111 | ] 1112 | }, 1113 | { 1114 | "cell_type": "code", 1115 | "execution_count": null, 1116 | "metadata": { 1117 | "vscode": { 1118 | "languageId": "python" 1119 | } 1120 | }, 1121 | "outputs": [], 1122 | "source": [ 1123 | "from pytorch_lightning.loggers import TensorBoardLogger\n", 1124 | "\n", 1125 | "\n", 1126 | "tb_logger = TensorBoardLogger(\".\", version=name)" 1127 | ] 1128 | }, 1129 | { 1130 | "attachments": {}, 1131 | "cell_type": "markdown", 1132 | "metadata": {}, 1133 | "source": [ 1134 | "Now we'd have all the places in place, if it weren't for one subtlety: The Princeton compute nodes don't have internet connectivity.\n", 1135 | "This is also why we set `offline=True` to the `WandbLogger`. But that's not a problem, because we have internet on the head node `della-gpu`\n", 1136 | "(just not on the compute node). So we can simply run `wandb sync /path/to/run/dir` afterwards.\n", 1137 | "However, because this is annoying, I wrote a package `wandb-osh` to help with this.\n", 1138 | "\n", 1139 | "To install it, run:" 1140 | ] 1141 | }, 1142 | { 1143 | "cell_type": "code", 1144 | "execution_count": null, 1145 | "metadata": { 1146 | "vscode": { 1147 | "languageId": "python" 1148 | } 1149 | }, 1150 | "outputs": [ 1151 | { 1152 | "name": "stdout", 1153 | "output_type": "stream", 1154 | "text": [ 1155 | "Requirement already satisfied: wandb-osh in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (1.0.4)\n", 1156 | "Requirement already satisfied: colorlog in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb-osh) (6.7.0)\n", 1157 | "Requirement already satisfied: wandb in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb-osh) (0.15.4)\n", 1158 | "Requirement already satisfied: Click!=8.0.0,>=7.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (8.1.3)\n", 1159 | "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (3.1.31)\n", 1160 | "Requirement already satisfied: requests<3,>=2.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (2.31.0)\n", 1161 | "Requirement already satisfied: psutil>=5.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (5.9.5)\n", 1162 | "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.21.1)\n", 1163 | "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (0.4.0)\n", 1164 | "Requirement already satisfied: PyYAML in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (6.0)\n", 1165 | "Requirement already satisfied: pathtools in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (0.1.2)\n", 1166 | "Requirement already satisfied: setproctitle in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.3.2)\n", 1167 | "Requirement already satisfied: setuptools in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (67.7.2)\n", 1168 | "Requirement already satisfied: appdirs>=1.4.3 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.4.4)\n", 1169 | "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (3.20.3)\n", 1170 | "Requirement already satisfied: six>=1.4.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb->wandb-osh) (1.16.0)\n", 1171 | "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb->wandb-osh) (4.0.10)\n", 1172 | "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (3.1.0)\n", 1173 | "Requirement already satisfied: idna<4,>=2.5 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (3.4)\n", 1174 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (2.0.3)\n", 1175 | "Requirement already satisfied: certifi>=2017.4.17 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (2023.5.7)\n", 1176 | "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->wandb-osh) (3.0.5)\n" 1177 | ] 1178 | } 1179 | ], 1180 | "source": [ 1181 | "! pip3 install wandb-osh" 1182 | ] 1183 | }, 1184 | { 1185 | "attachments": {}, 1186 | "cell_type": "markdown", 1187 | "metadata": {}, 1188 | "source": [ 1189 | "Now let's put everything together: " 1190 | ] 1191 | }, 1192 | { 1193 | "cell_type": "code", 1194 | "execution_count": null, 1195 | "metadata": { 1196 | "vscode": { 1197 | "languageId": "python" 1198 | } 1199 | }, 1200 | "outputs": [], 1201 | "source": [ 1202 | "from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback\n", 1203 | "\n", 1204 | "\n", 1205 | "trainer = Trainer(\n", 1206 | " max_epochs=3,\n", 1207 | " accelerator=\"cpu\",\n", 1208 | " log_every_n_steps=1,\n", 1209 | " callbacks=[\n", 1210 | " TriggerWandbSyncLightningCallback(),\n", 1211 | " PrintValidationMetrics(),\n", 1212 | " ],\n", 1213 | " logger=[\n", 1214 | " wandb_logger,\n", 1215 | " tb_logger,\n", 1216 | " ],\n", 1217 | ")" 1218 | ] 1219 | }, 1220 | { 1221 | "attachments": {}, 1222 | "cell_type": "markdown", 1223 | "metadata": {}, 1224 | "source": [ 1225 | "To sync your run, simply start the `wandb-osh` command line utility on `della-gpu`.\n", 1226 | "For more information on how this works, see [here](https://github.com/klieret/wandb-offline-sync-hook)." 1227 | ] 1228 | } 1229 | ], 1230 | "metadata": { 1231 | "kernelspec": { 1232 | "display_name": "Python 3 (ipykernel)", 1233 | "language": "python", 1234 | "name": "python3" 1235 | }, 1236 | "language_info": { 1237 | "name": "", 1238 | "version": "" 1239 | } 1240 | }, 1241 | "nbformat": 4, 1242 | "nbformat_minor": 1 1243 | } 1244 | -------------------------------------------------------------------------------- /notebooks/020_one_shot_object_condensation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# One shot object condensation\n", 10 | "\n", 11 | "This notebook shows how you can implement a model that directly goes from point cloud data to object condensation." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 63, 17 | "metadata": { 18 | "collapsed": false 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "from pathlib import Path\n", 23 | "from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin\n", 24 | "from torch import nn\n", 25 | "from torch_geometric.nn.conv import GravNetConv\n", 26 | "from torch_geometric.data import Data\n", 27 | "from pytorch_lightning import Trainer\n", 28 | "\n", 29 | "from gnn_tracking.metrics.losses.oc import CondensationLossTiger\n", 30 | "import torch\n", 31 | "from functools import partial\n", 32 | "\n", 33 | "from gnn_tracking.training.callbacks import PrintValidationMetrics\n", 34 | "from gnn_tracking.training.tc import TCModule\n", 35 | "from gnn_tracking.utils.loading import TrackingDataModule\n", 36 | "from gnn_tracking.utils.versioning import assert_version_geq\n", 37 | "\n", 38 | "assert_version_geq(\"23.12.0\")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "collapsed": false 45 | }, 46 | "source": [ 47 | "## 1. Configure data" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 64, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "data_dir = (\n", 57 | " Path.cwd().resolve().parent.parent / \"test-data\" / \"data\" / \"point_clouds\" / \"v8\"\n", 58 | ")\n", 59 | "assert data_dir.is_dir()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 65, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dm = TrackingDataModule(\n", 69 | " train=dict(\n", 70 | " dirs=[data_dir],\n", 71 | " stop=1,\n", 72 | " ),\n", 73 | " val=dict(\n", 74 | " dirs=[data_dir],\n", 75 | " start=1,\n", 76 | " stop=2,\n", 77 | " ),\n", 78 | " identifier=\"point_clouds_v8\",\n", 79 | " # could also configure a 'test' set here\n", 80 | ")" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "collapsed": false 87 | }, 88 | "source": [ 89 | "## 2. Write a model" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 66, 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "class DemoGravNet(nn.Module, HyperparametersMixin):\n", 101 | " def __init__(self, in_dim: int = 14, depth: int = 1, k: int = 2):\n", 102 | " super().__init__()\n", 103 | " self.save_hyperparameters()\n", 104 | " layers = [\n", 105 | " GravNetConv(\n", 106 | " in_channels=in_dim,\n", 107 | " out_channels=in_dim,\n", 108 | " space_dimensions=3,\n", 109 | " propagate_dimensions=3,\n", 110 | " k=k,\n", 111 | " )\n", 112 | " for _ in range(depth)\n", 113 | " ]\n", 114 | " self._embedding = nn.Sequential(*layers)\n", 115 | " self._beta = nn.Sequential(\n", 116 | " nn.Linear(in_dim, 1),\n", 117 | " nn.Sigmoid(),\n", 118 | " )\n", 119 | "\n", 120 | " def forward(self, data: Data):\n", 121 | " latent = self._embedding(data.x)\n", 122 | " beta = self._beta(latent).squeeze()\n", 123 | " eps = 1e-6\n", 124 | " beta = beta.clamp(eps, 1 - eps)\n", 125 | " return {\n", 126 | " \"B\": beta,\n", 127 | " \"H\": latent,\n", 128 | " }" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 67, 134 | "metadata": { 135 | "collapsed": false 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "model = DemoGravNet()" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": { 145 | "collapsed": false 146 | }, 147 | "source": [ 148 | "## 3. Configure loss functions and weights" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 68, 154 | "metadata": { 155 | "collapsed": false 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "# The loss functions can be memory hungry. Here we override `data_preproc` to place a tighter pt cut on\n", 160 | "# the data to easy computation (since this is just a demo).\n", 161 | "class PtCut(HyperparametersMixin):\n", 162 | " def __call__(self, data: Data):\n", 163 | " mask = data.pt > 4\n", 164 | " data = data.subgraph(mask)\n", 165 | " return data" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 71, 171 | "metadata": { 172 | "collapsed": false 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner\n", 177 | "\n", 178 | "# TC for track condensation\n", 179 | "lmodel = TCModule(\n", 180 | " model=model,\n", 181 | " loss_fct=CondensationLossTiger(\n", 182 | " lw_repulsive=2.0,\n", 183 | " ),\n", 184 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n", 185 | " cluster_scanner=DBSCANHyperParamScanner(n_trials=5, n_jobs=1),\n", 186 | " preproc=PtCut(),\n", 187 | ")" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": { 193 | "collapsed": false 194 | }, 195 | "source": [ 196 | "## 4. Train the model" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 72, 202 | "metadata": { 203 | "collapsed": false 204 | }, 205 | "outputs": [ 206 | { 207 | "name": "stderr", 208 | "output_type": "stream", 209 | "text": [ 210 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...\n", 211 | "GPU available: False, used: False\n", 212 | "TPU available: False, using: 0 TPU cores\n", 213 | "IPU available: False, using: 0 IPUs\n", 214 | "HPU available: False, using: 0 HPUs\n", 215 | "\u001b[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 216 | "\u001b[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt\u001b[0m\n", 217 | "\u001b[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n", 218 | "\u001b[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt\u001b[0m\n", 219 | "\n", 220 | " | Name | Type | Params\n", 221 | "---------------------------------------------------\n", 222 | "0 | model | DemoGravNet | 399 \n", 223 | "1 | loss_fct | CondensationLossTiger | 0 \n", 224 | "---------------------------------------------------\n", 225 | "399 Trainable params\n", 226 | "0 Non-trainable params\n", 227 | "399 Total params\n", 228 | "0.002 Total estimated model params size (MB)\n" 229 | ] 230 | }, 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "Sanity Checking: | | 0/? [00:00\n" 383 | ], 384 | "text/plain": [] 385 | }, 386 | "metadata": {}, 387 | "output_type": "display_data" 388 | }, 389 | { 390 | "name": "stdout", 391 | "output_type": "stream", 392 | "text": [ 393 | "\n", 394 | "\u001b[3m Validation epoch=0 \u001b[0m \n", 395 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓\n", 396 | "┃\u001b[1m \u001b[0m\u001b[1mMetric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mError\u001b[0m\u001b[1m \u001b[0m┃\n", 397 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩\n", 398 | "│\u001b[1;95m \u001b[0m\u001b[1;95mattractive \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m55245000.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 399 | "│ attractive_train │ 93512936.00000 │ nan │\n", 400 | "│ attractive_weighted │ 55245000.00000 │ nan │\n", 401 | "│ attractive_weighted_train │ 93512936.00000 │ nan │\n", 402 | "│ best_dbscan_eps │ 0.15979 │ nan │\n", 403 | "│ best_dbscan_min_samples │ 4.00000 │ nan │\n", 404 | "│ coward │ 0.03412 │ nan │\n", 405 | "│ coward_train │ 0.05263 │ nan │\n", 406 | "│ coward_weighted │ 0.00000 │ nan │\n", 407 | "│ coward_weighted_train │ 0.00000 │ nan │\n", 408 | "│ n_rep │ 1.00000 │ nan │\n", 409 | "│ n_rep_train │ 0.00000 │ nan │\n", 410 | "│ noise │ nan │ nan │\n", 411 | "│ noise_train │ nan │ nan │\n", 412 | "│ noise_weighted │ nan │ nan │\n", 413 | "│ noise_weighted_train │ nan │ nan │\n", 414 | "│\u001b[1;95m \u001b[0m\u001b[1;95mrepulsive \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00112\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 415 | "│ repulsive_train │ 0.00000 │ nan │\n", 416 | "│ repulsive_weighted │ 0.00223 │ nan │\n", 417 | "│ repulsive_weighted_train │ 0.00000 │ nan │\n", 418 | "│ total │ nan │ nan │\n", 419 | "│ total_train │ nan │ nan │\n", 420 | "│ trk.double_majority │ 0.00000 │ nan │\n", 421 | "│ trk.double_majority_pt0.5 │ 0.00000 │ nan │\n", 422 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.double_majority_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 423 | "│ trk.double_majority_pt1.5 │ 0.00000 │ nan │\n", 424 | "│ trk.fake_double_majority │ nan │ nan │\n", 425 | "│ trk.fake_double_majority_pt0.5 │ nan │ nan │\n", 426 | "│ trk.fake_double_majority_pt0.9 │ nan │ nan │\n", 427 | "│ trk.fake_double_majority_pt1.5 │ nan │ nan │\n", 428 | "│ trk.fake_lhc │ nan │ nan │\n", 429 | "│ trk.fake_lhc_pt0.5 │ nan │ nan │\n", 430 | "│ trk.fake_lhc_pt0.9 │ nan │ nan │\n", 431 | "│ trk.fake_lhc_pt1.5 │ nan │ nan │\n", 432 | "│ trk.fake_perfect │ nan │ nan │\n", 433 | "│ trk.fake_perfect_pt0.5 │ nan │ nan │\n", 434 | "│ trk.fake_perfect_pt0.9 │ nan │ nan │\n", 435 | "│ trk.fake_perfect_pt1.5 │ nan │ nan │\n", 436 | "│ trk.i_batch │ 0.00000 │ nan │\n", 437 | "│ trk.lhc │ nan │ nan │\n", 438 | "│ trk.lhc_pt0.5 │ nan │ nan │\n", 439 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.lhc_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 440 | "│ trk.lhc_pt1.5 │ nan │ nan │\n", 441 | "│ trk.n_cleaned_clusters │ 0.00000 │ nan │\n", 442 | "│ trk.n_cleaned_clusters_pt0.5 │ 0.00000 │ nan │\n", 443 | "│ trk.n_cleaned_clusters_pt0.9 │ 0.00000 │ nan │\n", 444 | "│ trk.n_cleaned_clusters_pt1.5 │ 0.00000 │ nan │\n", 445 | "│ trk.n_particles │ 17.00000 │ nan │\n", 446 | "│ trk.n_particles_pt0.5 │ 17.00000 │ nan │\n", 447 | "│ trk.n_particles_pt0.9 │ 17.00000 │ nan │\n", 448 | "│ trk.n_particles_pt1.5 │ 17.00000 │ nan │\n", 449 | "│ trk.perfect │ 0.00000 │ nan │\n", 450 | "│ trk.perfect_pt0.5 │ 0.00000 │ nan │\n", 451 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.perfect_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 452 | "│ trk.perfect_pt1.5 │ 0.00000 │ nan │\n", 453 | "└────────────────────────────────┴────────────────┴───────┘\n", 454 | "\n", 455 | "Epoch 0: 100%|█| 1/1 [00:11<00:00, 0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise" 456 | ] 457 | }, 458 | { 459 | "name": "stderr", 460 | "output_type": "stream", 461 | "text": [ 462 | "`Trainer.fit` stopped: `max_epochs=1` reached.\n" 463 | ] 464 | }, 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Epoch 0: 100%|█| 1/1 [00:11<00:00, 0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "trainer = Trainer(\n", 475 | " max_epochs=1,\n", 476 | " accelerator=\"cpu\",\n", 477 | " log_every_n_steps=1,\n", 478 | " callbacks=[PrintValidationMetrics()],\n", 479 | ")\n", 480 | "trainer.fit(model=lmodel, datamodule=dm)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": { 487 | "collapsed": false 488 | }, 489 | "outputs": [], 490 | "source": [] 491 | } 492 | ], 493 | "metadata": { 494 | "kernelspec": { 495 | "display_name": "Python 3 (ipykernel)", 496 | "language": "python", 497 | "name": "python3" 498 | }, 499 | "language_info": { 500 | "codemirror_mode": { 501 | "name": "ipython", 502 | "version": 3 503 | }, 504 | "file_extension": ".py", 505 | "mimetype": "text/x-python", 506 | "name": "python", 507 | "nbconvert_exporter": "python", 508 | "pygments_lexer": "ipython3", 509 | "version": "3.10.11" 510 | } 511 | }, 512 | "nbformat": 4, 513 | "nbformat_minor": 0 514 | } 515 | -------------------------------------------------------------------------------- /notebooks/030_edge_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# Edge classification\n", 10 | "\n", 11 | "This notebooks shows how to classify edges of a graph. In many GNN tracking approaches, we start from an initial graph (e.g., built from a point cloud with the strategy described in `009_build_graphs_ml.ipynb`). We then try to falsify all edges that connected hits of two different particles. If edge classification (EC) would be perfect, we could then reconstruct tracks as connected components of the graph.\n", 12 | "For our object condensation approach, EC is only an auxiliary step. Edges are only considered for message passing but are not important for the final decision on how tracks look. However, EC is still important to help the model to learn quickly.\n", 13 | "\n", 14 | "For background on pytorch lightning, see `009_build_graphs_ml.ipynb`." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 76, 20 | "metadata": { 21 | "collapsed": false 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "from pytorch_lightning import Trainer\n", 26 | "from torch import nn\n", 27 | "from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin\n", 28 | "import torch\n", 29 | "from functools import partial\n", 30 | "\n", 31 | "from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt\n", 32 | "\n", 33 | "from gnn_tracking.metrics.losses.ec import EdgeWeightFocalLoss\n", 34 | "from gnn_tracking.training.callbacks import PrintValidationMetrics\n", 35 | "from gnn_tracking.training.ec import ECModule\n", 36 | "\n", 37 | "from gnn_tracking.utils.loading import TrackingDataModule\n", 38 | "\n", 39 | "\n", 40 | "from gnn_tracking.utils.versioning import assert_version_geq\n", 41 | "\n", 42 | "assert_version_geq(\"23.12.0\")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": { 48 | "collapsed": false 49 | }, 50 | "source": [ 51 | "We can either directly load graphs (from disk), or we load point clouds and build edges on the fly using the module from `009_build_graphs_ml.ipynb`." 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "source": [ 60 | "## From on-disk graphs" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "collapsed": false 67 | }, 68 | "source": [ 69 | "### 1. Setting up the data\n", 70 | "\n", 71 | "If you are not working on Princeton's `della`, you can download these example graphs [here](https://cernbox.cern.ch/s/4xYL99cd7zNe0VK). Note that this is simplified data (pt > 1 GeV truth cut) and a single event has been broken up into 32 sectors." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 85, 77 | "metadata": { 78 | "collapsed": false 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "dm = TrackingDataModule(\n", 83 | " train=dict(\n", 84 | " dirs=[\n", 85 | " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all\"\n", 86 | " ],\n", 87 | " stop=28_000,\n", 88 | " # If you run into memory issues, reduce this\n", 89 | " batch_size=10,\n", 90 | " ),\n", 91 | " val=dict(\n", 92 | " dirs=[\n", 93 | " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all\"\n", 94 | " ],\n", 95 | " start=28_000,\n", 96 | " stop=28_100,\n", 97 | " ),\n", 98 | " identifier=\"graphs_v1\",\n", 99 | ")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": { 105 | "collapsed": false 106 | }, 107 | "source": [ 108 | "### 2. Defining the module" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 86, 114 | "metadata": { 115 | "collapsed": false 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "class SillyEC(nn.Module, HyperparametersMixin):\n", 120 | " def __init__(\n", 121 | " self,\n", 122 | " node_in_dim: int,\n", 123 | " edge_in_dim: int,\n", 124 | " hidden_dim: int = 12,\n", 125 | " ):\n", 126 | " super().__init__()\n", 127 | " self.save_hyperparameters()\n", 128 | " self.node_in_dim = node_in_dim\n", 129 | " self.edge_in_dim = edge_in_dim\n", 130 | " self.hidden_dim = hidden_dim\n", 131 | "\n", 132 | " self.fcnn = nn.Sequential(\n", 133 | " nn.Linear(edge_in_dim, hidden_dim),\n", 134 | " nn.ReLU(),\n", 135 | " nn.Linear(hidden_dim, hidden_dim),\n", 136 | " nn.ReLU(),\n", 137 | " nn.Linear(hidden_dim, 1),\n", 138 | " nn.Sigmoid(),\n", 139 | " )\n", 140 | "\n", 141 | " def forward(self, data):\n", 142 | " w = self.fcnn(data.edge_attr).squeeze()\n", 143 | " return {\"W\": w}" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 87, 149 | "metadata": { 150 | "collapsed": false 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "model = SillyEC(node_in_dim=6, edge_in_dim=4, hidden_dim=128)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "collapsed": false 161 | }, 162 | "source": [ 163 | "### 2. Setting up the loss functions and the lightning module" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 88, 169 | "metadata": { 170 | "collapsed": false 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "lmodel = ECModule(\n", 175 | " model=model,\n", 176 | " loss_fct=EdgeWeightFocalLoss(alpha=0.3),\n", 177 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n", 178 | ")" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "collapsed": false 185 | }, 186 | "source": [ 187 | "### 3. Starting training" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 89, 193 | "metadata": { 194 | "collapsed": false 195 | }, 196 | "outputs": [ 197 | { 198 | "name": "stderr", 199 | "output_type": "stream", 200 | "text": [ 201 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...\n", 202 | "GPU available: False, used: False\n", 203 | "TPU available: False, using: 0 TPU cores\n", 204 | "IPU available: False, using: 0 IPUs\n", 205 | "HPU available: False, using: 0 HPUs\n", 206 | "\u001b[32m[16:06:55] INFO: DataLoader will load 28000 graphs (out of 28800 available).\u001b[0m\n", 207 | "\u001b[36m[16:06:55] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21974_s9.pt\u001b[0m\n", 208 | "\u001b[32m[16:06:56] INFO: DataLoader will load 100 graphs (out of 28800 available).\u001b[0m\n", 209 | "\u001b[36m[16:06:56] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21975_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21978_s11.pt\u001b[0m\n", 210 | "\n", 211 | " | Name | Type | Params\n", 212 | "-------------------------------------------------\n", 213 | "0 | model | SillyEC | 17.3 K\n", 214 | "1 | loss_fct | EdgeWeightFocalLoss | 0 \n", 215 | "-------------------------------------------------\n", 216 | "17.3 K Trainable params\n", 217 | "0 Non-trainable params\n", 218 | "17.3 K Total params\n", 219 | "0.069 Total estimated model params size (MB)\n" 220 | ] 221 | }, 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "Sanity Checking: | | 0/? [00:00\n" 261 | ], 262 | "text/plain": [] 263 | }, 264 | "metadata": {}, 265 | "output_type": "display_data" 266 | }, 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "\n", 272 | " \n", 273 | " \n", 274 | "\u001b[3m Validation epoch=0 \u001b[0m \n", 275 | "┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓\n", 276 | "┃\u001b[1m \u001b[0m\u001b[1mMetric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Error\u001b[0m\u001b[1m \u001b[0m┃\n", 277 | "┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩\n", 278 | "│ max_ba │ 0.81276 │ 0.00218 │\n", 279 | "│ max_ba_loc │ 0.42698 │ 0.00009 │\n", 280 | "│ max_ba_loc_pt0.5 │ 0.42698 │ 0.00009 │\n", 281 | "│ max_ba_loc_pt0.9 │ 0.42698 │ 0.00009 │\n", 282 | "│ max_ba_loc_pt1.5 │ 0.42663 │ 0.00017 │\n", 283 | "│ max_ba_pt0.5 │ 0.81276 │ 0.00218 │\n", 284 | "│ max_ba_pt0.9 │ 0.81276 │ 0.00218 │\n", 285 | "│ max_ba_pt1.5 │ 0.80939 │ 0.00313 │\n", 286 | "│ max_f1 │ 0.63106 │ 0.00708 │\n", 287 | "│ max_f1_loc │ 0.42714 │ 0.00000 │\n", 288 | "│ max_f1_loc_pt0.5 │ 0.42714 │ 0.00000 │\n", 289 | "│ max_f1_loc_pt0.9 │ 0.42714 │ 0.00000 │\n", 290 | "│ max_f1_loc_pt1.5 │ 0.42749 │ 0.00013 │\n", 291 | "│ max_f1_pt0.5 │ 0.63106 │ 0.00708 │\n", 292 | "│ max_f1_pt0.9 │ 0.63106 │ 0.00708 │\n", 293 | "│ max_f1_pt1.5 │ 0.53800 │ 0.00936 │\n", 294 | "│ max_mcc │ 0.53974 │ 0.00572 │\n", 295 | "│ max_mcc_loc │ 0.42719 │ 0.00005 │\n", 296 | "│ max_mcc_loc_pt0.5 │ 0.42719 │ 0.00005 │\n", 297 | "│ max_mcc_loc_pt0.9 │ 0.42719 │ 0.00005 │\n", 298 | "│ max_mcc_loc_pt1.5 │ 0.42779 │ 0.00017 │\n", 299 | "│ max_mcc_pt0.5 │ 0.53974 │ 0.00572 │\n", 300 | "│\u001b[1;95m \u001b[0m\u001b[1;95mmax_mcc_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.53974\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.00572\u001b[0m\u001b[1;95m \u001b[0m│\n", 301 | "│ max_mcc_pt1.5 │ 0.48179 │ 0.00770 │\n", 302 | "│ roc_auc │ 0.87181 │ 0.00222 │\n", 303 | "│ roc_auc_0.001FPR │ 0.50003 │ 0.00008 │\n", 304 | "│ roc_auc_0.001FPR_pt0.5 │ 0.50003 │ 0.00008 │\n", 305 | "│ roc_auc_0.001FPR_pt0.9 │ 0.50003 │ 0.00008 │\n", 306 | "│ roc_auc_0.001FPR_pt1.5 │ 0.50805 │ 0.00270 │\n", 307 | "│ roc_auc_0.01FPR │ 0.51681 │ 0.00309 │\n", 308 | "│ roc_auc_0.01FPR_pt0.5 │ 0.51681 │ 0.00309 │\n", 309 | "│ roc_auc_0.01FPR_pt0.9 │ 0.51681 │ 0.00309 │\n", 310 | "│ roc_auc_0.01FPR_pt1.5 │ 0.53808 │ 0.00470 │\n", 311 | "│ roc_auc_pt0.5 │ 0.87181 │ 0.00222 │\n", 312 | "│ roc_auc_pt0.9 │ 0.87181 │ 0.00222 │\n", 313 | "│ roc_auc_pt1.5 │ 0.87136 │ 0.00296 │\n", 314 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtotal \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.06456\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n", 315 | "│ total_train │ 0.06461 │ nan │\n", 316 | "│ tpr_eq_tnr │ 0.81249 │ 0.00221 │\n", 317 | "│ tpr_eq_tnr_loc │ 0.42693 │ 0.00010 │\n", 318 | "│ tpr_eq_tnr_loc_pt0.5 │ 0.42693 │ 0.00010 │\n", 319 | "│ tpr_eq_tnr_loc_pt0.9 │ 0.42693 │ 0.00010 │\n", 320 | "│ tpr_eq_tnr_loc_pt1.5 │ 0.42693 │ 0.00012 │\n", 321 | "│ tpr_eq_tnr_pt0.5 │ 0.81249 │ 0.00221 │\n", 322 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtpr_eq_tnr_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.81249\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.00221\u001b[0m\u001b[1;95m \u001b[0m│\n", 323 | "│ tpr_eq_tnr_pt1.5 │ 0.80665 │ 0.00382 │\n", 324 | "└────────────────────────┴─────────┴─────────┘\n", 325 | "\n", 326 | "Epoch 0: 4%|█████▉ | 100/2800 [00:32<14:50, 3.03it/s, v_num=8, total_train=0.0646]" 327 | ] 328 | }, 329 | { 330 | "name": "stderr", 331 | "output_type": "stream", 332 | "text": [ 333 | "`Trainer.fit` stopped: `max_steps=100` reached.\n" 334 | ] 335 | }, 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "Epoch 0: 4%|█████▉ | 100/2800 [00:32<14:50, 3.03it/s, v_num=8, total_train=0.0646]\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "trainer = Trainer(\n", 346 | " max_steps=100,\n", 347 | " val_check_interval=100,\n", 348 | " accelerator=\"cpu\",\n", 349 | " log_every_n_steps=1,\n", 350 | " callbacks=[PrintValidationMetrics()],\n", 351 | ")\n", 352 | "trainer.fit(model=lmodel, datamodule=dm)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": { 358 | "collapsed": false 359 | }, 360 | "source": [ 361 | "## With graphs built on-the-fly from point clouds" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": { 367 | "collapsed": false 368 | }, 369 | "source": [ 370 | "Step 1: Configure data module to load point clouds (rather than graphs).\n", 371 | "Step 2: Add `MLGraphConstructionFromChkpt` as preproc." 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "collapsed": false 379 | }, 380 | "outputs": [], 381 | "source": [ 382 | "lmodel = ECModule(\n", 383 | " model=model,\n", 384 | " loss_fct=EdgeWeightFocalLoss(alpha=0.3),\n", 385 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n", 386 | " preproc=MLGraphConstructionFromChkpt(\n", 387 | " ml_class_name=\"gnn_tracking.models.graph_construction.GraphConstructionFCNN\",\n", 388 | " ml_chkpt_path=\"/path/to/your/checkpoint\",\n", 389 | " ),\n", 390 | ")" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "collapsed": false 397 | }, 398 | "source": [ 399 | "Instead of `MLGraphConstructionFromChkpt` you can also take a look at `MLGraphConstruction` that simply takes a model (that you can instantiate in any way)." 400 | ] 401 | } 402 | ], 403 | "metadata": { 404 | "kernelspec": { 405 | "display_name": "Python 3 (ipykernel)", 406 | "language": "python", 407 | "name": "python3" 408 | }, 409 | "language_info": { 410 | "codemirror_mode": { 411 | "name": "ipython", 412 | "version": 3 413 | }, 414 | "file_extension": ".py", 415 | "mimetype": "text/x-python", 416 | "name": "python", 417 | "nbconvert_exporter": "python", 418 | "pygments_lexer": "ipython3", 419 | "version": "3.10.11" 420 | } 421 | }, 422 | "nbformat": 4, 423 | "nbformat_minor": 0 424 | } 425 | -------------------------------------------------------------------------------- /notebooks/040_three_shot_object_condensation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# Three-shot object condensation\n", 10 | "\n", 11 | "This sketches how to implement the pipleine of \"graph construction (GC)\" > \"edge classification (EC)\" > \"object condensation (OC)\".\n", 12 | "There are multiple options." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "collapsed": false 19 | }, 20 | "source": [ 21 | "## Using graphs on disk\n", 22 | "\n", 23 | "`020_one_shot_object_condensation.ipynb` built graphs using kNN as part of the `GravNetConv`. But if you directly load graphs from disk, you can simply use any GNN and everything will work." 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "collapsed": false 30 | }, 31 | "source": [ 32 | "## From point clouds on disk and a pre-trained GC + EC\n", 33 | "\n", 34 | "We simply follow the last example from the EC notebook:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "source": [ 43 | "```python3\n", 44 | "lmodel = TCModule(\n", 45 | " model=model,\n", 46 | " ...,\n", 47 | " preproc = MLGraphConstructionFromChkpt(\n", 48 | " ml_class_name=\"gnn_tracking.models.graph_construction.GraphConstructionFCNN\",\n", 49 | " ml_chkpt_path=\"/path/to/your/checkpoint\",\n", 50 | " ec_class_name=\"gnn_tracking.models.edge_classifier.ECForGraphTCN\",\n", 51 | " ec_chkpt_path=\"/path/to/your/checkpoint\",\n", 52 | " ),\n", 53 | ")\n", 54 | "```" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "collapsed": false 61 | }, 62 | "source": [ 63 | "Other than the `preproc` step, everything can be set up like in the EC notebook." 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3", 70 | "language": "python", 71 | "name": "python3" 72 | }, 73 | "language_info": { 74 | "codemirror_mode": { 75 | "name": "ipython", 76 | "version": 2 77 | }, 78 | "file_extension": ".py", 79 | "mimetype": "text/x-python", 80 | "name": "python", 81 | "nbconvert_exporter": "python", 82 | "pygments_lexer": "ipython2", 83 | "version": "2.7.6" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 0 88 | } 89 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | minversion = "6.0" 3 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--cov-branch"] 4 | xfail_strict = true 5 | testpaths = ["tests"] 6 | 7 | [tool.pycln] 8 | all = true 9 | --------------------------------------------------------------------------------