├── .cruft.json ├── .flake8 ├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ └── check-links.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE.txt ├── README.md ├── codespell.txt ├── mlc_config.json ├── mypy.ini ├── notebooks ├── 002_preproc_data.ipynb ├── 009_build_graphs_ml.ipynb ├── 020_one_shot_object_condensation.ipynb ├── 030_edge_classification.ipynb └── 040_three_shot_object_condensation.ipynb └── pyproject.toml /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "git@github.com:klieret/python-cookiecutter.git", 3 | "commit": "c650bc109b1f408ded56bbb421747630309e10f8", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "gnn-tracking-tutorials", 8 | "package_name": "gnn_tracking_tut", 9 | "description": "Tutorials and onboarding for the GNN Tracking project", 10 | "user": "gnn-tracking", 11 | "url": "https://github.com/gnn-tracking/tutorials", 12 | "full_name": "Kilian Lieret, Gage deZoort", 13 | "email": "kilian.lieret@posteo.de", 14 | "maintainer": "Kilian Lieret, Gage deZoort", 15 | "maintainer_email": "kilian.lieret@posteo.de", 16 | "year": "2023", 17 | "_copy_without_render": ["*.css"], 18 | "_template": "git@github.com:klieret/python-cookiecutter.git" 19 | } 20 | }, 21 | "directory": null 22 | } 23 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 80 3 | select = C,E,F,W,B,B950 4 | ignore = E203, E501, W503 5 | exclude = 6 | .git, 7 | __pycache__, 8 | notebooks, 9 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | *.ipynb diff=jupyternotebook 3 | 4 | *.ipynb merge=jupyternotebook 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /.github/workflows/check-links.yaml: -------------------------------------------------------------------------------- 1 | name: Check Markdown links 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: "0 0 1 * *" 8 | 9 | jobs: 10 | markdown-link-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@master 14 | - uses: gaurav-nelson/github-action-markdown-link-check@v1 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # PROJECT SPECIFIC IGNORES 3 | # ============================================================================= 4 | *.pt 5 | **/lightning_logs/** 6 | **/wandb/** 7 | 8 | # ============================================================================= 9 | # GENERAL PYTHON GITIGNORE 10 | # ============================================================================= 11 | # Created by https://www.toptal.com/developers/gitignore/api/python 12 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 13 | 14 | ### Python ### 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | cover/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | .pybuilder/ 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | # For a library or package, you might want to ignore these files since the code is 101 | # intended to run in multiple environments; otherwise, check them in: 102 | # .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # poetry 112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 113 | # This is especially recommended for binary packages to ensure reproducibility, and is more 114 | # commonly ignored for libraries. 115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 116 | #poetry.lock 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/#use-with-ide 124 | .pdm.toml 125 | 126 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 127 | __pypackages__/ 128 | 129 | # Celery stuff 130 | celerybeat-schedule 131 | celerybeat.pid 132 | 133 | # SageMath parsed files 134 | *.sage.py 135 | 136 | # Environments 137 | .env 138 | .venv 139 | env/ 140 | venv/ 141 | ENV/ 142 | env.bak/ 143 | venv.bak/ 144 | 145 | # Spyder project settings 146 | .spyderproject 147 | .spyproject 148 | 149 | # Rope project settings 150 | .ropeproject 151 | 152 | # mkdocs documentation 153 | /site 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | #.idea/ 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.8.0 4 | hooks: 5 | - id: black 6 | - id: black-jupyter 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.6.0 9 | hooks: 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: detect-private-key 13 | - id: end-of-file-fixer 14 | exclude: '.*\.ipynb' 15 | - id: trailing-whitespace 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | name: isort (python) 21 | args: ["--profile", "black", "-a", "", "--append-only"] 22 | - repo: https://github.com/PyCQA/flake8 23 | rev: "7.1.1" 24 | hooks: 25 | - id: flake8 26 | additional_dependencies: ["flake8-bugbear"] 27 | - repo: https://github.com/pre-commit/mirrors-mypy 28 | rev: "v1.11.2" 29 | hooks: 30 | - id: mypy 31 | exclude: 'docs/source/conf\.py' 32 | - repo: https://github.com/codespell-project/codespell 33 | rev: "v2.3.0" 34 | hooks: 35 | - id: codespell 36 | args: ["-I", "codespell.txt"] 37 | exclude: '.*\.ipynb' 38 | - repo: https://github.com/asottile/pyupgrade 39 | rev: v3.17.0 40 | hooks: 41 | - id: pyupgrade 42 | args: ["--py37-plus"] 43 | - repo: https://github.com/asottile/setup-cfg-fmt 44 | rev: "v2.5.0" 45 | hooks: 46 | - id: setup-cfg-fmt 47 | args: [--include-version-classifiers, --max-py-version=3.10] 48 | - repo: https://github.com/hadialqattan/pycln 49 | rev: v2.4.0 50 | hooks: 51 | - id: pycln 52 | args: [--config=pyproject.toml] 53 | 54 | ci: 55 | autoupdate_schedule: monthly 56 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 4 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Kilian Lieret 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
─────────────────────────── demonic-logical-platypus ───────────────────────────\n", 1028 | "\n" 1029 | ], 1030 | "text/plain": [ 1031 | "\u001b[92m─────────────────────────── \u001b[0m\u001b[1;33mdemonic-logical-platypus\u001b[0m\u001b[92m ───────────────────────────\u001b[0m\n" 1032 | ] 1033 | }, 1034 | "metadata": {}, 1035 | "output_type": "display_data" 1036 | } 1037 | ], 1038 | "source": [ 1039 | "name = random_trial_name()" 1040 | ] 1041 | }, 1042 | { 1043 | "attachments": {}, 1044 | "cell_type": "markdown", 1045 | "metadata": {}, 1046 | "source": [ 1047 | "\n", 1048 | "After this, let's set up the logger:" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": null, 1054 | "metadata": { 1055 | "vscode": { 1056 | "languageId": "python" 1057 | } 1058 | }, 1059 | "outputs": [ 1060 | { 1061 | "name": "stderr", 1062 | "output_type": "stream", 1063 | "text": [ 1064 | "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", 1065 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id godlike-buzzard-of-wonder.\n" 1066 | ] 1067 | }, 1068 | { 1069 | "data": { 1070 | "text/html": [ 1071 | "Tracking run with wandb version 0.15.4" 1072 | ], 1073 | "text/plain": [ 1074 | "
`offline` in this directory.
Run `wandb online` or set WANDB_MODE=online to enable cloud syncing."
1084 | ],
1085 | "text/plain": [
1086 | ""
1087 | ]
1088 | },
1089 | "metadata": {},
1090 | "output_type": "display_data"
1091 | }
1092 | ],
1093 | "source": [
1094 | "from pytorch_lightning.loggers import WandbLogger\n",
1095 | "\n",
1096 | "\n",
1097 | "wandb_logger = WandbLogger(\n",
1098 | " project=\"ml\",\n",
1099 | " group=\"first\",\n",
1100 | " offline=True, # <-- see notes below\n",
1101 | " version=name,\n",
1102 | ")"
1103 | ]
1104 | },
1105 | {
1106 | "attachments": {},
1107 | "cell_type": "markdown",
1108 | "metadata": {},
1109 | "source": [
1110 | "We want to keep our checkpoints locally, so let's also initialize the default logger (which would be replaced by `WandbLogger` if we don't add it manually):"
1111 | ]
1112 | },
1113 | {
1114 | "cell_type": "code",
1115 | "execution_count": null,
1116 | "metadata": {
1117 | "vscode": {
1118 | "languageId": "python"
1119 | }
1120 | },
1121 | "outputs": [],
1122 | "source": [
1123 | "from pytorch_lightning.loggers import TensorBoardLogger\n",
1124 | "\n",
1125 | "\n",
1126 | "tb_logger = TensorBoardLogger(\".\", version=name)"
1127 | ]
1128 | },
1129 | {
1130 | "attachments": {},
1131 | "cell_type": "markdown",
1132 | "metadata": {},
1133 | "source": [
1134 | "Now we'd have all the places in place, if it weren't for one subtlety: The Princeton compute nodes don't have internet connectivity.\n",
1135 | "This is also why we set `offline=True` to the `WandbLogger`. But that's not a problem, because we have internet on the head node `della-gpu`\n",
1136 | "(just not on the compute node). So we can simply run `wandb sync /path/to/run/dir` afterwards.\n",
1137 | "However, because this is annoying, I wrote a package `wandb-osh` to help with this.\n",
1138 | "\n",
1139 | "To install it, run:"
1140 | ]
1141 | },
1142 | {
1143 | "cell_type": "code",
1144 | "execution_count": null,
1145 | "metadata": {
1146 | "vscode": {
1147 | "languageId": "python"
1148 | }
1149 | },
1150 | "outputs": [
1151 | {
1152 | "name": "stdout",
1153 | "output_type": "stream",
1154 | "text": [
1155 | "Requirement already satisfied: wandb-osh in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (1.0.4)\n",
1156 | "Requirement already satisfied: colorlog in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb-osh) (6.7.0)\n",
1157 | "Requirement already satisfied: wandb in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb-osh) (0.15.4)\n",
1158 | "Requirement already satisfied: Click!=8.0.0,>=7.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (8.1.3)\n",
1159 | "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (3.1.31)\n",
1160 | "Requirement already satisfied: requests<3,>=2.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (2.31.0)\n",
1161 | "Requirement already satisfied: psutil>=5.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (5.9.5)\n",
1162 | "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.21.1)\n",
1163 | "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (0.4.0)\n",
1164 | "Requirement already satisfied: PyYAML in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (6.0)\n",
1165 | "Requirement already satisfied: pathtools in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (0.1.2)\n",
1166 | "Requirement already satisfied: setproctitle in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.3.2)\n",
1167 | "Requirement already satisfied: setuptools in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (67.7.2)\n",
1168 | "Requirement already satisfied: appdirs>=1.4.3 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (1.4.4)\n",
1169 | "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from wandb->wandb-osh) (3.20.3)\n",
1170 | "Requirement already satisfied: six>=1.4.0 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb->wandb-osh) (1.16.0)\n",
1171 | "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb->wandb-osh) (4.0.10)\n",
1172 | "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (3.1.0)\n",
1173 | "Requirement already satisfied: idna<4,>=2.5 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (3.4)\n",
1174 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (2.0.3)\n",
1175 | "Requirement already satisfied: certifi>=2017.4.17 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->wandb-osh) (2023.5.7)\n",
1176 | "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/fuchur/micromamba/envs/gnn/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->wandb-osh) (3.0.5)\n"
1177 | ]
1178 | }
1179 | ],
1180 | "source": [
1181 | "! pip3 install wandb-osh"
1182 | ]
1183 | },
1184 | {
1185 | "attachments": {},
1186 | "cell_type": "markdown",
1187 | "metadata": {},
1188 | "source": [
1189 | "Now let's put everything together: "
1190 | ]
1191 | },
1192 | {
1193 | "cell_type": "code",
1194 | "execution_count": null,
1195 | "metadata": {
1196 | "vscode": {
1197 | "languageId": "python"
1198 | }
1199 | },
1200 | "outputs": [],
1201 | "source": [
1202 | "from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback\n",
1203 | "\n",
1204 | "\n",
1205 | "trainer = Trainer(\n",
1206 | " max_epochs=3,\n",
1207 | " accelerator=\"cpu\",\n",
1208 | " log_every_n_steps=1,\n",
1209 | " callbacks=[\n",
1210 | " TriggerWandbSyncLightningCallback(),\n",
1211 | " PrintValidationMetrics(),\n",
1212 | " ],\n",
1213 | " logger=[\n",
1214 | " wandb_logger,\n",
1215 | " tb_logger,\n",
1216 | " ],\n",
1217 | ")"
1218 | ]
1219 | },
1220 | {
1221 | "attachments": {},
1222 | "cell_type": "markdown",
1223 | "metadata": {},
1224 | "source": [
1225 | "To sync your run, simply start the `wandb-osh` command line utility on `della-gpu`.\n",
1226 | "For more information on how this works, see [here](https://github.com/klieret/wandb-offline-sync-hook)."
1227 | ]
1228 | }
1229 | ],
1230 | "metadata": {
1231 | "kernelspec": {
1232 | "display_name": "Python 3 (ipykernel)",
1233 | "language": "python",
1234 | "name": "python3"
1235 | },
1236 | "language_info": {
1237 | "name": "",
1238 | "version": ""
1239 | }
1240 | },
1241 | "nbformat": 4,
1242 | "nbformat_minor": 1
1243 | }
1244 |
--------------------------------------------------------------------------------
/notebooks/020_one_shot_object_condensation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": false
7 | },
8 | "source": [
9 | "# One shot object condensation\n",
10 | "\n",
11 | "This notebook shows how you can implement a model that directly goes from point cloud data to object condensation."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 63,
17 | "metadata": {
18 | "collapsed": false
19 | },
20 | "outputs": [],
21 | "source": [
22 | "from pathlib import Path\n",
23 | "from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin\n",
24 | "from torch import nn\n",
25 | "from torch_geometric.nn.conv import GravNetConv\n",
26 | "from torch_geometric.data import Data\n",
27 | "from pytorch_lightning import Trainer\n",
28 | "\n",
29 | "from gnn_tracking.metrics.losses.oc import CondensationLossTiger\n",
30 | "import torch\n",
31 | "from functools import partial\n",
32 | "\n",
33 | "from gnn_tracking.training.callbacks import PrintValidationMetrics\n",
34 | "from gnn_tracking.training.tc import TCModule\n",
35 | "from gnn_tracking.utils.loading import TrackingDataModule\n",
36 | "from gnn_tracking.utils.versioning import assert_version_geq\n",
37 | "\n",
38 | "assert_version_geq(\"23.12.0\")"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "collapsed": false
45 | },
46 | "source": [
47 | "## 1. Configure data"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 64,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "data_dir = (\n",
57 | " Path.cwd().resolve().parent.parent / \"test-data\" / \"data\" / \"point_clouds\" / \"v8\"\n",
58 | ")\n",
59 | "assert data_dir.is_dir()"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 65,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "dm = TrackingDataModule(\n",
69 | " train=dict(\n",
70 | " dirs=[data_dir],\n",
71 | " stop=1,\n",
72 | " ),\n",
73 | " val=dict(\n",
74 | " dirs=[data_dir],\n",
75 | " start=1,\n",
76 | " stop=2,\n",
77 | " ),\n",
78 | " identifier=\"point_clouds_v8\",\n",
79 | " # could also configure a 'test' set here\n",
80 | ")"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {
86 | "collapsed": false
87 | },
88 | "source": [
89 | "## 2. Write a model"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 66,
95 | "metadata": {
96 | "collapsed": false
97 | },
98 | "outputs": [],
99 | "source": [
100 | "class DemoGravNet(nn.Module, HyperparametersMixin):\n",
101 | " def __init__(self, in_dim: int = 14, depth: int = 1, k: int = 2):\n",
102 | " super().__init__()\n",
103 | " self.save_hyperparameters()\n",
104 | " layers = [\n",
105 | " GravNetConv(\n",
106 | " in_channels=in_dim,\n",
107 | " out_channels=in_dim,\n",
108 | " space_dimensions=3,\n",
109 | " propagate_dimensions=3,\n",
110 | " k=k,\n",
111 | " )\n",
112 | " for _ in range(depth)\n",
113 | " ]\n",
114 | " self._embedding = nn.Sequential(*layers)\n",
115 | " self._beta = nn.Sequential(\n",
116 | " nn.Linear(in_dim, 1),\n",
117 | " nn.Sigmoid(),\n",
118 | " )\n",
119 | "\n",
120 | " def forward(self, data: Data):\n",
121 | " latent = self._embedding(data.x)\n",
122 | " beta = self._beta(latent).squeeze()\n",
123 | " eps = 1e-6\n",
124 | " beta = beta.clamp(eps, 1 - eps)\n",
125 | " return {\n",
126 | " \"B\": beta,\n",
127 | " \"H\": latent,\n",
128 | " }"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 67,
134 | "metadata": {
135 | "collapsed": false
136 | },
137 | "outputs": [],
138 | "source": [
139 | "model = DemoGravNet()"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {
145 | "collapsed": false
146 | },
147 | "source": [
148 | "## 3. Configure loss functions and weights"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 68,
154 | "metadata": {
155 | "collapsed": false
156 | },
157 | "outputs": [],
158 | "source": [
159 | "# The loss functions can be memory hungry. Here we override `data_preproc` to place a tighter pt cut on\n",
160 | "# the data to easy computation (since this is just a demo).\n",
161 | "class PtCut(HyperparametersMixin):\n",
162 | " def __call__(self, data: Data):\n",
163 | " mask = data.pt > 4\n",
164 | " data = data.subgraph(mask)\n",
165 | " return data"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 71,
171 | "metadata": {
172 | "collapsed": false
173 | },
174 | "outputs": [],
175 | "source": [
176 | "from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner\n",
177 | "\n",
178 | "# TC for track condensation\n",
179 | "lmodel = TCModule(\n",
180 | " model=model,\n",
181 | " loss_fct=CondensationLossTiger(\n",
182 | " lw_repulsive=2.0,\n",
183 | " ),\n",
184 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n",
185 | " cluster_scanner=DBSCANHyperParamScanner(n_trials=5, n_jobs=1),\n",
186 | " preproc=PtCut(),\n",
187 | ")"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {
193 | "collapsed": false
194 | },
195 | "source": [
196 | "## 4. Train the model"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 72,
202 | "metadata": {
203 | "collapsed": false
204 | },
205 | "outputs": [
206 | {
207 | "name": "stderr",
208 | "output_type": "stream",
209 | "text": [
210 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...\n",
211 | "GPU available: False, used: False\n",
212 | "TPU available: False, using: 0 TPU cores\n",
213 | "IPU available: False, using: 0 IPUs\n",
214 | "HPU available: False, using: 0 HPUs\n",
215 | "\u001b[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n",
216 | "\u001b[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt\u001b[0m\n",
217 | "\u001b[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).\u001b[0m\n",
218 | "\u001b[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.pt\u001b[0m\n",
219 | "\n",
220 | " | Name | Type | Params\n",
221 | "---------------------------------------------------\n",
222 | "0 | model | DemoGravNet | 399 \n",
223 | "1 | loss_fct | CondensationLossTiger | 0 \n",
224 | "---------------------------------------------------\n",
225 | "399 Trainable params\n",
226 | "0 Non-trainable params\n",
227 | "399 Total params\n",
228 | "0.002 Total estimated model params size (MB)\n"
229 | ]
230 | },
231 | {
232 | "name": "stdout",
233 | "output_type": "stream",
234 | "text": [
235 | "Sanity Checking: | | 0/? [00:00, ?it/s]"
236 | ]
237 | },
238 | {
239 | "name": "stderr",
240 | "output_type": "stream",
241 | "text": [
242 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.\n"
243 | ]
244 | },
245 | {
246 | "name": "stdout",
247 | "output_type": "stream",
248 | "text": [
249 | "Sanity Checking DataLoader 0: 0%| | 0/1 [00:00, ?it/s]"
250 | ]
251 | },
252 | {
253 | "name": "stderr",
254 | "output_type": "stream",
255 | "text": [
256 | "No CUDA runtime is found, using CUDA_HOME='/scratch/gpfs/kl5675/micromamba/envs/gnn'\n"
257 | ]
258 | },
259 | {
260 | "name": "stdout",
261 | "output_type": "stream",
262 | "text": [
263 | " \r"
264 | ]
265 | },
266 | {
267 | "name": "stderr",
268 | "output_type": "stream",
269 | "text": [
270 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:212: You called `self.log('n_rep', ...)` in your `validation_step` but the value needs to be floating to be reduced. Converting it to torch.float32. You can silence this warning by converting the value to floating point yourself. If you don't intend to reduce the value (for instance when logging the global step or epoch) then you can use `self.logger.log_metrics({'n_rep': ...})` instead.\n",
271 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.\n"
272 | ]
273 | },
274 | {
275 | "name": "stdout",
276 | "output_type": "stream",
277 | "text": [
278 | "Epoch 0: 0%| | 0/1 [00:00, ?it/s]"
279 | ]
280 | },
281 | {
282 | "name": "stderr",
283 | "output_type": "stream",
284 | "text": [
285 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:212: You called `self.log('n_rep_train', ...)` in your `training_step` but the value needs to be floating to be reduced. Converting it to torch.float32. You can silence this warning by converting the value to floating point yourself. If you don't intend to reduce the value (for instance when logging the global step or epoch) then you can use `self.logger.log_metrics({'n_rep_train': ...})` instead.\n"
286 | ]
287 | },
288 | {
289 | "name": "stdout",
290 | "output_type": "stream",
291 | "text": [
292 | "Epoch 0: 100%|█| 1/1 [00:10<00:00, 0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise"
293 | ]
294 | },
295 | {
296 | "name": "stderr",
297 | "output_type": "stream",
298 | "text": [
299 | "NaN or Inf found in input tensor.\n",
300 | "NaN or Inf found in input tensor.\n",
301 | "NaN or Inf found in input tensor.\n"
302 | ]
303 | },
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": []
308 | },
309 | {
310 | "name": "stderr",
311 | "output_type": "stream",
312 | "text": [
313 | "NaN or Inf found in input tensor.\n",
314 | "NaN or Inf found in input tensor.\n",
315 | "NaN or Inf found in input tensor.\n",
316 | "NaN or Inf found in input tensor.\n",
317 | "NaN or Inf found in input tensor.\n",
318 | "NaN or Inf found in input tensor.\n",
319 | "NaN or Inf found in input tensor.\n",
320 | "NaN or Inf found in input tensor.\n",
321 | "NaN or Inf found in input tensor.\n",
322 | "NaN or Inf found in input tensor.\n",
323 | "NaN or Inf found in input tensor.\n",
324 | "NaN or Inf found in input tensor.\n",
325 | "NaN or Inf found in input tensor.\n",
326 | "NaN or Inf found in input tensor.\n",
327 | "NaN or Inf found in input tensor.\n",
328 | "NaN or Inf found in input tensor.\n",
329 | "NaN or Inf found in input tensor.\n",
330 | "NaN or Inf found in input tensor.\n",
331 | "NaN or Inf found in input tensor.\n",
332 | "NaN or Inf found in input tensor.\n",
333 | "NaN or Inf found in input tensor.\n",
334 | "NaN or Inf found in input tensor.\n",
335 | "NaN or Inf found in input tensor.\n",
336 | "NaN or Inf found in input tensor.\n",
337 | "NaN or Inf found in input tensor.\n",
338 | "NaN or Inf found in input tensor.\n",
339 | "NaN or Inf found in input tensor.\n",
340 | "NaN or Inf found in input tensor.\n",
341 | "NaN or Inf found in input tensor.\n",
342 | "NaN or Inf found in input tensor.\n",
343 | "NaN or Inf found in input tensor.\n",
344 | "NaN or Inf found in input tensor.\n",
345 | "NaN or Inf found in input tensor.\n",
346 | "NaN or Inf found in input tensor.\n",
347 | "NaN or Inf found in input tensor.\n",
348 | "NaN or Inf found in input tensor.\n",
349 | "NaN or Inf found in input tensor.\n",
350 | "NaN or Inf found in input tensor.\n",
351 | "NaN or Inf found in input tensor.\n",
352 | "NaN or Inf found in input tensor.\n",
353 | "NaN or Inf found in input tensor.\n",
354 | "NaN or Inf found in input tensor.\n",
355 | "NaN or Inf found in input tensor.\n",
356 | "NaN or Inf found in input tensor.\n",
357 | "NaN or Inf found in input tensor.\n",
358 | "NaN or Inf found in input tensor.\n",
359 | "NaN or Inf found in input tensor.\n",
360 | "NaN or Inf found in input tensor.\n",
361 | "NaN or Inf found in input tensor.\n",
362 | "NaN or Inf found in input tensor.\n",
363 | "NaN or Inf found in input tensor.\n",
364 | "NaN or Inf found in input tensor.\n",
365 | "NaN or Inf found in input tensor.\n",
366 | "NaN or Inf found in input tensor.\n",
367 | "NaN or Inf found in input tensor.\n",
368 | "NaN or Inf found in input tensor.\n",
369 | "NaN or Inf found in input tensor.\n",
370 | "NaN or Inf found in input tensor.\n",
371 | "NaN or Inf found in input tensor.\n",
372 | "NaN or Inf found in input tensor.\n",
373 | "NaN or Inf found in input tensor.\n",
374 | "NaN or Inf found in input tensor.\n",
375 | "NaN or Inf found in input tensor.\n",
376 | "NaN or Inf found in input tensor.\n"
377 | ]
378 | },
379 | {
380 | "data": {
381 | "text/html": [
382 | "\n"
383 | ],
384 | "text/plain": []
385 | },
386 | "metadata": {},
387 | "output_type": "display_data"
388 | },
389 | {
390 | "name": "stdout",
391 | "output_type": "stream",
392 | "text": [
393 | "\n",
394 | "\u001b[3m Validation epoch=0 \u001b[0m \n",
395 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓\n",
396 | "┃\u001b[1m \u001b[0m\u001b[1mMetric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mError\u001b[0m\u001b[1m \u001b[0m┃\n",
397 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩\n",
398 | "│\u001b[1;95m \u001b[0m\u001b[1;95mattractive \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m55245000.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
399 | "│ attractive_train │ 93512936.00000 │ nan │\n",
400 | "│ attractive_weighted │ 55245000.00000 │ nan │\n",
401 | "│ attractive_weighted_train │ 93512936.00000 │ nan │\n",
402 | "│ best_dbscan_eps │ 0.15979 │ nan │\n",
403 | "│ best_dbscan_min_samples │ 4.00000 │ nan │\n",
404 | "│ coward │ 0.03412 │ nan │\n",
405 | "│ coward_train │ 0.05263 │ nan │\n",
406 | "│ coward_weighted │ 0.00000 │ nan │\n",
407 | "│ coward_weighted_train │ 0.00000 │ nan │\n",
408 | "│ n_rep │ 1.00000 │ nan │\n",
409 | "│ n_rep_train │ 0.00000 │ nan │\n",
410 | "│ noise │ nan │ nan │\n",
411 | "│ noise_train │ nan │ nan │\n",
412 | "│ noise_weighted │ nan │ nan │\n",
413 | "│ noise_weighted_train │ nan │ nan │\n",
414 | "│\u001b[1;95m \u001b[0m\u001b[1;95mrepulsive \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00112\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
415 | "│ repulsive_train │ 0.00000 │ nan │\n",
416 | "│ repulsive_weighted │ 0.00223 │ nan │\n",
417 | "│ repulsive_weighted_train │ 0.00000 │ nan │\n",
418 | "│ total │ nan │ nan │\n",
419 | "│ total_train │ nan │ nan │\n",
420 | "│ trk.double_majority │ 0.00000 │ nan │\n",
421 | "│ trk.double_majority_pt0.5 │ 0.00000 │ nan │\n",
422 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.double_majority_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
423 | "│ trk.double_majority_pt1.5 │ 0.00000 │ nan │\n",
424 | "│ trk.fake_double_majority │ nan │ nan │\n",
425 | "│ trk.fake_double_majority_pt0.5 │ nan │ nan │\n",
426 | "│ trk.fake_double_majority_pt0.9 │ nan │ nan │\n",
427 | "│ trk.fake_double_majority_pt1.5 │ nan │ nan │\n",
428 | "│ trk.fake_lhc │ nan │ nan │\n",
429 | "│ trk.fake_lhc_pt0.5 │ nan │ nan │\n",
430 | "│ trk.fake_lhc_pt0.9 │ nan │ nan │\n",
431 | "│ trk.fake_lhc_pt1.5 │ nan │ nan │\n",
432 | "│ trk.fake_perfect │ nan │ nan │\n",
433 | "│ trk.fake_perfect_pt0.5 │ nan │ nan │\n",
434 | "│ trk.fake_perfect_pt0.9 │ nan │ nan │\n",
435 | "│ trk.fake_perfect_pt1.5 │ nan │ nan │\n",
436 | "│ trk.i_batch │ 0.00000 │ nan │\n",
437 | "│ trk.lhc │ nan │ nan │\n",
438 | "│ trk.lhc_pt0.5 │ nan │ nan │\n",
439 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.lhc_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
440 | "│ trk.lhc_pt1.5 │ nan │ nan │\n",
441 | "│ trk.n_cleaned_clusters │ 0.00000 │ nan │\n",
442 | "│ trk.n_cleaned_clusters_pt0.5 │ 0.00000 │ nan │\n",
443 | "│ trk.n_cleaned_clusters_pt0.9 │ 0.00000 │ nan │\n",
444 | "│ trk.n_cleaned_clusters_pt1.5 │ 0.00000 │ nan │\n",
445 | "│ trk.n_particles │ 17.00000 │ nan │\n",
446 | "│ trk.n_particles_pt0.5 │ 17.00000 │ nan │\n",
447 | "│ trk.n_particles_pt0.9 │ 17.00000 │ nan │\n",
448 | "│ trk.n_particles_pt1.5 │ 17.00000 │ nan │\n",
449 | "│ trk.perfect │ 0.00000 │ nan │\n",
450 | "│ trk.perfect_pt0.5 │ 0.00000 │ nan │\n",
451 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtrk.perfect_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m 0.00000\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
452 | "│ trk.perfect_pt1.5 │ 0.00000 │ nan │\n",
453 | "└────────────────────────────────┴────────────────┴───────┘\n",
454 | "\n",
455 | "Epoch 0: 100%|█| 1/1 [00:11<00:00, 0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise"
456 | ]
457 | },
458 | {
459 | "name": "stderr",
460 | "output_type": "stream",
461 | "text": [
462 | "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
463 | ]
464 | },
465 | {
466 | "name": "stdout",
467 | "output_type": "stream",
468 | "text": [
469 | "Epoch 0: 100%|█| 1/1 [00:11<00:00, 0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise\n"
470 | ]
471 | }
472 | ],
473 | "source": [
474 | "trainer = Trainer(\n",
475 | " max_epochs=1,\n",
476 | " accelerator=\"cpu\",\n",
477 | " log_every_n_steps=1,\n",
478 | " callbacks=[PrintValidationMetrics()],\n",
479 | ")\n",
480 | "trainer.fit(model=lmodel, datamodule=dm)"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": null,
486 | "metadata": {
487 | "collapsed": false
488 | },
489 | "outputs": [],
490 | "source": []
491 | }
492 | ],
493 | "metadata": {
494 | "kernelspec": {
495 | "display_name": "Python 3 (ipykernel)",
496 | "language": "python",
497 | "name": "python3"
498 | },
499 | "language_info": {
500 | "codemirror_mode": {
501 | "name": "ipython",
502 | "version": 3
503 | },
504 | "file_extension": ".py",
505 | "mimetype": "text/x-python",
506 | "name": "python",
507 | "nbconvert_exporter": "python",
508 | "pygments_lexer": "ipython3",
509 | "version": "3.10.11"
510 | }
511 | },
512 | "nbformat": 4,
513 | "nbformat_minor": 0
514 | }
515 |
--------------------------------------------------------------------------------
/notebooks/030_edge_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": false
7 | },
8 | "source": [
9 | "# Edge classification\n",
10 | "\n",
11 | "This notebooks shows how to classify edges of a graph. In many GNN tracking approaches, we start from an initial graph (e.g., built from a point cloud with the strategy described in `009_build_graphs_ml.ipynb`). We then try to falsify all edges that connected hits of two different particles. If edge classification (EC) would be perfect, we could then reconstruct tracks as connected components of the graph.\n",
12 | "For our object condensation approach, EC is only an auxiliary step. Edges are only considered for message passing but are not important for the final decision on how tracks look. However, EC is still important to help the model to learn quickly.\n",
13 | "\n",
14 | "For background on pytorch lightning, see `009_build_graphs_ml.ipynb`."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 76,
20 | "metadata": {
21 | "collapsed": false
22 | },
23 | "outputs": [],
24 | "source": [
25 | "from pytorch_lightning import Trainer\n",
26 | "from torch import nn\n",
27 | "from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin\n",
28 | "import torch\n",
29 | "from functools import partial\n",
30 | "\n",
31 | "from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt\n",
32 | "\n",
33 | "from gnn_tracking.metrics.losses.ec import EdgeWeightFocalLoss\n",
34 | "from gnn_tracking.training.callbacks import PrintValidationMetrics\n",
35 | "from gnn_tracking.training.ec import ECModule\n",
36 | "\n",
37 | "from gnn_tracking.utils.loading import TrackingDataModule\n",
38 | "\n",
39 | "\n",
40 | "from gnn_tracking.utils.versioning import assert_version_geq\n",
41 | "\n",
42 | "assert_version_geq(\"23.12.0\")"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {
48 | "collapsed": false
49 | },
50 | "source": [
51 | "We can either directly load graphs (from disk), or we load point clouds and build edges on the fly using the module from `009_build_graphs_ml.ipynb`."
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {
57 | "collapsed": false
58 | },
59 | "source": [
60 | "## From on-disk graphs"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {
66 | "collapsed": false
67 | },
68 | "source": [
69 | "### 1. Setting up the data\n",
70 | "\n",
71 | "If you are not working on Princeton's `della`, you can download these example graphs [here](https://cernbox.cern.ch/s/4xYL99cd7zNe0VK). Note that this is simplified data (pt > 1 GeV truth cut) and a single event has been broken up into 32 sectors."
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 85,
77 | "metadata": {
78 | "collapsed": false
79 | },
80 | "outputs": [],
81 | "source": [
82 | "dm = TrackingDataModule(\n",
83 | " train=dict(\n",
84 | " dirs=[\n",
85 | " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all\"\n",
86 | " ],\n",
87 | " stop=28_000,\n",
88 | " # If you run into memory issues, reduce this\n",
89 | " batch_size=10,\n",
90 | " ),\n",
91 | " val=dict(\n",
92 | " dirs=[\n",
93 | " \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all\"\n",
94 | " ],\n",
95 | " start=28_000,\n",
96 | " stop=28_100,\n",
97 | " ),\n",
98 | " identifier=\"graphs_v1\",\n",
99 | ")"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {
105 | "collapsed": false
106 | },
107 | "source": [
108 | "### 2. Defining the module"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 86,
114 | "metadata": {
115 | "collapsed": false
116 | },
117 | "outputs": [],
118 | "source": [
119 | "class SillyEC(nn.Module, HyperparametersMixin):\n",
120 | " def __init__(\n",
121 | " self,\n",
122 | " node_in_dim: int,\n",
123 | " edge_in_dim: int,\n",
124 | " hidden_dim: int = 12,\n",
125 | " ):\n",
126 | " super().__init__()\n",
127 | " self.save_hyperparameters()\n",
128 | " self.node_in_dim = node_in_dim\n",
129 | " self.edge_in_dim = edge_in_dim\n",
130 | " self.hidden_dim = hidden_dim\n",
131 | "\n",
132 | " self.fcnn = nn.Sequential(\n",
133 | " nn.Linear(edge_in_dim, hidden_dim),\n",
134 | " nn.ReLU(),\n",
135 | " nn.Linear(hidden_dim, hidden_dim),\n",
136 | " nn.ReLU(),\n",
137 | " nn.Linear(hidden_dim, 1),\n",
138 | " nn.Sigmoid(),\n",
139 | " )\n",
140 | "\n",
141 | " def forward(self, data):\n",
142 | " w = self.fcnn(data.edge_attr).squeeze()\n",
143 | " return {\"W\": w}"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 87,
149 | "metadata": {
150 | "collapsed": false
151 | },
152 | "outputs": [],
153 | "source": [
154 | "model = SillyEC(node_in_dim=6, edge_in_dim=4, hidden_dim=128)"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {
160 | "collapsed": false
161 | },
162 | "source": [
163 | "### 2. Setting up the loss functions and the lightning module"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 88,
169 | "metadata": {
170 | "collapsed": false
171 | },
172 | "outputs": [],
173 | "source": [
174 | "lmodel = ECModule(\n",
175 | " model=model,\n",
176 | " loss_fct=EdgeWeightFocalLoss(alpha=0.3),\n",
177 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n",
178 | ")"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {
184 | "collapsed": false
185 | },
186 | "source": [
187 | "### 3. Starting training"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 89,
193 | "metadata": {
194 | "collapsed": false
195 | },
196 | "outputs": [
197 | {
198 | "name": "stderr",
199 | "output_type": "stream",
200 | "text": [
201 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...\n",
202 | "GPU available: False, used: False\n",
203 | "TPU available: False, using: 0 TPU cores\n",
204 | "IPU available: False, using: 0 IPUs\n",
205 | "HPU available: False, using: 0 HPUs\n",
206 | "\u001b[32m[16:06:55] INFO: DataLoader will load 28000 graphs (out of 28800 available).\u001b[0m\n",
207 | "\u001b[36m[16:06:55] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21974_s9.pt\u001b[0m\n",
208 | "\u001b[32m[16:06:56] INFO: DataLoader will load 100 graphs (out of 28800 available).\u001b[0m\n",
209 | "\u001b[36m[16:06:56] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21975_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21978_s11.pt\u001b[0m\n",
210 | "\n",
211 | " | Name | Type | Params\n",
212 | "-------------------------------------------------\n",
213 | "0 | model | SillyEC | 17.3 K\n",
214 | "1 | loss_fct | EdgeWeightFocalLoss | 0 \n",
215 | "-------------------------------------------------\n",
216 | "17.3 K Trainable params\n",
217 | "0 Non-trainable params\n",
218 | "17.3 K Total params\n",
219 | "0.069 Total estimated model params size (MB)\n"
220 | ]
221 | },
222 | {
223 | "name": "stdout",
224 | "output_type": "stream",
225 | "text": [
226 | "Sanity Checking: | | 0/? [00:00, ?it/s]"
227 | ]
228 | },
229 | {
230 | "name": "stderr",
231 | "output_type": "stream",
232 | "text": [
233 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.\n"
234 | ]
235 | },
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | " \r"
241 | ]
242 | },
243 | {
244 | "name": "stderr",
245 | "output_type": "stream",
246 | "text": [
247 | "/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.\n"
248 | ]
249 | },
250 | {
251 | "name": "stdout",
252 | "output_type": "stream",
253 | "text": [
254 | "Epoch 0: 4%|█████▉ | 100/2800 [00:24<11:10, 4.03it/s, v_num=8, total_train=0.0646]"
255 | ]
256 | },
257 | {
258 | "data": {
259 | "text/html": [
260 | "\n"
261 | ],
262 | "text/plain": []
263 | },
264 | "metadata": {},
265 | "output_type": "display_data"
266 | },
267 | {
268 | "name": "stdout",
269 | "output_type": "stream",
270 | "text": [
271 | "\n",
272 | " \n",
273 | " \n",
274 | "\u001b[3m Validation epoch=0 \u001b[0m \n",
275 | "┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓\n",
276 | "┃\u001b[1m \u001b[0m\u001b[1mMetric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Error\u001b[0m\u001b[1m \u001b[0m┃\n",
277 | "┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩\n",
278 | "│ max_ba │ 0.81276 │ 0.00218 │\n",
279 | "│ max_ba_loc │ 0.42698 │ 0.00009 │\n",
280 | "│ max_ba_loc_pt0.5 │ 0.42698 │ 0.00009 │\n",
281 | "│ max_ba_loc_pt0.9 │ 0.42698 │ 0.00009 │\n",
282 | "│ max_ba_loc_pt1.5 │ 0.42663 │ 0.00017 │\n",
283 | "│ max_ba_pt0.5 │ 0.81276 │ 0.00218 │\n",
284 | "│ max_ba_pt0.9 │ 0.81276 │ 0.00218 │\n",
285 | "│ max_ba_pt1.5 │ 0.80939 │ 0.00313 │\n",
286 | "│ max_f1 │ 0.63106 │ 0.00708 │\n",
287 | "│ max_f1_loc │ 0.42714 │ 0.00000 │\n",
288 | "│ max_f1_loc_pt0.5 │ 0.42714 │ 0.00000 │\n",
289 | "│ max_f1_loc_pt0.9 │ 0.42714 │ 0.00000 │\n",
290 | "│ max_f1_loc_pt1.5 │ 0.42749 │ 0.00013 │\n",
291 | "│ max_f1_pt0.5 │ 0.63106 │ 0.00708 │\n",
292 | "│ max_f1_pt0.9 │ 0.63106 │ 0.00708 │\n",
293 | "│ max_f1_pt1.5 │ 0.53800 │ 0.00936 │\n",
294 | "│ max_mcc │ 0.53974 │ 0.00572 │\n",
295 | "│ max_mcc_loc │ 0.42719 │ 0.00005 │\n",
296 | "│ max_mcc_loc_pt0.5 │ 0.42719 │ 0.00005 │\n",
297 | "│ max_mcc_loc_pt0.9 │ 0.42719 │ 0.00005 │\n",
298 | "│ max_mcc_loc_pt1.5 │ 0.42779 │ 0.00017 │\n",
299 | "│ max_mcc_pt0.5 │ 0.53974 │ 0.00572 │\n",
300 | "│\u001b[1;95m \u001b[0m\u001b[1;95mmax_mcc_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.53974\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.00572\u001b[0m\u001b[1;95m \u001b[0m│\n",
301 | "│ max_mcc_pt1.5 │ 0.48179 │ 0.00770 │\n",
302 | "│ roc_auc │ 0.87181 │ 0.00222 │\n",
303 | "│ roc_auc_0.001FPR │ 0.50003 │ 0.00008 │\n",
304 | "│ roc_auc_0.001FPR_pt0.5 │ 0.50003 │ 0.00008 │\n",
305 | "│ roc_auc_0.001FPR_pt0.9 │ 0.50003 │ 0.00008 │\n",
306 | "│ roc_auc_0.001FPR_pt1.5 │ 0.50805 │ 0.00270 │\n",
307 | "│ roc_auc_0.01FPR │ 0.51681 │ 0.00309 │\n",
308 | "│ roc_auc_0.01FPR_pt0.5 │ 0.51681 │ 0.00309 │\n",
309 | "│ roc_auc_0.01FPR_pt0.9 │ 0.51681 │ 0.00309 │\n",
310 | "│ roc_auc_0.01FPR_pt1.5 │ 0.53808 │ 0.00470 │\n",
311 | "│ roc_auc_pt0.5 │ 0.87181 │ 0.00222 │\n",
312 | "│ roc_auc_pt0.9 │ 0.87181 │ 0.00222 │\n",
313 | "│ roc_auc_pt1.5 │ 0.87136 │ 0.00296 │\n",
314 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtotal \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.06456\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m nan\u001b[0m\u001b[1;95m \u001b[0m│\n",
315 | "│ total_train │ 0.06461 │ nan │\n",
316 | "│ tpr_eq_tnr │ 0.81249 │ 0.00221 │\n",
317 | "│ tpr_eq_tnr_loc │ 0.42693 │ 0.00010 │\n",
318 | "│ tpr_eq_tnr_loc_pt0.5 │ 0.42693 │ 0.00010 │\n",
319 | "│ tpr_eq_tnr_loc_pt0.9 │ 0.42693 │ 0.00010 │\n",
320 | "│ tpr_eq_tnr_loc_pt1.5 │ 0.42693 │ 0.00012 │\n",
321 | "│ tpr_eq_tnr_pt0.5 │ 0.81249 │ 0.00221 │\n",
322 | "│\u001b[1;95m \u001b[0m\u001b[1;95mtpr_eq_tnr_pt0.9 \u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.81249\u001b[0m\u001b[1;95m \u001b[0m│\u001b[1;95m \u001b[0m\u001b[1;95m0.00221\u001b[0m\u001b[1;95m \u001b[0m│\n",
323 | "│ tpr_eq_tnr_pt1.5 │ 0.80665 │ 0.00382 │\n",
324 | "└────────────────────────┴─────────┴─────────┘\n",
325 | "\n",
326 | "Epoch 0: 4%|█████▉ | 100/2800 [00:32<14:50, 3.03it/s, v_num=8, total_train=0.0646]"
327 | ]
328 | },
329 | {
330 | "name": "stderr",
331 | "output_type": "stream",
332 | "text": [
333 | "`Trainer.fit` stopped: `max_steps=100` reached.\n"
334 | ]
335 | },
336 | {
337 | "name": "stdout",
338 | "output_type": "stream",
339 | "text": [
340 | "Epoch 0: 4%|█████▉ | 100/2800 [00:32<14:50, 3.03it/s, v_num=8, total_train=0.0646]\n"
341 | ]
342 | }
343 | ],
344 | "source": [
345 | "trainer = Trainer(\n",
346 | " max_steps=100,\n",
347 | " val_check_interval=100,\n",
348 | " accelerator=\"cpu\",\n",
349 | " log_every_n_steps=1,\n",
350 | " callbacks=[PrintValidationMetrics()],\n",
351 | ")\n",
352 | "trainer.fit(model=lmodel, datamodule=dm)"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {
358 | "collapsed": false
359 | },
360 | "source": [
361 | "## With graphs built on-the-fly from point clouds"
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "metadata": {
367 | "collapsed": false
368 | },
369 | "source": [
370 | "Step 1: Configure data module to load point clouds (rather than graphs).\n",
371 | "Step 2: Add `MLGraphConstructionFromChkpt` as preproc."
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {
378 | "collapsed": false
379 | },
380 | "outputs": [],
381 | "source": [
382 | "lmodel = ECModule(\n",
383 | " model=model,\n",
384 | " loss_fct=EdgeWeightFocalLoss(alpha=0.3),\n",
385 | " optimizer=partial(torch.optim.Adam, lr=1e-4),\n",
386 | " preproc=MLGraphConstructionFromChkpt(\n",
387 | " ml_class_name=\"gnn_tracking.models.graph_construction.GraphConstructionFCNN\",\n",
388 | " ml_chkpt_path=\"/path/to/your/checkpoint\",\n",
389 | " ),\n",
390 | ")"
391 | ]
392 | },
393 | {
394 | "cell_type": "markdown",
395 | "metadata": {
396 | "collapsed": false
397 | },
398 | "source": [
399 | "Instead of `MLGraphConstructionFromChkpt` you can also take a look at `MLGraphConstruction` that simply takes a model (that you can instantiate in any way)."
400 | ]
401 | }
402 | ],
403 | "metadata": {
404 | "kernelspec": {
405 | "display_name": "Python 3 (ipykernel)",
406 | "language": "python",
407 | "name": "python3"
408 | },
409 | "language_info": {
410 | "codemirror_mode": {
411 | "name": "ipython",
412 | "version": 3
413 | },
414 | "file_extension": ".py",
415 | "mimetype": "text/x-python",
416 | "name": "python",
417 | "nbconvert_exporter": "python",
418 | "pygments_lexer": "ipython3",
419 | "version": "3.10.11"
420 | }
421 | },
422 | "nbformat": 4,
423 | "nbformat_minor": 0
424 | }
425 |
--------------------------------------------------------------------------------
/notebooks/040_three_shot_object_condensation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": false
7 | },
8 | "source": [
9 | "# Three-shot object condensation\n",
10 | "\n",
11 | "This sketches how to implement the pipleine of \"graph construction (GC)\" > \"edge classification (EC)\" > \"object condensation (OC)\".\n",
12 | "There are multiple options."
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "collapsed": false
19 | },
20 | "source": [
21 | "## Using graphs on disk\n",
22 | "\n",
23 | "`020_one_shot_object_condensation.ipynb` built graphs using kNN as part of the `GravNetConv`. But if you directly load graphs from disk, you can simply use any GNN and everything will work."
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "collapsed": false
30 | },
31 | "source": [
32 | "## From point clouds on disk and a pre-trained GC + EC\n",
33 | "\n",
34 | "We simply follow the last example from the EC notebook:"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {
40 | "collapsed": false
41 | },
42 | "source": [
43 | "```python3\n",
44 | "lmodel = TCModule(\n",
45 | " model=model,\n",
46 | " ...,\n",
47 | " preproc = MLGraphConstructionFromChkpt(\n",
48 | " ml_class_name=\"gnn_tracking.models.graph_construction.GraphConstructionFCNN\",\n",
49 | " ml_chkpt_path=\"/path/to/your/checkpoint\",\n",
50 | " ec_class_name=\"gnn_tracking.models.edge_classifier.ECForGraphTCN\",\n",
51 | " ec_chkpt_path=\"/path/to/your/checkpoint\",\n",
52 | " ),\n",
53 | ")\n",
54 | "```"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "collapsed": false
61 | },
62 | "source": [
63 | "Other than the `preproc` step, everything can be set up like in the EC notebook."
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "kernelspec": {
69 | "display_name": "Python 3",
70 | "language": "python",
71 | "name": "python3"
72 | },
73 | "language_info": {
74 | "codemirror_mode": {
75 | "name": "ipython",
76 | "version": 2
77 | },
78 | "file_extension": ".py",
79 | "mimetype": "text/x-python",
80 | "name": "python",
81 | "nbconvert_exporter": "python",
82 | "pygments_lexer": "ipython2",
83 | "version": "2.7.6"
84 | }
85 | },
86 | "nbformat": 4,
87 | "nbformat_minor": 0
88 | }
89 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | minversion = "6.0"
3 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--cov-branch"]
4 | xfail_strict = true
5 | testpaths = ["tests"]
6 |
7 | [tool.pycln]
8 | all = true
9 |
--------------------------------------------------------------------------------