├── .env.example ├── .github ├── ISSUE_TEMPLATE │ ├── 0-bug-report.yaml │ ├── 1-feature-request.yaml │ ├── 2-documentation.yaml │ ├── 3-refactor.yaml │ └── config.yaml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── code-quality-main.yaml │ ├── code-quality-pr.yaml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── CHANGELOG.md ├── CITATION.cff ├── LICENSE ├── README.md ├── configs ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── gradient_accumulator.yaml │ ├── lr_monitor.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── datamodule │ ├── panoptic │ │ ├── dales.yaml │ │ ├── dales_nano.yaml │ │ ├── kitti360.yaml │ │ ├── kitti360_nano.yaml │ │ ├── s3dis.yaml │ │ ├── s3dis_nano.yaml │ │ ├── s3dis_room.yaml │ │ ├── s3dis_with_stuff.yaml │ │ ├── s3dis_with_stuff_nano.yaml │ │ ├── scannet.yaml │ │ └── scannet_nano.yaml │ └── semantic │ │ ├── _features.yaml │ │ ├── dales.yaml │ │ ├── dales_nano.yaml │ │ ├── default.yaml │ │ ├── kitti360.yaml │ │ ├── kitti360_nano.yaml │ │ ├── s3dis.yaml │ │ ├── s3dis_nano.yaml │ │ ├── s3dis_room.yaml │ │ ├── scannet.yaml │ │ └── scannet_nano.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── panoptic │ │ ├── dales.yaml │ │ ├── dales_11g.yaml │ │ ├── dales_nano.yaml │ │ ├── kitti360.yaml │ │ ├── kitti360_11g.yaml │ │ ├── kitti360_nano.yaml │ │ ├── s3dis.yaml │ │ ├── s3dis_11g.yaml │ │ ├── s3dis_nano.yaml │ │ ├── s3dis_room.yaml │ │ ├── s3dis_with_stuff.yaml │ │ ├── s3dis_with_stuff_11g.yaml │ │ ├── s3dis_with_stuff_nano.yaml │ │ ├── scannet.yaml │ │ ├── scannet_11g.yaml │ │ └── scannet_nano.yaml │ └── semantic │ │ ├── dales.yaml │ │ ├── dales_11g.yaml │ │ ├── dales_nano.yaml │ │ ├── kitti360.yaml │ │ ├── kitti360_11g.yaml │ │ ├── kitti360_nano.yaml │ │ ├── s3dis.yaml │ │ ├── s3dis_11g.yaml │ │ ├── s3dis_nano.yaml │ │ ├── s3dis_room.yaml │ │ ├── scannet.yaml │ │ ├── scannet_11g.yaml │ │ └── scannet_nano.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── panoptic │ │ ├── _instance.yaml │ │ ├── nano-2.yaml │ │ ├── nano-3.yaml │ │ ├── spt-2.yaml │ │ ├── spt-3.yaml │ │ └── spt.yaml │ └── semantic │ │ ├── _attention.yaml │ │ ├── _down.yaml │ │ ├── _point.yaml │ │ ├── _up.yaml │ │ ├── default.yaml │ │ ├── nano-2.yaml │ │ ├── nano-3.yaml │ │ ├── spt-2.yaml │ │ ├── spt-3.yaml │ │ └── spt.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── docs ├── data_structures.md ├── datasets.md ├── logging.md └── visualization.md ├── install.sh ├── media ├── dales │ ├── input.png │ ├── pano_gt.png │ ├── sem_gt.png │ └── sem_gt_demo.png ├── kitti360 │ ├── input.png │ ├── pano_gt.png │ └── sem_gt.png ├── s3dis │ ├── input.png │ ├── pano_gt.png │ └── sem_gt.png ├── scannet │ ├── input.png │ ├── pano_gt.png │ └── sem_gt.png ├── superpoint_transformer_tutorial.pdf ├── teaser_spt.jpg ├── teaser_spt.png ├── teaser_supercluster.png ├── visualizations.7z ├── viz_100k.png ├── viz_10k.png ├── viz_errors.png ├── viz_p2.png ├── viz_pos.png ├── viz_radius_center.png ├── viz_rgb.png ├── viz_select.png ├── viz_x.png └── viz_y.png ├── notebooks ├── demo.ipynb ├── demo_nag.ipynb ├── demo_nag.pt ├── demo_panoptic_parametrization.ipynb └── superpoint_transformer_tutorial.ipynb ├── scripts ├── schedule.sh └── setup_dependencies.py ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── cluster.py │ ├── csr.py │ ├── data.py │ ├── instance.py │ └── nag.py ├── datamodules │ ├── __init__.py │ ├── base.py │ ├── components │ │ └── __init__.py │ ├── dales.py │ ├── kitti360.py │ ├── s3dis.py │ ├── s3dis_room.py │ └── scannet.py ├── datasets │ ├── __init__.py │ ├── base.py │ ├── dales.py │ ├── dales_config.py │ ├── kitti360.py │ ├── kitti360_config.py │ ├── s3dis.py │ ├── s3dis_config.py │ ├── s3dis_room.py │ ├── scannet.py │ └── scannet_config.py ├── debug.py ├── dependencies │ └── __init__.py ├── eval.py ├── loader │ ├── __init__.py │ └── dataloader.py ├── loss │ ├── __init__.py │ ├── bce.py │ ├── focal.py │ ├── l1.py │ ├── l2.py │ ├── lovasz.py │ ├── multi.py │ └── weighted.py ├── metrics │ ├── __init__.py │ ├── mean_average_precision.py │ ├── panoptic.py │ ├── semantic.py │ └── weighted_li.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── mlp.py │ │ └── spt.py │ ├── panoptic.py │ └── semantic.py ├── nn │ ├── __init__.py │ ├── attention.py │ ├── dropout.py │ ├── fusion.py │ ├── instance.py │ ├── mlp.py │ ├── norm.py │ ├── pool.py │ ├── position_encoding.py │ ├── stage.py │ ├── transformer.py │ └── unpool.py ├── optim │ ├── __init__.py │ └── lr_scheduler.py ├── train.py ├── transforms │ ├── __init__.py │ ├── data.py │ ├── debug.py │ ├── device.py │ ├── geometry.py │ ├── graph.py │ ├── instance.py │ ├── neighbors.py │ ├── partition.py │ ├── point.py │ ├── sampling.py │ └── transforms.py ├── utils │ ├── __init__.py │ ├── color.py │ ├── configs.py │ ├── cpu.py │ ├── download.py │ ├── dropout.py │ ├── edge.py │ ├── encoding.py │ ├── features.py │ ├── geometry.py │ ├── graph.py │ ├── ground.py │ ├── histogram.py │ ├── hydra.py │ ├── instance.py │ ├── io.py │ ├── keys.py │ ├── list.py │ ├── loss.py │ ├── memory.py │ ├── multiprocessing.py │ ├── neighbors.py │ ├── nn.py │ ├── output_panoptic.py │ ├── output_semantic.py │ ├── parameter.py │ ├── partition.py │ ├── point.py │ ├── pylogger.py │ ├── rich_utils.py │ ├── scannet.py │ ├── scatter.py │ ├── semantic.py │ ├── sparse.py │ ├── tensor.py │ ├── time.py │ ├── utils.py │ ├── wandb.py │ └── widgets.py └── visualization │ ├── __init__.py │ └── visualization.py └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_configs.py ├── test_eval.py ├── test_sweeps.py └── test_train.py /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "🚀 Feature Request" 2 | description: Propose a new feature 3 | title: Title of your feature request 4 | labels: ["feature"] 5 | 6 | body: 7 | 8 | - type: markdown 9 | attributes: 10 | value: Thanks for taking the time to fill out this feature request report 🙏 ! 11 | 12 | - type: checkboxes 13 | attributes: 14 | label: ✅ Code of conduct checklist 15 | description: > 16 | Before submitting a feature request, please make sure you went through 17 | the following steps. 18 | options: 19 | - label: "🌱 I am using the **_latest version_** of the [code](https://github.com/drprojects/superpoint_transformer/tree/master)." 20 | required: true 21 | - label: "📙 I **_thoroughly_** went through the [README](https://github.com/drprojects/superpoint_transformer/blob/master/README.md), but could not find the feature I need there." 22 | required: true 23 | - label: "📘 I **_thoroughly_** went through the tutorial [slides](media/superpoint_transformer_tutorial.pdf), [notebook](notebooks/superpoint_transformer_tutorial.ipynb), and [video](https://www.youtube.com/watch?v=2qKhpQs9gJw), but could not find the feature I need there." 24 | required: true 25 | - label: "📗 I **_thoroughly_** went through the [documentation](https://github.com/drprojects/superpoint_transformer/tree/master/docs), but could not find the feature I need there." 26 | required: true 27 | - label: "📜 I went through the **_docstrings_** and **_comments_** in the [source code](https://github.com/drprojects/superpoint_transformer/tree/master) parts relevant to my problem, but could not find the feature I need there." 28 | required: true 29 | - label: "👩‍🔧 I searched for [**_similar issues_**](https://github.com/drprojects/superpoint_transformer/issues), but could not find the feature I need there." 30 | required: true 31 | - label: "⭐ Since I am showing interest in the project, I took the time to give the [repo](https://github.com/drprojects/superpoint_transformer/tree/master) a ⭐ to show support. **Please do, it means a lot to us !**" 32 | required: true 33 | 34 | - type: textarea 35 | attributes: 36 | label: 🚀 The feature, motivation and pitch 37 | description: > 38 | A clear and concise description of the feature proposal. Please outline 39 | the motivation for the proposal. Is your feature request related to a 40 | specific problem ? e.g., *"I'm working on X and would like Y to be 41 | possible"*. If this is related to another GitHub issue, please link here 42 | too. 43 | validations: 44 | required: true 45 | 46 | - type: textarea 47 | attributes: 48 | label: 🔀 Alternatives 49 | description: > 50 | A description of any alternative solutions or features you've 51 | considered, if any. 52 | 53 | - type: textarea 54 | attributes: 55 | label: 📚 Additional context 56 | description: > 57 | Add any other context or screenshots about the feature request. 58 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-documentation.yaml: -------------------------------------------------------------------------------- 1 | name: "📚 Typos and Documentation Fixes" 2 | description: Tell us about how we can improve our documentation 3 | title: Title of your documentation/typo fix/request 4 | labels: ["documentation"] 5 | 6 | body: 7 | 8 | - type: markdown 9 | attributes: 10 | value: Thanks for taking the time to fill out this documentation report 🙏 ! 11 | 12 | - type: checkboxes 13 | attributes: 14 | label: ✅ Code of conduct checklist 15 | description: > 16 | Before submitting a bug, please make sure you went through the following 17 | steps. 18 | options: 19 | - label: "🌱 I am using the **_latest version_** of the [code](https://github.com/drprojects/superpoint_transformer/tree/master)." 20 | required: true 21 | - label: "📙 I went through the [README](https://github.com/drprojects/superpoint_transformer/blob/master/README.md), but could not find the appropriate documentation there." 22 | required: true 23 | - label: "📘 I went through the tutorial [slides](media/superpoint_transformer_tutorial.pdf), [notebook](notebooks/superpoint_transformer_tutorial.ipynb), and [video](https://www.youtube.com/watch?v=2qKhpQs9gJw), but could not find the appropriate documentation there." 24 | required: true 25 | - label: "📗 I went through the [documentation](https://github.com/drprojects/superpoint_transformer/tree/master/docs), but could not find the appropriate documentation there." 26 | required: true 27 | - label: "📜 I went through the **_docstrings_** and **_comments_** in the [source code](https://github.com/drprojects/superpoint_transformer/tree/master) parts relevant to my problem, but could not find the appropriate documentation there." 28 | required: true 29 | - label: "👩‍🔧 I searched for [**_similar issues_**](https://github.com/drprojects/superpoint_transformer/issues), but could not find the appropriate documentation there." 30 | required: true 31 | - label: "⭐ Since I am showing interest in the project, I took the time to give the [repo](https://github.com/drprojects/superpoint_transformer/tree/master) a ⭐ to show support. **Please do, it means a lot to us !**" 32 | required: true 33 | 34 | - type: textarea 35 | attributes: 36 | label: 📚 Describe the documentation issue 37 | description: | 38 | A clear and concise description of the issue. 39 | validations: 40 | required: true 41 | 42 | - type: textarea 43 | attributes: 44 | label: Suggest a potential alternative/fix 45 | description: | 46 | Tell us how we could improve the documentation in this regard. 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-refactor.yaml: -------------------------------------------------------------------------------- 1 | name: "🛠 Refactor" 2 | description: Suggest a code refactor or deprecation 3 | title: Title of your refactor/deprecation report 4 | labels: refactor 5 | 6 | body: 7 | 8 | - type: markdown 9 | attributes: 10 | value: Thanks for taking the time to fill out this refactor report 🙏 ! 11 | 12 | - type: textarea 13 | attributes: 14 | label: 🛠 Proposed Refactor 15 | description: | 16 | A clear and concise description of the refactor proposal. Please outline 17 | the motivation for the proposal. If this is related to another GitHub 18 | issue, please link here too. 19 | validations: 20 | required: true 21 | 22 | - type: textarea 23 | attributes: 24 | label: Suggest a potential alternative/fix 25 | description: | 26 | Tell us how we could improve the code in this regard. 27 | validations: 28 | required: true 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | ignore: 13 | - dependency-name: "pytorch-lightning" 14 | update-types: ["version-update:semver-patch"] 15 | - dependency-name: "torchmetrics" 16 | update-types: ["version-update:semver-patch"] 17 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-main.yaml: -------------------------------------------------------------------------------- 1 | # Same as `code-quality-pr.yaml` but triggered on commit to main branch 2 | # and runs on all files (instead of only the changed ones) 3 | 4 | name: Code Quality Main 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | 10 | jobs: 11 | code-quality: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | 21 | - name: Run pre-commits 22 | uses: pre-commit/action@v2.0.3 23 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-pr.yaml: -------------------------------------------------------------------------------- 1 | # This workflow finds which files were changed, prints them, 2 | # and runs `pre-commit` on those files. 3 | 4 | # Inspired by the sktime library: 5 | # https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml 6 | 7 | name: Code Quality PR 8 | 9 | on: 10 | pull_request: 11 | branches: [main, "release/*"] 12 | 13 | jobs: 14 | code-quality: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | 24 | - name: Find modified files 25 | id: file_changes 26 | uses: trilom/file-changes-action@v1.2.4 27 | with: 28 | output: " " 29 | 30 | - name: List modified files 31 | run: echo '${{ steps.file_changes.outputs.files}}' 32 | 33 | - name: Run pre-commits 34 | uses: pre-commit/action@v2.0.3 35 | with: 36 | extra_args: --files ${{ steps.file_changes.outputs.files}} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | jobs: 10 | run_tests: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: ["ubuntu-latest", "macos-latest"] 17 | python-version: ["3.7", "3.8", "3.9", "3.10"] 18 | 19 | timeout-minutes: 10 20 | 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v3 24 | 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements.txt 34 | pip install pytest 35 | pip install sh 36 | 37 | - name: List dependencies 38 | run: | 39 | python -m pip list 40 | 41 | - name: Run pytest 42 | run: | 43 | pytest -v 44 | 45 | run_tests_windows: 46 | runs-on: ${{ matrix.os }} 47 | 48 | strategy: 49 | fail-fast: false 50 | matrix: 51 | os: ["windows-latest"] 52 | python-version: ["3.7", "3.8", "3.9", "3.10"] 53 | 54 | timeout-minutes: 10 55 | 56 | steps: 57 | - name: Checkout 58 | uses: actions/checkout@v3 59 | 60 | - name: Set up Python ${{ matrix.python-version }} 61 | uses: actions/setup-python@v3 62 | with: 63 | python-version: ${{ matrix.python-version }} 64 | 65 | - name: Install dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | pip install -r requirements.txt 69 | pip install pytest 70 | 71 | - name: List dependencies 72 | run: | 73 | python -m pip list 74 | 75 | - name: Run pytest 76 | run: | 77 | pytest -v 78 | 79 | # upload code coverage report 80 | code-coverage: 81 | runs-on: ubuntu-latest 82 | 83 | steps: 84 | - name: Checkout 85 | uses: actions/checkout@v2 86 | 87 | - name: Set up Python 3.10 88 | uses: actions/setup-python@v2 89 | with: 90 | python-version: "3.10" 91 | 92 | - name: Install dependencies 93 | run: | 94 | python -m pip install --upgrade pip 95 | pip install -r requirements.txt 96 | pip install pytest 97 | pip install pytest-cov[toml] 98 | pip install sh 99 | 100 | - name: Run tests and collect coverage 101 | run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER 102 | 103 | - name: Upload coverage to Codecov 104 | uses: codecov/codecov-action@v3 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | data/ 150 | logs/ 151 | .env 152 | .autoenv 153 | 154 | 155 | # Superpoint Transformer project 156 | .idea 157 | src/dependencies/parallel_cut_pursuit 158 | src/dependencies/grid_graph 159 | src/dependencies/FRNN 160 | src/dependencies/point_geometric_features 161 | src/dependencies/ 162 | *.pyc 163 | *DS_Store* 164 | *.vscode 165 | *.so 166 | *.cmake 167 | *__pycache__* 168 | *.ipynb* 169 | notebooks/* 170 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 22.6.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.10.1 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v2.32.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.4 44 | hooks: 45 | - id: docformatter 46 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 47 | 48 | # python check (PEP8), programming errors and code complexity 49 | - repo: https://github.com/PyCQA/flake8 50 | rev: 4.0.1 51 | hooks: 52 | - id: flake8 53 | args: 54 | [ 55 | "--extend-ignore", 56 | "E203,E402,E501,F401,F841", 57 | "--exclude", 58 | "logs/*,data/*", 59 | ] 60 | 61 | # python security linter 62 | - repo: https://github.com/PyCQA/bandit 63 | rev: "1.7.1" 64 | hooks: 65 | - id: bandit 66 | args: ["-s", "B101"] 67 | 68 | # yaml formatting 69 | - repo: https://github.com/pre-commit/mirrors-prettier 70 | rev: v2.7.1 71 | hooks: 72 | - id: prettier 73 | types: [yaml] 74 | 75 | # shell scripts linter 76 | - repo: https://github.com/shellcheck-py/shellcheck-py 77 | rev: v0.8.0.4 78 | hooks: 79 | - id: shellcheck 80 | 81 | # md formatting 82 | - repo: https://github.com/executablebooks/mdformat 83 | rev: 0.7.14 84 | hooks: 85 | - id: mdformat 86 | args: ["--number"] 87 | additional_dependencies: 88 | - mdformat-gfm 89 | - mdformat-tables 90 | - mdformat_frontmatter 91 | # - mdformat-toc 92 | # - mdformat-black 93 | 94 | # word spelling linter 95 | - repo: https://github.com/codespell-project/codespell 96 | rev: v2.1.0 97 | hooks: 98 | - id: codespell 99 | args: 100 | - --skip=logs/**,data/**,*.ipynb 101 | # - --ignore-words-list=abc,def 102 | 103 | # jupyter notebook cell output clearing 104 | - repo: https://github.com/kynan/nbstripout 105 | rev: 0.5.0 106 | hooks: 107 | - id: nbstripout 108 | 109 | # jupyter notebook linting 110 | - repo: https://github.com/nbQA-dev/nbQA 111 | rev: 1.4.0 112 | hooks: 113 | - id: nbqa-black 114 | args: ["--line-length=99"] 115 | - id: nbqa-isort 116 | args: ["--profile=black"] 117 | - id: nbqa-flake8 118 | args: 119 | [ 120 | "--extend-ignore=E203,E402,E501,F401,F841", 121 | "--exclude=logs/*,data/*", 122 | ] 123 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). 5 | 6 | ## \[2.1.0\] - 2024-11-07 7 | 8 | ### Added 9 | 10 | - Added a [CITATION.cff](CITATION.cff) 11 | - Added a [CHANGELOG.md](CHANGELOG.md) 12 | - Added support for serialization of `CSRBatch`, `Batch` and `NAGBatch` objects 13 | - Added support for inferring how to un-batch some `Batch` attributes, even if 14 | not present when `Batch.from_data_list()` was initially called 15 | - Added helper for S3DIS 6-fold metrics computation for semantic segmentation 16 | - Moved to `pgeof==0.3.0` 17 | - Released a Superpoint Transformer 🧑‍🏫 tutorial with 18 | [slides](media/superpoint_transformer_tutorial.pdf), 19 | [notebook](notebooks/superpoint_transformer_tutorial.ipynb), 20 | and [video](https://www.youtube.com/watch?v=2qKhpQs9gJw) 21 | - Added more documentation throughout the [docs](docs) and in the code 22 | - Added some documentation for our [interactive visualization tool](docs/visualization.md) 23 | 24 | ### Changed 25 | 26 | - Breaking Change: modified the serialization behavior of the data structures. 27 | You will need to re-run all your datasets' preprocessing 28 | - Remove `SampleSubNodes` from the validation and test transforms to ensure the 29 | validation and test forward passes are deterministic 30 | 31 | ### Deprecated 32 | 33 | ### Fixed 34 | 35 | - Fixed several bugs, some of which introduced by recent commits... 36 | - Fixed some installation issues 37 | 38 | ### Removed 39 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "Please cite our papers if you use this code in your own work." 3 | title: "Superpoint Transformer" 4 | authors: 5 | - family-names: "Robert" 6 | given-names: "Damien" 7 | - family-names: "Raguet" 8 | given-names: "Hugo" 9 | - family-names: "Landrieu" 10 | given-names: "Loic" 11 | date-released: 2023-06-15 12 | license: MIT 13 | url: "https://github.com/drprojects/superpoint_transformer" 14 | preferred-citation: 15 | type: conference-paper 16 | authors: 17 | - family-names: "Robert" 18 | given-names: "Damien" 19 | orcid: https://orcid.org/0000-0003-0983-7053 20 | - family-names: "Raguet" 21 | given-names: "Hugo" 22 | orcid: https://orcid.org/0000-0002-4598-6710 23 | - family-names: "Landrieu" 24 | given-names: "Loic" 25 | orcid: https://orcid.org/0000-0002-7738-8141 26 | conference: "ICCV" 27 | title: "Efficient 3D Semantic Segmentation with Superpoint Transformer" 28 | year: 2023 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Damien ROBERT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - model_summary.yaml 5 | - rich_progress_bar.yaml 6 | - lr_monitor.yaml 7 | - gradient_accumulator.yaml 8 | - _self_ 9 | 10 | model_checkpoint: 11 | dirpath: ${paths.output_dir}/checkpoints 12 | filename: "epoch_{epoch:03d}" 13 | monitor: ${optimized_metric} 14 | mode: "max" 15 | save_last: True 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | monitor: ${optimized_metric} 20 | patience: 500 21 | mode: "max" 22 | 23 | model_summary: 24 | max_depth: -1 25 | 26 | gradient_accumulator: 27 | scheduling: 28 | 0: 1 29 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: pytorch_lightning.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 3 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /configs/callbacks/gradient_accumulator.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html 2 | 3 | # Accumulate gradients across multiple batches, to use smaller batches 4 | # Scheduling expects a dictionary of {epoch: num_batch} indicating how 5 | # to accumulate gradients 6 | gradient_accumulator: 7 | _target_: pytorch_lightning.callbacks.GradientAccumulationScheduler 8 | scheduling: 9 | 0: 2 10 | -------------------------------------------------------------------------------- /configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html 2 | 3 | # Monitor and log the learning rate as the training goes 4 | lr_monitor: 5 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 6 | logging_interval: 'epoch' # supports 'epoch', 'step', and null 7 | log_momentum: True 8 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html 2 | 3 | # Save the model periodically by monitoring a quantity. 4 | # Look at the above link for more detailed information. 5 | model_checkpoint: 6 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 7 | dirpath: null # directory to save the model file 8 | filename: null # checkpoint filename 9 | monitor: null # name of the logged metric which determines when model is improving 10 | verbose: False # verbosity mode 11 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: 1 # save k best models (determined by above metric) 13 | mode: "min" # "max" means higher metric value is better, can be also "min" 14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 15 | save_weights_only: False # if True, then only the model’s weights will be saved 16 | every_n_train_steps: null # number of training steps between checkpoints 17 | train_time_interval: null # checkpoints are monitored at the specified time interval 18 | every_n_epochs: null # number of epochs between checkpoints 19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 20 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: pytorch_lightning.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 20 # maximum distance of neighbors for each superpoint in the instance graph -------------------------------------------------------------------------------- /configs/datamodule/panoptic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 20 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 8 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 8 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_room.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/s3dis_with_stuff.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Specify whether S3DIS should have only 'thing' classes (default) or if 11 | # 'ceiling', 'wall', and 'floor' should be treated as 'stuff' 12 | with_stuff: True 13 | 14 | # For now, we also need to specify the stuff labels here, not for the 15 | # datamodule, but rather for the model config to catch 16 | stuff_classes: [0, 1, 2] 17 | 18 | # Instance graph parameters 19 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 20 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 21 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/s3dis_with_stuff_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Specify whether S3DIS should have only 'thing' classes (default) or if 11 | # 'ceiling', 'wall', and 'floor' should be treated as 'stuff' 12 | with_stuff: True 13 | 14 | # For now, we also need to specify the stuff labels here, not for the 15 | # datamodule, but rather for the model config to catch 16 | stuff_classes: [0, 1, 2] 17 | 18 | # Instance graph parameters 19 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 20 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 21 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/panoptic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /configs/datamodule/semantic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'intensity' 11 | - 'linearity' 12 | - 'planarity' 13 | - 'scattering' 14 | - 'verticality' 15 | - 'elevation' 16 | -------------------------------------------------------------------------------- /configs/datamodule/semantic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'hsv' 11 | - 'linearity' 12 | - 'planarity' 13 | - 'scattering' 14 | - 'verticality' 15 | - 'elevation' 16 | -------------------------------------------------------------------------------- /configs/datamodule/semantic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'linearity' 11 | - 'planarity' 12 | - 'scattering' 13 | - 'verticality' 14 | - 'elevation' 15 | - 'rgb' 16 | -------------------------------------------------------------------------------- /configs/datamodule/semantic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Room-wise learning on the S3DIS dataset 6 | _target_: src.datamodules.s3dis_room.S3DISRoomDataModule 7 | 8 | dataloader: 9 | batch_size: 8 10 | 11 | sample_graph_k: -1 # skip subgraph sampling; to directly use the whole room 12 | -------------------------------------------------------------------------------- /configs/datamodule/semantic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'linearity' 11 | - 'planarity' 12 | - 'scattering' 13 | - 'verticality' 14 | - 'elevation' 15 | - 'rgb' 16 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | dataloader: 35 | num_workers: 0 # debuggers don't like multiprocessing 36 | pin_memory: False # disable gpu memory pin 37 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: semantic/s3dis.yaml 6 | - model: semantic/spt-2.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | # experiment configs allow for version control of specific hyperparameters 14 | # e.g. best hyperparameters for given model and datamodule 15 | - experiment: null 16 | 17 | # optional local config for machine/user specific settings 18 | # it's optional since it doesn't need to exist and is excluded from version control 19 | - optional local: default.yaml 20 | 21 | task_name: "eval" 22 | 23 | tags: ["dev"] 24 | 25 | # compile model for faster training with pytorch >=2.1.0 26 | compile: False 27 | 28 | # passing checkpoint path is necessary for evaluation 29 | ckpt_path: ??? 30 | 31 | # float32 precision operations (torch>=2.0) 32 | # see https://pytorch.org/docs/2.0/generated/torch.set_float32_matmul_precision.html 33 | float32_matmul_precision: high 34 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales 5 | 6 | defaults: 7 | - override /datamodule: panoptic/dales.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 100 26 | 27 | edge_affinity_loss_lambda: 10 28 | 29 | partition_every_n_epoch: 10 30 | 31 | logger: 32 | wandb: 33 | project: "spt_dales" 34 | name: "SPT-64" 35 | 36 | # metric based on which models will be selected 37 | optimized_metric: "val/pq" 38 | 39 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 40 | # being potentially different 41 | callbacks: 42 | model_checkpoint: 43 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 44 | 45 | early_stopping: 46 | strict: False 47 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/dales_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/dales configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/dales.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 5 # split each cloud into xy_tiling²=25 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 288 # to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25 35 | 36 | model: 37 | optimizer: 38 | lr: 0.01 39 | weight_decay: 1e-4 40 | 41 | partitioner: 42 | regularization: 20 43 | x_weight: 5e-2 44 | cutoff: 100 45 | 46 | edge_affinity_loss_lambda: 10 47 | 48 | partition_every_n_epoch: 10 49 | 50 | logger: 51 | wandb: 52 | project: "spt_dales" 53 | name: "SPT-64" 54 | 55 | # metric based on which models will be selected 56 | optimized_metric: "val/pq" 57 | 58 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 59 | # being potentially different 60 | callbacks: 61 | model_checkpoint: 62 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 63 | 64 | early_stopping: 65 | strict: False 66 | 67 | gradient_accumulator: 68 | scheduling: 69 | 0: 70 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 71 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/dales_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 100 26 | 27 | edge_affinity_loss_lambda: 10 28 | 29 | partition_every_n_epoch: 10 30 | 31 | logger: 32 | wandb: 33 | project: "spt_dales" 34 | name: "NANO" 35 | 36 | # metric based on which models will be selected 37 | optimized_metric: "val/pq" 38 | 39 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 40 | # being potentially different 41 | callbacks: 42 | model_checkpoint: 43 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 44 | 45 | early_stopping: 46 | strict: False 47 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360 5 | 6 | defaults: 7 | - override /datamodule: panoptic/kitti360.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 128, 128, 128, 128 ] 23 | _up_dim: [ 128, 128, 128 ] 24 | 25 | net: 26 | no_ffn: False 27 | down_ffn_ratio: 1 28 | 29 | 30 | partitioner: 31 | regularization: 10 32 | x_weight: 5e-2 33 | cutoff: 1 34 | 35 | partition_every_n_epoch: 10 36 | 37 | logger: 38 | wandb: 39 | project: "spt_kitti360" 40 | name: "SPT-128" 41 | 42 | # metric based on which models will be selected 43 | optimized_metric: "val/pq" 44 | 45 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 46 | # being potentially different 47 | callbacks: 48 | model_checkpoint: 49 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 50 | 51 | early_stopping: 52 | strict: False 53 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/kitti360_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/kitti360 configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | 22 | defaults: 23 | - override /datamodule: panoptic/kitti360.yaml 24 | - override /model: panoptic/spt-2.yaml 25 | - override /trainer: gpu.yaml 26 | 27 | # all parameters below will be merged with parameters from default configurations set above 28 | # this allows you to overwrite only specified parameters 29 | 30 | datamodule: 31 | pc_tiling: 2 # split each cloud into 2^pc_tiling=4 tiles, based on their principal components. Reduces preprocessing- and inference-time GPU memory 32 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 33 | 34 | trainer: 35 | max_epochs: 100 # to keep same nb of steps: 4x more tiles, 2-step gradient accumulation -> epochs/2 36 | 37 | model: 38 | optimizer: 39 | lr: 0.01 40 | weight_decay: 1e-4 41 | 42 | _down_dim: [ 128, 128, 128, 128 ] 43 | _up_dim: [ 128, 128, 128 ] 44 | 45 | net: 46 | no_ffn: False 47 | down_ffn_ratio: 1 48 | 49 | partitioner: 50 | regularization: 10 51 | x_weight: 5e-2 52 | cutoff: 1 53 | 54 | partition_every_n_epoch: 10 55 | 56 | logger: 57 | wandb: 58 | project: "spt_kitti360" 59 | name: "SPT-128" 60 | 61 | # metric based on which models will be selected 62 | optimized_metric: "val/pq" 63 | 64 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 65 | # being potentially different 66 | callbacks: 67 | model_checkpoint: 68 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 69 | 70 | early_stopping: 71 | strict: False 72 | 73 | gradient_accumulator: 74 | scheduling: 75 | 0: 76 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 77 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/kitti360_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 32, 32, 32, 32 ] 23 | _up_dim: [ 32, 32, 32 ] 24 | _node_mlp_out: 32 25 | _h_edge_mlp_out: 32 26 | 27 | partitioner: 28 | regularization: 10 29 | x_weight: 5e-2 30 | cutoff: 1 31 | 32 | partition_every_n_epoch: 10 33 | 34 | logger: 35 | wandb: 36 | project: "spt_kitti360" 37 | name: "NANO-32" 38 | 39 | # metric based on which models will be selected 40 | optimized_metric: "val/pq" 41 | 42 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 43 | # being potentially different 44 | callbacks: 45 | model_checkpoint: 46 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 47 | 48 | early_stopping: 49 | strict: False 50 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/s3dis.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 35 | 36 | model: 37 | optimizer: 38 | lr: 0.1 39 | weight_decay: 1e-2 40 | 41 | partitioner: 42 | regularization: 20 43 | x_weight: 5e-2 44 | cutoff: 300 45 | 46 | partition_every_n_epoch: 5 47 | 48 | logger: 49 | wandb: 50 | project: "spt_s3dis" 51 | name: "SPT-64" 52 | 53 | # metric based on which models will be selected 54 | optimized_metric: "val/pq" 55 | 56 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 57 | # being potentially different 58 | callbacks: 59 | model_checkpoint: 60 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 61 | 62 | early_stopping: 63 | strict: False 64 | 65 | gradient_accumulator: 66 | scheduling: 67 | 0: 68 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 69 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "NANO" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_room 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_room.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis_room" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_with_stuff.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_with_stuff.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 10 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_with_stuff_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/s3dis_with_stuff.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 35 | 36 | model: 37 | optimizer: 38 | lr: 0.1 39 | weight_decay: 1e-2 40 | 41 | partitioner: 42 | regularization: 10 43 | x_weight: 5e-2 44 | cutoff: 300 45 | 46 | partition_every_n_epoch: 5 47 | 48 | logger: 49 | wandb: 50 | project: "spt_s3dis" 51 | name: "SPT-64" 52 | 53 | # metric based on which models will be selected 54 | optimized_metric: "val/pq" 55 | 56 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 57 | # being potentially different 58 | callbacks: 59 | model_checkpoint: 60 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 61 | 62 | early_stopping: 63 | strict: False 64 | 65 | gradient_accumulator: 66 | scheduling: 67 | 0: 68 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 69 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/s3dis_with_stuff_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_with_stuff_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 10 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "NANO" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet 5 | 6 | defaults: 7 | - override /datamodule: panoptic/scannet.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 64 27 | _h_edge_mlp_out: 64 28 | _down_dim: [ 128, 128, 128, 128 ] 29 | _up_dim: [ 128, 128, 128 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | down_num_heads: 32 34 | 35 | partitioner: 36 | regularization: 20 37 | x_weight: 5e-2 38 | cutoff: 300 39 | 40 | edge_affinity_loss_lambda: 10 41 | 42 | partition_every_n_epoch: 4 43 | 44 | logger: 45 | wandb: 46 | project: "spt_scannet" 47 | name: "SPT-128" 48 | 49 | # metric based on which models will be selected 50 | optimized_metric: "val/pq" 51 | 52 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 53 | # being potentially different 54 | callbacks: 55 | model_checkpoint: 56 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 57 | 58 | early_stopping: 59 | strict: False 60 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/scannet_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/panoptic/scannet configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - reduce the number of samples in each batch (facilitates training 11 | # on smaller GPUs) 12 | # To keep the total number of training steps consistent with the default 13 | # configuration, while keeping informative gradient despite the smaller 14 | # batches, we use gradient accumulation and reduce the number of epochs. 15 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 16 | # (more raw data reading steps), and slightly reduce mode performance 17 | # (less diversity in the spherical samples) 18 | 19 | defaults: 20 | - override /datamodule: panoptic/scannet.yaml 21 | - override /model: panoptic/spt-2.yaml 22 | - override /trainer: gpu.yaml 23 | 24 | # all parameters below will be merged with parameters from default configurations set above 25 | # this allows you to overwrite only specified parameters 26 | 27 | datamodule: 28 | dataloader: 29 | batch_size: 1 30 | 31 | callbacks: 32 | model_checkpoint: 33 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 34 | 35 | early_stopping: 36 | strict: False 37 | 38 | gradient_accumulator: 39 | scheduling: 40 | 0: 41 | 4 # accumulate gradient every 4 batches, to make up for reduced batch size 42 | 43 | trainer: 44 | max_epochs: 100 # to keep the same number of steps -> epochs unchanged 45 | check_val_every_n_epoch: 2 46 | 47 | model: 48 | optimizer: 49 | lr: 0.01 50 | weight_decay: 1e-4 51 | 52 | scheduler: 53 | num_warmup: 2 54 | 55 | _node_mlp_out: 64 56 | _h_edge_mlp_out: 64 57 | _down_dim: [ 128, 128, 128, 128 ] 58 | _up_dim: [ 128, 128, 128 ] 59 | net: 60 | no_ffn: False 61 | down_ffn_ratio: 1 62 | down_num_heads: 32 63 | 64 | partitioner: 65 | regularization: 20 66 | x_weight: 5e-2 67 | cutoff: 300 68 | 69 | edge_affinity_loss_lambda: 10 70 | 71 | partition_every_n_epoch: 4 72 | 73 | logger: 74 | wandb: 75 | project: "spt_scannet" 76 | name: "SPT-128" 77 | 78 | # metric based on which models will be selected 79 | optimized_metric: "val/pq" 80 | -------------------------------------------------------------------------------- /configs/experiment/panoptic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/scannet_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | scheduler: 23 | num_warmup: 2 24 | 25 | _node_mlp_out: 32 26 | _h_edge_mlp_out: 32 27 | _down_dim: [ 32, 32, 32, 32 ] 28 | _up_dim: [ 32, 32, 32 ] 29 | net: 30 | no_ffn: False 31 | down_ffn_ratio: 1 32 | 33 | partitioner: 34 | regularization: 20 35 | x_weight: 5e-2 36 | cutoff: 300 37 | 38 | edge_affinity_loss_lambda: 10 39 | 40 | partition_every_n_epoch: 4 41 | 42 | logger: 43 | wandb: 44 | project: "spt_scannet" 45 | name: "NANO" 46 | 47 | # metric based on which models will be selected 48 | optimized_metric: "val/pq" 49 | 50 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 51 | # being potentially different 52 | callbacks: 53 | model_checkpoint: 54 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 55 | 56 | early_stopping: 57 | strict: False 58 | -------------------------------------------------------------------------------- /configs/experiment/semantic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales 5 | 6 | defaults: 7 | - override /datamodule: semantic/dales.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | logger: 23 | wandb: 24 | project: "spt_dales" 25 | name: "SPT-64" -------------------------------------------------------------------------------- /configs/experiment/semantic/dales_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/dales configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: semantic/dales.yaml 23 | - override /model: semantic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 5 # split each cloud into xy_tiling²=25 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | callbacks: 34 | gradient_accumulator: 35 | scheduling: 36 | 0: 37 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 38 | 39 | trainer: 40 | max_epochs: 288 # to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25 41 | 42 | model: 43 | optimizer: 44 | lr: 0.01 45 | weight_decay: 1e-4 46 | 47 | logger: 48 | wandb: 49 | project: "spt_dales" 50 | name: "SPT-64" 51 | -------------------------------------------------------------------------------- /configs/experiment/semantic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/dales_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | logger: 23 | wandb: 24 | project: "spt_dales" 25 | name: "NANO" 26 | -------------------------------------------------------------------------------- /configs/experiment/semantic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360 5 | 6 | defaults: 7 | - override /datamodule: semantic/kitti360.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 128, 128, 128, 128 ] 23 | _up_dim: [ 128, 128, 128 ] 24 | 25 | net: 26 | no_ffn: False 27 | down_ffn_ratio: 1 28 | 29 | logger: 30 | wandb: 31 | project: "spt_kitti360" 32 | name: "SPT-128" 33 | -------------------------------------------------------------------------------- /configs/experiment/semantic/kitti360_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/kitti360 configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | 22 | defaults: 23 | - override /datamodule: semantic/kitti360.yaml 24 | - override /model: semantic/spt-2.yaml 25 | - override /trainer: gpu.yaml 26 | 27 | # all parameters below will be merged with parameters from default configurations set above 28 | # this allows you to overwrite only specified parameters 29 | 30 | datamodule: 31 | pc_tiling: 2 # split each cloud into 2^pc_tiling=4 tiles, based on their principal components. Reduces preprocessing- and inference-time GPU memory 32 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 33 | 34 | callbacks: 35 | gradient_accumulator: 36 | scheduling: 37 | 0: 38 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 39 | 40 | trainer: 41 | max_epochs: 100 # to keep same nb of steps: 4x more tiles, 2-step gradient accumulation -> epochs/2 42 | 43 | model: 44 | optimizer: 45 | lr: 0.01 46 | weight_decay: 1e-4 47 | 48 | _down_dim: [ 128, 128, 128, 128 ] 49 | _up_dim: [ 128, 128, 128 ] 50 | 51 | net: 52 | no_ffn: False 53 | down_ffn_ratio: 1 54 | 55 | logger: 56 | wandb: 57 | project: "spt_kitti360" 58 | name: "SPT-128" 59 | -------------------------------------------------------------------------------- /configs/experiment/semantic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/kitti360_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 32, 32, 32, 32 ] 23 | _up_dim: [ 32, 32, 32 ] 24 | _node_mlp_out: 32 25 | _h_edge_mlp_out: 32 26 | 27 | logger: 28 | wandb: 29 | project: "spt_kitti360" 30 | name: "NANO-32" 31 | -------------------------------------------------------------------------------- /configs/experiment/semantic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis" 25 | name: "SPT-64" 26 | -------------------------------------------------------------------------------- /configs/experiment/semantic/s3dis_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: semantic/s3dis.yaml 23 | - override /model: semantic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | callbacks: 34 | gradient_accumulator: 35 | scheduling: 36 | 0: 37 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 38 | 39 | trainer: 40 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 41 | 42 | model: 43 | optimizer: 44 | lr: 0.1 45 | weight_decay: 1e-2 46 | 47 | logger: 48 | wandb: 49 | project: "spt_s3dis" 50 | name: "SPT-64" 51 | -------------------------------------------------------------------------------- /configs/experiment/semantic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis" 25 | name: "NANO" 26 | -------------------------------------------------------------------------------- /configs/experiment/semantic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_room 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis_room.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis_room" 25 | name: "SPT-64" 26 | -------------------------------------------------------------------------------- /configs/experiment/semantic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet 5 | 6 | defaults: 7 | - override /datamodule: semantic/scannet.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 64 27 | _h_edge_mlp_out: 64 28 | _down_dim: [ 128, 128, 128, 128 ] 29 | _up_dim: [ 128, 128, 128 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | down_num_heads: 32 34 | 35 | 36 | logger: 37 | wandb: 38 | project: "spt_scannet" 39 | name: "SPT-128" 40 | -------------------------------------------------------------------------------- /configs/experiment/semantic/scannet_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/scannet configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - reduce the number of samples in each batch (facilitates training 11 | # on smaller GPUs) 12 | # To keep the total number of training steps consistent with the default 13 | # configuration, while keeping informative gradient despite the smaller 14 | # batches, we use gradient accumulation and reduce the number of epochs. 15 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 16 | # (more raw data reading steps), and slightly reduce mode performance 17 | # (less diversity in the spherical samples) 18 | 19 | defaults: 20 | - override /datamodule: semantic/scannet.yaml 21 | - override /model: semantic/spt-2.yaml 22 | - override /trainer: gpu.yaml 23 | 24 | # all parameters below will be merged with parameters from default configurations set above 25 | # this allows you to overwrite only specified parameters 26 | 27 | datamodule: 28 | dataloader: 29 | batch_size: 1 30 | 31 | callbacks: 32 | gradient_accumulator: 33 | scheduling: 34 | 0: 35 | 4 # accumulate gradient every 4 batches, to make up for reduced batch size 36 | 37 | trainer: 38 | max_epochs: 100 # to keep the same number of steps -> epochs unchanged 39 | check_val_every_n_epoch: 2 40 | 41 | model: 42 | optimizer: 43 | lr: 0.01 44 | weight_decay: 1e-4 45 | 46 | scheduler: 47 | num_warmup: 2 48 | 49 | _node_mlp_out: 64 50 | _h_edge_mlp_out: 64 51 | _down_dim: [ 128, 128, 128, 128 ] 52 | _up_dim: [ 128, 128, 128 ] 53 | net: 54 | no_ffn: False 55 | down_ffn_ratio: 1 56 | down_num_heads: 32 57 | 58 | logger: 59 | wandb: 60 | project: "spt_scannet" 61 | name: "SPT-128" 62 | -------------------------------------------------------------------------------- /configs/experiment/semantic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/scannet_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 32 27 | _h_edge_mlp_out: 32 28 | _down_dim: [ 32, 32, 32, 32 ] 29 | _up_dim: [ 32, 32, 32 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | 34 | logger: 35 | wandb: 36 | project: "spt_scannet" 37 | name: "NANO" 38 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "superpoint_transformer" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/panoptic/_instance.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | _target_: src.models.panoptic.PanopticSegmentationModule 4 | 5 | # Stuff class indices must be specified for instantiation. 6 | # Concretely, stuff_classes is recovered from the datamodule config 7 | stuff_classes: ${datamodule.stuff_classes} 8 | 9 | # Minimum size for an instance to be taken into account in the instance 10 | # segmentation metrics 11 | min_instance_size: ${datamodule.min_instance_size} 12 | 13 | # Make the point encoder slightly smaller than for the default SPT. This 14 | # may slightly affect semantic segmentation results but allows fitting 15 | # into 32G with the edge affinity head 16 | _point_mlp: [32, 64, 64] # point encoder layers 17 | 18 | # Edge affinity prediction head for instance/panoptic graph clustering. 19 | # Importantly, we pass `dims` to characterize the MLP layers. The size 20 | # of the first layer is directly computed from the config 21 | edge_affinity_head: 22 | _target_: src.nn.MLP 23 | dims: ${eval:'[ ${model._up_dim}[-1] * 2, 32, 16, 1 ]'} 24 | activation: 25 | _target_: torch.nn.LeakyReLU 26 | norm: null 27 | last_norm: False 28 | last_activation: False 29 | 30 | # Instance/panoptic partitioner module. See the `InstancePartitioner` 31 | # documentation for more details on the available parameters 32 | partitioner: 33 | _target_: src.nn.instance.InstancePartitioner 34 | 35 | # Frequency at which the partition should be computed. If lower or equal 36 | # to 0, the partition will only be computed at the last training epoch 37 | partition_every_n_epoch: -1 38 | 39 | # If True, the instance metrics will never be computed. If only panoptic 40 | # metrics are of interest, this can save considerable training and 41 | # evaluation time, as instance metrics computation is relatively slow 42 | no_instance_metrics: True 43 | 44 | # If True, the instance segmentation metrics will not be computed on the 45 | # train set. This allows saving some computation and training time 46 | no_instance_metrics_on_train_set: True 47 | 48 | # Edge affinity loss 49 | edge_affinity_criterion: 50 | _target_: src.loss.BCEWithLogitsLoss 51 | weight: null 52 | 53 | # Weights for insisting on certain cases in the edge affinity loss: 54 | # - 0: same-class same-object edges 55 | # - 1: same-class different-object edges 56 | # - 2: different-class same-object edges 57 | # - 3: different-class different-object edges 58 | edge_affinity_loss_weights: [1, 1, 1, 1] 59 | 60 | # Node offset loss 61 | node_offset_criterion: 62 | _target_: src.loss.WeightedL2Loss 63 | 64 | # Weights for combining the semantic segmentation loss with the node 65 | # offset and edge affinity losses 66 | edge_affinity_loss_lambda: 1 67 | node_offset_loss_lambda: 1 68 | -------------------------------------------------------------------------------- /configs/model/panoptic/nano-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/nano-2.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /configs/model/panoptic/nano-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/nano-3.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /configs/model/panoptic/spt-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-2.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /configs/model/panoptic/spt-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-3.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /configs/model/panoptic/spt.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /configs/model/semantic/_attention.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the attention blocks 6 | net: 7 | activation: 8 | _target_: torch.nn.LeakyReLU 9 | norm: 10 | _target_: src.nn.GraphNorm 11 | _partial_: True 12 | pre_norm: True 13 | no_sa: False 14 | no_ffn: True 15 | qk_dim: 4 16 | qkv_bias: True 17 | qk_scale: null 18 | in_rpe_dim: ${eval:'${model._h_edge_mlp_out} if ${model._h_edge_mlp_out} else ${model._h_edge_hf_dim}'} 19 | k_rpe: True 20 | q_rpe: True 21 | v_rpe: True 22 | k_delta_rpe: False 23 | q_delta_rpe: False 24 | qk_share_rpe: False 25 | q_on_minus_rpe: False 26 | stages_share_rpe: False 27 | blocks_share_rpe: False 28 | heads_share_rpe: False 29 | -------------------------------------------------------------------------------- /configs/model/semantic/_down.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the encoder 6 | net: 7 | down_dim: ${model._down_dim} 8 | down_pool_dim: ${eval:'[${model._point_mlp}[-1]] + ${model._down_dim}[:-1]'} 9 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[1]] + [${model._down_dim}[2]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[2]] + [${model._down_dim}[3]] * ${model._mlp_depth} ]'} 10 | down_out_mlp: null 11 | down_mlp_drop: null 12 | down_num_heads: 16 13 | down_num_blocks: 3 14 | down_ffn_ratio: 1 15 | down_residual_drop: null 16 | down_attn_drop: null 17 | down_drop_path: null 18 | -------------------------------------------------------------------------------- /configs/model/semantic/_point.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the point encoder 6 | net: 7 | point_mlp: ${eval:'[${model._point_hf_dim}] + ${model._point_mlp}'} 8 | point_drop: null 9 | -------------------------------------------------------------------------------- /configs/model/semantic/_up.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the decoder 6 | net: 7 | up_dim: ${model._up_dim} 8 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[0] + ${model._down_dim}[-3]] + [${model._up_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[1] + ${model._down_dim}[-4]] + [${model._up_dim}[2]] * ${model._mlp_depth} ]'} 9 | up_out_mlp: ${model.net.down_out_mlp} 10 | up_mlp_drop: ${model.net.down_mlp_drop} 11 | up_num_heads: ${model.net.down_num_heads} 12 | up_num_blocks: 1 13 | up_ffn_ratio: ${model.net.down_ffn_ratio} 14 | up_residual_drop: ${model.net.down_residual_drop} 15 | up_attn_drop: ${model.net.down_attn_drop} 16 | up_drop_path: ${model.net.down_drop_path} 17 | -------------------------------------------------------------------------------- /configs/model/semantic/default.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | _target_: src.models.semantic.SemanticSegmentationModule 4 | 5 | num_classes: ${datamodule.num_classes} 6 | sampling_loss: False 7 | loss_type: 'ce_kl' # supports 'ce', 'wce', 'kl', 'ce_kl', 'wce_kl' 8 | weighted_loss: True 9 | init_linear: null # defaults to xavier_uniform initialization 10 | init_rpe: null # defaults to xavier_uniform initialization 11 | multi_stage_loss_lambdas: [1, 50] # weights for the multi-stage loss 12 | transformer_lr_scale: 0.1 13 | gc_every_n_steps: 0 14 | 15 | # Every N epoch, the model may store to disk predictions for some 16 | # tracked validation batch of interest. This assumes the validation 17 | # dataloader is non-stochastic. Additionally, the model may store to 18 | # disk predictions for some or all the test batches. 19 | track_val_every_n_epoch: 10 # trigger the tracking every N epoch 20 | track_val_idx: null # index of the validation batch to track. If -1, all the validation batches will be tracked, at every `track_val_every_n_epoch` epoch 21 | track_test_idx: null # index of the test batch to track. If -1, all the test batches will be tracked 22 | 23 | optimizer: 24 | _target_: torch.optim.AdamW 25 | _partial_: True 26 | lr: 0.01 27 | weight_decay: 1e-4 28 | 29 | scheduler: 30 | _target_: src.optim.CosineAnnealingLRWithWarmup 31 | _partial_: True 32 | T_max: ${eval:'${trainer.max_epochs} - ${model.scheduler.num_warmup}'} 33 | eta_min: 1e-6 34 | warmup_init_lr: 1e-6 35 | num_warmup: 20 36 | warmup_strategy: 'cos' 37 | 38 | criterion: 39 | _target_: torch.nn.CrossEntropyLoss 40 | ignore_index: ${datamodule.num_classes} 41 | 42 | # Parameters declared here to facilitate tuning configs. Those are only 43 | # used here for config interpolation but will/should actually fall in 44 | # the ignored kwargs of the SemanticSegmentationModule 45 | _point_mlp: [32, 64, 128] # point encoder layers 46 | _node_mlp_out: 32 # size of level-1+ handcrafted node features after MLP, set to 'null' to use directly the raw features 47 | _h_edge_mlp_out: 32 # size of level-1+ handcrafted horizontal edge features after MLP, set to 'null' to use directly the raw features 48 | _v_edge_mlp_out: 32 # size of level-1+ handcrafted vertical edge features after MLP, set to 'null' to use directly the raw features 49 | 50 | _point_hf_dim: ${eval:'${model.net.use_pos} * 3 + ${datamodule.num_hf_point} + ${model.net.use_diameter_parent}'} # size of handcrafted level-0 node features (points) 51 | _node_hf_dim: ${eval:'${model.net.use_node_hf} * ${datamodule.num_hf_segment}'} # size of handcrafted level-1+ node features before node MLP 52 | _node_injection_dim: ${eval:'${model.net.use_pos} * 3 + ${model.net.use_diameter} + ${model.net.use_diameter_parent} + (${model._node_mlp_out} if ${model._node_mlp_out} and ${model.net.use_node_hf} and ${model._node_hf_dim} > 0 else ${model._node_hf_dim})'} # size of parent level-1+ node features for Stage injection input 53 | _h_edge_hf_dim: ${datamodule.num_hf_edge} # size of level-1+ handcrafted horizontal edge features 54 | _v_edge_hf_dim: ${datamodule.num_hf_v_edge} # size of level-1+ handcrafted vertical edge features 55 | 56 | _down_dim: [64, 64, 64, 64] # encoder stage dimensions 57 | _up_dim: [64, 64, 64] # decoder stage dimensions 58 | _mlp_depth: 2 # default nb of layers in all MLPs (i.e. MLP depth) 59 | 60 | net: ??? 61 | -------------------------------------------------------------------------------- /configs/model/semantic/nano-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-2.yaml 4 | 5 | _down_dim: [16, 16, 16, 16] 6 | _up_dim: [16, 16, 16] 7 | _node_mlp_out: 16 8 | _h_edge_mlp_out: 16 9 | 10 | net: 11 | nano: True 12 | qk_dim: 2 13 | -------------------------------------------------------------------------------- /configs/model/semantic/nano-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-3.yaml 4 | 5 | net: 6 | nano: True 7 | -------------------------------------------------------------------------------- /configs/model/semantic/spt-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | 5 | net: 6 | down_dim: ${eval:'${model._down_dim}[:2]'} 7 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth} ]'} 8 | up_dim: ${eval:'${model._up_dim}[:1]'} 9 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth} ]'} 10 | -------------------------------------------------------------------------------- /configs/model/semantic/spt-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | 5 | net: 6 | down_dim: ${eval:'${model._down_dim}[:3]'} 7 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[1]] + [${model._down_dim}[2]] * ${model._mlp_depth} ]'} 8 | up_dim: ${eval:'${model._up_dim}[:2]'} 9 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[0] + ${model._down_dim}[-3]] + [${model._up_dim}[1]] * ${model._mlp_depth} ]'} 10 | -------------------------------------------------------------------------------- /configs/model/semantic/spt.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | - /model/semantic/_point.yaml 5 | - /model/semantic/_down.yaml 6 | - /model/semantic/_up.yaml 7 | - /model/semantic/_attention.yaml 8 | 9 | net: 10 | _target_: src.models.components.spt.SPT 11 | 12 | nano: False 13 | node_mlp: ${eval:'[${model._node_hf_dim}] + [${model._node_mlp_out}] * ${model._mlp_depth} if ${model._node_mlp_out} and ${model.net.use_node_hf} and ${model._node_hf_dim} > 0 else None'} 14 | h_edge_mlp: ${eval:'[${model._h_edge_hf_dim}] + [${model._h_edge_mlp_out}] * ${model._mlp_depth} if ${model._h_edge_mlp_out} else None'} 15 | v_edge_mlp: ${eval:'[${model._v_edge_hf_dim}] + [${model._v_edge_mlp_out}] * ${model._mlp_depth} if ${model._v_edge_mlp_out} else None'} 16 | share_hf_mlps: False 17 | mlp_activation: 18 | _target_: torch.nn.LeakyReLU 19 | mlp_norm: 20 | _target_: src.nn.GraphNorm 21 | _partial_: True 22 | 23 | use_pos: True # whether features should include position (with unit-sphere normalization wrt siblings) 24 | use_node_hf: True # whether features should include node handcrafted features (after optional node_mlp, if features are actually loaded by the datamodule) 25 | use_diameter: False # whether features should include the superpoint's diameter (from unit-sphere normalization wrt siblings) 26 | use_diameter_parent: True # whether features should include diameter of the superpoint's parent (from unit-sphere normalization wrt siblings) 27 | pool: 'max' # pooling across the cluster, supports 'max', 'mean', 'min' 28 | unpool: 'index' 29 | fusion: 'cat' 30 | norm_mode: 'graph' 31 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - datamodule: semantic/s3dis.yaml 8 | - model: semantic/spt-2.yaml 9 | - callbacks: default.yaml 10 | - logger: wandb.yaml # null for default, set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. the best hyperparameters for a given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # metric based on which models will be selected 34 | optimized_metric: "val/miou" 35 | 36 | # tags to help you identify your experiments 37 | # you can overwrite this in experiment configs 38 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 39 | # appending lists from command line is currently not supported :( 40 | # https://github.com/facebookresearch/hydra/issues/1547 41 | tags: ["dev"] 42 | 43 | # set False to skip model training 44 | train: True 45 | 46 | # evaluate on test set, using best model weights achieved during training 47 | # lightning chooses best weights based on the metric specified in checkpoint callback 48 | test: True 49 | 50 | # compile model for faster training with pytorch >=2.1.0 51 | compile: False 52 | 53 | # simply provide checkpoint path to resume training 54 | ckpt_path: null 55 | 56 | # seed for random number generators in pytorch, numpy and python.random 57 | seed: null 58 | 59 | # float32 precision operations (torch>=2.0) 60 | # see https://pytorch.org/docs/2.0/generated/torch.set_float32_matmul_precision.html 61 | float32_matmul_precision: high 62 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 100 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 10 16 | 17 | # set True to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | 7 | # mixed precision for extra speed-up 8 | # precision: 16 9 | # precision: bf16 10 | precision: 32 11 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /docs/logging.md: -------------------------------------------------------------------------------- 1 | # Logging 2 | 3 | ## Structure of `logs/` directory 4 | Logs directory structure. 5 | 6 | Your logs will be saved under the following structure: 7 | 8 | ``` 9 | └── logs 10 | ├── {{train, eval}} # Task name 11 | │ ├── runs # Logs generated by single runs 12 | │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the run 13 | │ │ │ ├── .hydra # Hydra logs 14 | │ │ │ ├── csv # Csv logs 15 | │ │ │ ├── wandb # Weights&Biases logs 16 | │ │ │ ├── checkpoints # Training checkpoints 17 | │ │ │ └── ... # Any other thing saved during training 18 | │ │ └── ... 19 | │ │ 20 | │ └── multiruns # Logs generated by multiruns (ie using --multirun) 21 | │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the multirun 22 | │ │ ├──1 # Multirun job number 23 | │ │ ├──2 24 | │ │ └── ... 25 | │ └── ... 26 | │ 27 | └── debug # Logs generated when debugging config is attached 28 | └── ... 29 | 30 | ``` 31 | 32 | ## Setting up your own `data/` and `logs/` paths 33 | The `data/` and `logs/` directories will store all your datasets and training 34 | logs. By default, these are placed in the repository directory. 35 | 36 | Since this may take some space, or your heavy data may be stored elsewhere, you 37 | may specify other paths for these directories by creating a 38 | `configs/local/defaults.yaml` file containing the following: 39 | 40 | ```yaml 41 | # @package paths 42 | 43 | # path to data directory 44 | data_dir: /path/to/your/data/ 45 | 46 | # path to logging directory 47 | log_dir: /path/to/your/logs/ 48 | ``` 49 | 50 | ## Logger options 51 | By default, your logs will automatically be uploaded to 52 | [Weights and Biases](https://wandb.ai), from where you can track and compare 53 | your experiments. 54 | 55 | Other loggers are available in `configs/logger/`. See 56 | [Lightning-Hydra](https://github.com/ashleve/lightning-hydra-template) for more 57 | information. 58 | -------------------------------------------------------------------------------- /media/dales/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/dales/input.png -------------------------------------------------------------------------------- /media/dales/pano_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/dales/pano_gt.png -------------------------------------------------------------------------------- /media/dales/sem_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/dales/sem_gt.png -------------------------------------------------------------------------------- /media/dales/sem_gt_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/dales/sem_gt_demo.png -------------------------------------------------------------------------------- /media/kitti360/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/kitti360/input.png -------------------------------------------------------------------------------- /media/kitti360/pano_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/kitti360/pano_gt.png -------------------------------------------------------------------------------- /media/kitti360/sem_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/kitti360/sem_gt.png -------------------------------------------------------------------------------- /media/s3dis/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/s3dis/input.png -------------------------------------------------------------------------------- /media/s3dis/pano_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/s3dis/pano_gt.png -------------------------------------------------------------------------------- /media/s3dis/sem_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/s3dis/sem_gt.png -------------------------------------------------------------------------------- /media/scannet/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/scannet/input.png -------------------------------------------------------------------------------- /media/scannet/pano_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/scannet/pano_gt.png -------------------------------------------------------------------------------- /media/scannet/sem_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/scannet/sem_gt.png -------------------------------------------------------------------------------- /media/superpoint_transformer_tutorial.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/superpoint_transformer_tutorial.pdf -------------------------------------------------------------------------------- /media/teaser_spt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/teaser_spt.jpg -------------------------------------------------------------------------------- /media/teaser_spt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/teaser_spt.png -------------------------------------------------------------------------------- /media/teaser_supercluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/teaser_supercluster.png -------------------------------------------------------------------------------- /media/visualizations.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/visualizations.7z -------------------------------------------------------------------------------- /media/viz_100k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_100k.png -------------------------------------------------------------------------------- /media/viz_10k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_10k.png -------------------------------------------------------------------------------- /media/viz_errors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_errors.png -------------------------------------------------------------------------------- /media/viz_p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_p2.png -------------------------------------------------------------------------------- /media/viz_pos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_pos.png -------------------------------------------------------------------------------- /media/viz_radius_center.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_radius_center.png -------------------------------------------------------------------------------- /media/viz_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_rgb.png -------------------------------------------------------------------------------- /media/viz_select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_select.png -------------------------------------------------------------------------------- /media/viz_x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_x.png -------------------------------------------------------------------------------- /media/viz_y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/media/viz_y.png -------------------------------------------------------------------------------- /notebooks/demo_nag.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/notebooks/demo_nag.pt -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .debug import is_debug_enabled, debug, set_debug 2 | import src.data 3 | import src.datasets 4 | import src.datamodules 5 | import src.loader 6 | import src.metrics 7 | import src.models 8 | import src.nn 9 | import src.transforms 10 | import src.utils 11 | import src.visualization 12 | 13 | __version__ = '0.0.1' 14 | 15 | __all__ = [ 16 | 'is_debug_enabled', 17 | 'debug', 18 | 'set_debug', 19 | 'src', 20 | '__version__', 21 | ] 22 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .csr import * 2 | from .cluster import * 3 | from .instance import * 4 | from .data import * 5 | from .nag import * 6 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/src/datamodules/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/src/datamodules/components/__init__.py -------------------------------------------------------------------------------- /src/datamodules/dales.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.datamodules.base import BaseDataModule 3 | from src.datasets import DALES, MiniDALES 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class DALESDataModule(BaseDataModule): 10 | """LightningDataModule for DALES dataset. 11 | 12 | A DataModule implements 5 key methods: 13 | 14 | def prepare_data(self): 15 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 16 | # download data, pre-process, split, save to disk, etc... 17 | def setup(self, stage): 18 | # things to do on every process in DDP 19 | # load data, set variables, etc... 20 | def train_dataloader(self): 21 | # return train dataloader 22 | def val_dataloader(self): 23 | # return validation dataloader 24 | def test_dataloader(self): 25 | # return test dataloader 26 | def teardown(self): 27 | # called on every process in DDP 28 | # clean up after fit or test 29 | 30 | This allows you to share a full dataset without explaining how to download, 31 | split, transform and process the data. 32 | 33 | Read the docs: 34 | https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html 35 | """ 36 | _DATASET_CLASS = DALES 37 | _MINIDATASET_CLASS = MiniDALES 38 | 39 | 40 | if __name__ == "__main__": 41 | import hydra 42 | import omegaconf 43 | import pyrootutils 44 | 45 | root = str(pyrootutils.setup_root(__file__, pythonpath=True)) 46 | cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/dales.yaml") 47 | cfg.data_dir = root + "/data" 48 | _ = hydra.utils.instantiate(cfg) 49 | -------------------------------------------------------------------------------- /src/datamodules/kitti360.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.datamodules.base import BaseDataModule 3 | from src.datasets import KITTI360, MiniKITTI360 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class KITTI360DataModule(BaseDataModule): 10 | """LightningDataModule for KITTI360 dataset. 11 | 12 | A DataModule implements 5 key methods: 13 | 14 | def prepare_data(self): 15 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 16 | # download data, pre-process, split, save to disk, etc... 17 | def setup(self, stage): 18 | # things to do on every process in DDP 19 | # load data, set variables, etc... 20 | def train_dataloader(self): 21 | # return train dataloader 22 | def val_dataloader(self): 23 | # return validation dataloader 24 | def test_dataloader(self): 25 | # return test dataloader 26 | def teardown(self): 27 | # called on every process in DDP 28 | # clean up after fit or test 29 | 30 | This allows you to share a full dataset without explaining how to download, 31 | split, transform and process the data. 32 | 33 | Read the docs: 34 | https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html 35 | """ 36 | _DATASET_CLASS = KITTI360 37 | _MINIDATASET_CLASS = MiniKITTI360 38 | 39 | 40 | if __name__ == "__main__": 41 | import hydra 42 | import omegaconf 43 | import pyrootutils 44 | 45 | root = str(pyrootutils.setup_root(__file__, pythonpath=True)) 46 | cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/kitti360.yaml") 47 | cfg.data_dir = root + "/data" 48 | _ = hydra.utils.instantiate(cfg) 49 | -------------------------------------------------------------------------------- /src/datamodules/s3dis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.datamodules.base import BaseDataModule 3 | from src.datasets import S3DIS, MiniS3DIS 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class S3DISDataModule(BaseDataModule): 10 | """LightningDataModule for S3DIS dataset. 11 | 12 | A DataModule implements 5 key methods: 13 | 14 | def prepare_data(self): 15 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 16 | # download data, pre-process, split, save to disk, etc... 17 | def setup(self, stage): 18 | # things to do on every process in DDP 19 | # load data, set variables, etc... 20 | def train_dataloader(self): 21 | # return train dataloader 22 | def val_dataloader(self): 23 | # return validation dataloader 24 | def test_dataloader(self): 25 | # return test dataloader 26 | def teardown(self): 27 | # called on every process in DDP 28 | # clean up after fit or test 29 | 30 | This allows you to share a full dataset without explaining how to download, 31 | split, transform and process the data. 32 | 33 | Read the docs: 34 | https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html 35 | """ 36 | _DATASET_CLASS = S3DIS 37 | _MINIDATASET_CLASS = MiniS3DIS 38 | 39 | 40 | if __name__ == "__main__": 41 | import hydra 42 | import omegaconf 43 | import pyrootutils 44 | 45 | root = str(pyrootutils.setup_root(__file__, pythonpath=True)) 46 | cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/s3dis.yaml") 47 | cfg.data_dir = root + "/data" 48 | _ = hydra.utils.instantiate(cfg) 49 | -------------------------------------------------------------------------------- /src/datamodules/s3dis_room.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.datamodules.base import BaseDataModule 3 | from src.datasets import S3DISRoom, MiniS3DISRoom 4 | 5 | 6 | # Occasional Dataloader issues with S3DISRoomDataModule on some 7 | # machines. Hack to solve this: 8 | # https://stackoverflow.com/questions/73125231/pytorch-dataloaders-bad-file-descriptor-and-eof-for-workers0 9 | import torch.multiprocessing 10 | torch.multiprocessing.set_sharing_strategy('file_system') 11 | 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | 16 | class S3DISRoomDataModule(BaseDataModule): 17 | """LightningDataModule for room-wise S3DIS dataset. 18 | 19 | A DataModule implements 5 key methods: 20 | 21 | def prepare_data(self): 22 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 23 | # download data, pre-process, split, save to disk, etc... 24 | def setup(self, stage): 25 | # things to do on every process in DDP 26 | # load data, set variables, etc... 27 | def train_dataloader(self): 28 | # return train dataloader 29 | def val_dataloader(self): 30 | # return validation dataloader 31 | def test_dataloader(self): 32 | # return test dataloader 33 | def teardown(self): 34 | # called on every process in DDP 35 | # clean up after fit or test 36 | 37 | This allows you to share a full dataset without explaining how to download, 38 | split, transform and process the data. 39 | 40 | Read the docs: 41 | https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html 42 | """ 43 | _DATASET_CLASS = S3DISRoom 44 | _MINIDATASET_CLASS = MiniS3DISRoom 45 | 46 | 47 | if __name__ == "__main__": 48 | import hydra 49 | import omegaconf 50 | import pyrootutils 51 | 52 | root = str(pyrootutils.setup_root(__file__, pythonpath=True)) 53 | cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/s3dis_room.yaml") 54 | cfg.data_dir = root + "/data" 55 | _ = hydra.utils.instantiate(cfg) 56 | -------------------------------------------------------------------------------- /src/datamodules/scannet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.datamodules.base import BaseDataModule 3 | from src.datasets import ScanNet, MiniScanNet 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class ScanNetDataModule(BaseDataModule): 10 | """LightningDataModule for ScanNet dataset. 11 | 12 | A DataModule implements 5 key methods: 13 | 14 | def prepare_data(self): 15 | # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) 16 | # download data, pre-process, split, save to disk, etc... 17 | def setup(self, stage): 18 | # things to do on every process in DDP 19 | # load data, set variables, etc... 20 | def train_dataloader(self): 21 | # return train dataloader 22 | def val_dataloader(self): 23 | # return validation dataloader 24 | def test_dataloader(self): 25 | # return test dataloader 26 | def teardown(self): 27 | # called on every process in DDP 28 | # clean up after fit or test 29 | 30 | This allows you to share a full dataset without explaining how to download, 31 | split, transform and process the data. 32 | 33 | Read the docs: 34 | https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html 35 | """ 36 | _DATASET_CLASS = ScanNet 37 | _MINIDATASET_CLASS = MiniScanNet 38 | 39 | 40 | if __name__ == "__main__": 41 | import hydra 42 | import omegaconf 43 | import pyrootutils 44 | 45 | root = str(pyrootutils.setup_root(__file__, pythonpath=True)) 46 | cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/scannet.yaml") 47 | cfg.data_dir = root + "/data" 48 | _ = hydra.utils.instantiate(cfg) 49 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | IGNORE_LABEL: int = -1 2 | from .base import * 3 | from .dales import * 4 | from .kitti360 import * 5 | from .s3dis import * 6 | from .s3dis_room import * 7 | from .scannet import * 8 | -------------------------------------------------------------------------------- /src/datasets/dales_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ######################################################################## 5 | # Download information # 6 | ######################################################################## 7 | 8 | FORM_URL = 'https://docs.google.com/forms/d/e/1FAIpQLSefhHMMvN0Uwjnj_vWQgYSvtFOtaoGFWsTIcRuBTnP09NHR7A/viewform?fbzx=5530674395784263977' 9 | 10 | # DALES in LAS format 11 | LAS_TAR_NAME = 'dales_semantic_segmentation_las.tar.gz' 12 | LAS_UNTAR_NAME = "dales_las" 13 | 14 | # DALES in PLY format 15 | PLY_TAR_NAME = 'dales_semantic_segmentation_ply.tar.gz' 16 | PLY_UNTAR_NAME = "dales_ply" 17 | 18 | # DALES in PLY, only version with intensity and instance labels 19 | OBJECTS_TAR_NAME = 'DALESObjects.tar.gz' 20 | OBJECTS_UNTAR_NAME = "DALESObjects" 21 | 22 | 23 | ######################################################################## 24 | # Data splits # 25 | ######################################################################## 26 | 27 | # The validation set was arbitrarily chosen as the x last train tiles: 28 | TILES = { 29 | 'train': [ 30 | '5080_54435_new', 31 | '5190_54400_new', 32 | '5105_54460_new', 33 | '5130_54355_new', 34 | '5165_54395_new', 35 | '5185_54390_new', 36 | '5180_54435_new', 37 | '5085_54320_new', 38 | '5100_54495_new', 39 | '5110_54320_new', 40 | '5140_54445_new', 41 | '5105_54405_new', 42 | '5185_54485_new', 43 | '5165_54390_new', 44 | '5145_54460_new', 45 | '5110_54460_new', 46 | '5180_54485_new', 47 | '5150_54340_new', 48 | '5145_54405_new', 49 | '5145_54470_new', 50 | '5160_54330_new', 51 | '5135_54495_new', 52 | '5145_54480_new', 53 | '5115_54480_new', 54 | '5110_54495_new', 55 | '5095_54440_new'], 56 | 57 | 'val': [ 58 | '5145_54340_new', 59 | '5095_54455_new', 60 | '5110_54475_new'], 61 | 62 | 'test': [ 63 | '5080_54470_new', 64 | '5100_54440_new', 65 | '5140_54390_new', 66 | '5080_54400_new', 67 | '5155_54335_new', 68 | '5150_54325_new', 69 | '5120_54445_new', 70 | '5135_54435_new', 71 | '5175_54395_new', 72 | '5100_54490_new', 73 | '5135_54430_new']} 74 | 75 | 76 | ######################################################################## 77 | # Labels # 78 | ######################################################################## 79 | 80 | DALES_NUM_CLASSES = 8 81 | 82 | ID2TRAINID = np.asarray([8, 0, 1, 2, 3, 4, 5, 6, 7]) 83 | 84 | CLASS_NAMES = [ 85 | 'Ground', 86 | 'Vegetation', 87 | 'Cars', 88 | 'Trucks', 89 | 'Power lines', 90 | 'Fences', 91 | 'Poles', 92 | 'Buildings', 93 | 'Unknown'] 94 | 95 | CLASS_COLORS = np.asarray([ 96 | [243, 214, 171], # sunset 97 | [ 70, 115, 66], # fern green 98 | [233, 50, 239], 99 | [243, 238, 0], 100 | [190, 153, 153], 101 | [ 0, 233, 11], 102 | [239, 114, 0], 103 | [214, 66, 54], # vermillon 104 | [ 0, 8, 116]]) 105 | 106 | # For instance segmentation 107 | MIN_OBJECT_SIZE = 100 108 | THING_CLASSES = [2, 3, 4, 5, 6, 7] 109 | STUFF_CLASSES = [i for i in range(DALES_NUM_CLASSES) if not i in THING_CLASSES] 110 | -------------------------------------------------------------------------------- /src/debug.py: -------------------------------------------------------------------------------- 1 | # Copied from: 2 | 3 | __debug_flag__ = {'enabled': False} 4 | 5 | 6 | def is_debug_enabled(): 7 | r"""Returns :obj:`True`, if the debug mode is enabled.""" 8 | return __debug_flag__['enabled'] 9 | 10 | 11 | def set_debug_enabled(mode): 12 | __debug_flag__['enabled'] = mode 13 | 14 | 15 | class debug(object): 16 | r"""Context-manager that enables the debug mode to help track down 17 | errors and separate usage errors from real bugs. 18 | 19 | Example: 20 | 21 | >>> with src.debug(): 22 | ... out = model(data.x, data.edge_index) 23 | """ 24 | 25 | def __init__(self): 26 | self.prev = is_debug_enabled() 27 | 28 | def __enter__(self): 29 | set_debug_enabled(True) 30 | 31 | def __exit__(self, *args): 32 | set_debug_enabled(self.prev) 33 | return False 34 | 35 | 36 | class set_debug(object): 37 | r"""Context-manager that sets the debug mode on or off. 38 | 39 | :class:`set_debug` will enable or disable the debug mode based on 40 | its argument :attr:`mode`. 41 | It can be used as a context-manager or as a function. 42 | 43 | See :class:`debug` above for more details. 44 | """ 45 | 46 | def __init__(self, mode): 47 | self.prev = is_debug_enabled() 48 | set_debug_enabled(mode) 49 | 50 | def __enter__(self): 51 | pass 52 | 53 | def __exit__(self, *args): 54 | set_debug_enabled(self.prev) 55 | return False 56 | -------------------------------------------------------------------------------- /src/dependencies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/src/dependencies/__init__.py -------------------------------------------------------------------------------- /src/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import * 2 | -------------------------------------------------------------------------------- /src/loader/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader as TorchDataLoader 2 | 3 | 4 | __all__ = ['DataLoader'] 5 | 6 | def __identity__(batch_list): 7 | """ 8 | fix for windows, where lambda can't be pickled. 9 | We have to use a top level function 10 | see: 11 | https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/10?page=2 12 | https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled 13 | """ 14 | return batch_list 15 | 16 | class DataLoader(TorchDataLoader): 17 | """Same as torch DataLoader except that the default behaviour for 18 | `collate_fn=None` is a simple identity. (i.e. the DataLoader will 19 | return a list of elements by default). This approach is meant to 20 | move the CPU-hungry NAG.from_nag_list (in particular, the level-0 21 | Data.from_nag_list) to GPU. This is instead taken care of in the 22 | 'DataModule.on_after_batch_transfer' hook, which calls the dataset 23 | 'on_device_transform'. 24 | 25 | Use `collate_fn=NAG.from_data_list` if you want the CPU to do this 26 | operation (but beware of collisions with our 27 | 'DataModule.on_after_batch_transfer' implementation. 28 | """ 29 | def __init__(self, *args, collate_fn=None, **kwargs): 30 | if collate_fn is None: 31 | collate_fn = __identity__ 32 | super().__init__(*args, collate_fn=collate_fn, **kwargs) -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .bce import * 2 | from .multi import * 3 | from .lovasz import * 4 | from .focal import * 5 | from .l1 import * 6 | from .l2 import * 7 | -------------------------------------------------------------------------------- /src/loss/bce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import BCEWithLogitsLoss as TorchBCEWithLogitsLoss 4 | from src.loss.weighted import WeightedLossMixIn 5 | 6 | 7 | __all__ = ['WeightedBCEWithLogitsLoss', 'BCEWithLogitsLoss'] 8 | 9 | 10 | class WeightedBCEWithLogitsLoss(WeightedLossMixIn, TorchBCEWithLogitsLoss): 11 | """Weighted BCE loss between predicted and target offsets. This is 12 | basically the BCEWithLogitsLoss except that positive weights must be 13 | passed at forward time to give more importance to some items. 14 | 15 | Besides, we remove the constraint of passing `pos_weight` as a 16 | Tensor. This simplifies instantiation with hydra. 17 | """ 18 | 19 | def __init__(self, *args, pos_weight=None, **kwargs): 20 | if pos_weight is not None and not isinstance(pos_weight, Tensor): 21 | pos_weight = torch.as_tensor(pos_weight) 22 | super().__init__( 23 | *args, pos_weight=pos_weight, reduction='none', **kwargs) 24 | 25 | def load_state_dict(self, state_dict, strict=True): 26 | """Normal `load_state_dict` behavior, except for the shared 27 | `pos_weight`. 28 | """ 29 | # Get the weight from the state_dict 30 | pos_weight = state_dict.get('pos_weight') 31 | state_dict.pop('pos_weight') 32 | 33 | # Normal load_state_dict, ignoring pos_weight 34 | out = super().load_state_dict(state_dict, strict=strict) 35 | 36 | # Set the pos_weight 37 | self.pos_weight = pos_weight 38 | 39 | return out 40 | 41 | 42 | class BCEWithLogitsLoss(WeightedBCEWithLogitsLoss): 43 | """BCE loss between predicted and target offsets. 44 | 45 | The forward signature allows using this loss as a weighted loss, 46 | with input weights ignored. 47 | """ 48 | 49 | def forward(self, input, target, weight): 50 | return super().forward(input, target, None) 51 | -------------------------------------------------------------------------------- /src/loss/l1.py: -------------------------------------------------------------------------------- 1 | from torch.nn import L1Loss as TorchL1Loss 2 | from src.loss.weighted import WeightedLossMixIn 3 | 4 | 5 | __all__ = ['WeightedL1Loss', 'L1Loss'] 6 | 7 | 8 | class WeightedL1Loss(WeightedLossMixIn, TorchL1Loss): 9 | """Weighted L1 loss between predicted and target offsets. This is 10 | basically the L1Loss except that positive weights must be passed at 11 | forward time to give more importance to some items. 12 | """ 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, reduction='none', **kwargs) 16 | 17 | 18 | class L1Loss(WeightedL1Loss): 19 | """L1 loss between predicted and target offsets. 20 | 21 | The forward signature allows using this loss as a weighted loss, 22 | with input weights ignored. 23 | """ 24 | 25 | def forward(self, input, target, weight): 26 | return super().forward(input, target, None) 27 | -------------------------------------------------------------------------------- /src/loss/l2.py: -------------------------------------------------------------------------------- 1 | from torch.nn import MSELoss as TorchL2Loss 2 | from src.loss.weighted import WeightedLossMixIn 3 | 4 | 5 | __all__ = ['WeightedL2Loss', 'L2Loss'] 6 | 7 | 8 | class WeightedL2Loss(WeightedLossMixIn, TorchL2Loss): 9 | """Weighted mean squared error (ie L2 loss) between predicted and 10 | target offsets. This is basically the MSELoss except that positive 11 | weights must be passed at forward time to give more importance to 12 | some items. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, reduction='none', **kwargs) 17 | 18 | 19 | class L2Loss(WeightedL2Loss): 20 | """Mean squared error (ie L2 loss) between predicted and target 21 | offsets. 22 | 23 | The forward signature allows using this loss as a weighted loss, 24 | with input weights ignored. 25 | """ 26 | 27 | def forward(self, input, target, weight): 28 | return super().forward(input, target, None) 29 | -------------------------------------------------------------------------------- /src/loss/multi.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | __all__ = ['MultiLoss'] 5 | 6 | 7 | class MultiLoss(nn.Module): 8 | """Wrapper to compute the weighted sum of multiple criteria 9 | 10 | :param criteria: List(callable) 11 | List of criteria 12 | :param lambdas: List(str) 13 | 14 | """ 15 | 16 | def __init__(self, criteria, lambdas): 17 | super().__init__() 18 | assert len(criteria) == len(lambdas) 19 | self.criteria = nn.ModuleList(criteria) 20 | self.lambdas = lambdas 21 | 22 | def __len__(self): 23 | return len(self.criteria) 24 | 25 | def to(self, *args, **kwargs): 26 | for i in range(len(self)): 27 | self.criteria[i] = self.criteria[i].to(*args, **kwargs) 28 | self.lambdas[i] = self.lambdas[i].to(*args, **kwargs) 29 | 30 | def extra_repr(self) -> str: 31 | return f'lambdas={self.lambdas}' 32 | 33 | def forward(self, a, b, **kwargs): 34 | loss = 0 35 | for lamb, criterion, a_, b_ in zip(self.lambdas, self.criteria, a, b): 36 | loss = loss + lamb * criterion(a_, b_, **kwargs) 37 | return loss 38 | 39 | @property 40 | def weight(self): 41 | """MultiLoss supports `weight` if all its criteria support it. 42 | """ 43 | return self.criteria[0].weight 44 | 45 | @weight.setter 46 | def weight(self, weight): 47 | """MultiLoss supports `weight` if all its criteria support it. 48 | """ 49 | for i in range(len(self)): 50 | self.criteria[i].weight = weight 51 | 52 | def state_dict(self, *args, destination=None, prefix='', keep_vars=False): 53 | """Normal `state_dict` behavior, except for the shared criterion 54 | weights, which are not saved under `prefix.criteria.i.weight` 55 | but under `prefix.weight`. 56 | """ 57 | destination = super().state_dict( 58 | *args, destination=destination, prefix=prefix, keep_vars=keep_vars) 59 | 60 | # Remove the 'weight' from the criteria 61 | for i in range(len(self)): 62 | destination.pop(f"{prefix}criteria.{i}.weight") 63 | 64 | # Only save the global shared weight 65 | destination[f"{prefix}weight"] = self.weight 66 | 67 | return destination 68 | 69 | def load_state_dict(self, state_dict, strict=True): 70 | """Normal `load_state_dict` behavior, except for the shared 71 | criterion weights, which are not saved under `criteria.i.weight` 72 | but under `prefix.weight`. 73 | """ 74 | # Get the weight from the state_dict 75 | old_format = state_dict.get('criteria.0.weight') 76 | new_format = state_dict.get('weight') 77 | weight = new_format if new_format is not None else old_format 78 | for k in [f"criteria.{i}.weight" for i in range(len(self))]: 79 | if k in state_dict.keys(): 80 | state_dict.pop(k) 81 | 82 | # Normal load_state_dict, ignoring self.criteria.0.weight and 83 | # self.weight 84 | out = super().load_state_dict(state_dict, strict=strict) 85 | 86 | # Set the weight 87 | self.weight = weight 88 | 89 | return out 90 | -------------------------------------------------------------------------------- /src/loss/weighted.py: -------------------------------------------------------------------------------- 1 | __all__ = ['WeightedLossMixIn'] 2 | 3 | 4 | class WeightedLossMixIn: 5 | """A mix-in for converting a torch loss into an item-weighted loss. 6 | """ 7 | def forward(self, input, target, weight): 8 | if weight is not None: 9 | assert weight.ge(0).all(), "Weights must be positive." 10 | assert weight.gt(0).any(), "At least one weight must be non-zero." 11 | 12 | # Compute the loss, without reduction 13 | loss = super().forward(input, target) 14 | if loss.dim() == 1: 15 | loss = loss.view(-1, 1) 16 | 17 | # Sum the loss terms across the spatial dimension, so the 18 | # downstream averaging does not normalize by the number of 19 | # dimensions 20 | loss = loss.sum(dim=1).view(-1, 1) 21 | 22 | # If weights are None, fallback to normal unweighted L2 loss 23 | if weight is None: 24 | return loss.mean() 25 | 26 | # Compute the weighted mean 27 | return (loss * (weight / weight.sum()).view(-1, 1)).sum() 28 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .semantic import * 2 | from .mean_average_precision import * 3 | from .panoptic import * 4 | from .weighted_li import * 5 | -------------------------------------------------------------------------------- /src/metrics/weighted_li.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from torch import Tensor 3 | from typing import Tuple 4 | from torchmetrics import MeanSquaredError, MeanAbsoluteError 5 | from torchmetrics.utilities.checks import _check_same_shape 6 | 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | __all__ = ['WeightedL2Error', 'WeightedL1Error', 'L2Error', 'L1Error'] 12 | 13 | 14 | def _weighted_Li_error_update( 15 | pred: Tensor, 16 | target: Tensor, 17 | weight: Tensor, 18 | norm: int 19 | ) -> Tuple[Tensor, int]: 20 | """Update and returns variables required to compute weighted L1 21 | error. 22 | 23 | Args: 24 | pred: Predicted tensor 25 | target: Ground truth tensor 26 | weight: weight tensor 27 | norm: `i` for Li norm (`i` >= 0) 28 | """ 29 | if weight is not None: 30 | assert weight.dim() == 1 31 | assert weight.numel() == pred.shape[0] 32 | assert norm >= 0 33 | 34 | _check_same_shape(pred, target) 35 | 36 | a = pred - target 37 | sum_dims = tuple(range(1, a.dim())) 38 | if norm == 0: 39 | a = a.any(dim=1).float().sum(dim=sum_dims) 40 | elif norm == 1: 41 | a = a.abs().sum(dim=sum_dims) 42 | else: 43 | a = a.pow(norm).sum(dim=sum_dims) 44 | 45 | sum_error = (weight * a).sum() if weight is not None else a.sum() 46 | sum_weight = weight.sum() if weight is not None else pred.shape[0] 47 | 48 | return sum_error, sum_weight 49 | 50 | 51 | class WeightedL2Error(MeanSquaredError): 52 | """Simply torchmetrics' MeanSquaredError (ie L2 loss) with 53 | item-weighted mean to give more importance to some items. 54 | """ 55 | 56 | def update(self, pred: Tensor, target: Tensor, weight: Tensor) -> None: 57 | """Update state with predictions, targets, and weights.""" 58 | sum_squared_error, sum_weight = _weighted_Li_error_update( 59 | pred, target, weight, 2) 60 | 61 | self.sum_squared_error += sum_squared_error 62 | self.total = self.total + sum_weight 63 | 64 | 65 | class WeightedL1Error(MeanAbsoluteError): 66 | """Simply torchmetrics' MeanAbsoluteError (ie L1 loss) with 67 | item-weighted mean to give more importance to some items. 68 | """ 69 | 70 | def update(self, pred: Tensor, target: Tensor, weight: Tensor) -> None: 71 | """Update state with predictions, targets, and weights.""" 72 | sum_abs_error, sum_weight = _weighted_Li_error_update( 73 | pred, target, weight, 1) 74 | 75 | self.sum_abs_error += sum_abs_error 76 | self.total = self.total + sum_weight 77 | 78 | 79 | class L2Error(WeightedL2Error): 80 | """Simply torchmetrics' MeanSquaredError (ie L2 loss) with summation 81 | instead of mean along the feature dimensions. 82 | """ 83 | def update(self, pred: Tensor, target: Tensor) -> None: 84 | """Update state with predictions and targets.""" 85 | super().update(pred, target, None) 86 | 87 | 88 | class L1Error(WeightedL1Error): 89 | """Simply torchmetrics' MeanAbsoluteError (ie L1 loss) with 90 | summation instead of mean along the feature dimensions. 91 | """ 92 | def update(self, pred: Tensor, target: Tensor) -> None: 93 | """Update state with predictions and targets.""" 94 | super().update(pred, target, None) 95 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | """Model components groups architectures ready to be used a `net` in a 2 | LightningModule. These are complex architectures, on which a 3 | LightningModule can add heads and train for different types of tasks. 4 | """ 5 | from .spt import * 6 | from .mlp import * 7 | -------------------------------------------------------------------------------- /src/models/components/mlp.py: -------------------------------------------------------------------------------- 1 | from torch_scatter import scatter 2 | from torch import nn 3 | from src.data import NAG 4 | from src.nn import MLP, BatchNorm 5 | 6 | 7 | __all__ = ['NodeMLP'] 8 | 9 | 10 | class NodeMLP(nn.Module): 11 | """Simple MLP on the handcrafted features of the level-i in a NAG. 12 | This is used as a baseline to test how expressive handcrafted 13 | features are. 14 | """ 15 | 16 | def __init__( 17 | self, dims, level=0, activation=nn.LeakyReLU(), norm=BatchNorm, 18 | drop=None, norm_mode='graph'): 19 | 20 | super().__init__() 21 | 22 | self.level = level 23 | self.mlp = MLP(dims, activation=activation, norm=norm, drop=drop) 24 | self.norm_mode = norm_mode 25 | 26 | @property 27 | def out_dim(self): 28 | return self.mlp.out_dim 29 | 30 | def forward(self, nag): 31 | assert isinstance(nag, NAG) 32 | assert nag.num_levels > self.level 33 | 34 | # Compute node features from the handcrafted features 35 | norm_index = nag[self.i_level].norm_index(mode=self.norm_mode) 36 | x = self.mlp(nag[self.level].x, batch=norm_index) 37 | 38 | # If node level is 1, output level-1 features 39 | if self.level == 1: 40 | return x 41 | 42 | # If node level is 0, max-pool to produce level-1 features 43 | if self.level == 0: 44 | return scatter( 45 | x, nag[0].super_index, dim=0, dim_size=nag[1].num_nodes, 46 | reduce='max') 47 | 48 | # If node level is larger than 1, distribute parent features to 49 | # level-1 nodes 50 | super_index = nag.get_super_index(self.level, low=1) 51 | return x[super_index] 52 | -------------------------------------------------------------------------------- /src/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .norm import * 2 | from .mlp import * 3 | from .pool import * 4 | from .unpool import * 5 | from .attention import * 6 | from .fusion import * 7 | from .dropout import * 8 | from .transformer import * 9 | from .stage import * 10 | from .position_encoding import * 11 | from .instance import * 12 | -------------------------------------------------------------------------------- /src/nn/dropout.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | __all__ = ['DropPath'] 5 | 6 | 7 | def drop_path(x, drop_prob=0, training=False, scale_by_keep=True): 8 | """Drop paths (Stochastic Depth) per sample (when applied in main 9 | path of residual blocks). 10 | 11 | credit: https://github.com/rwightman/pytorch-image-models 12 | """ 13 | if drop_prob == 0. or not training: 14 | return x 15 | keep_prob = 1 - drop_prob 16 | # work with diff dim tensors, not just 2D ConvNets 17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 18 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 19 | if keep_prob > 0.0 and scale_by_keep: 20 | random_tensor.div_(keep_prob) 21 | return x * random_tensor 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main 26 | path of residual blocks). 27 | 28 | credit: https://github.com/rwightman/pytorch-image-models 29 | """ 30 | 31 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 32 | super().__init__() 33 | self.drop_prob = drop_prob 34 | self.scale_by_keep = scale_by_keep 35 | 36 | def forward(self, x): 37 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 38 | 39 | def extra_repr(self): 40 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 41 | -------------------------------------------------------------------------------- /src/nn/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | __all__ = ['CatFusion', 'AdditiveFusion', 'TakeFirstFusion', 'TakeSecondFusion'] 6 | 7 | 8 | def fusion_factory(mode): 9 | """Return the fusion class from an input string. 10 | 11 | :param mode: str 12 | """ 13 | if mode in ['cat', 'concatenate', 'concatenation', '|']: 14 | return CatFusion() 15 | elif mode in ['residual', 'additive', '+']: 16 | return AdditiveFusion() 17 | elif mode in ['first', '1', '1st']: 18 | return TakeFirstFusion() 19 | elif mode in ['second', '2', '2nd']: 20 | return TakeSecondFusion() 21 | else: 22 | raise NotImplementedError(f"Unknown mode='{mode}'") 23 | 24 | 25 | class BaseFusion(nn.Module): 26 | def forward(self, x1, x2): 27 | if x1 is None and x2 is None: 28 | return 29 | if x1 is None: 30 | return x2 31 | if x2 is None: 32 | return x1 33 | return self._func(x1, x2) 34 | 35 | def _func(self, x1, x2): 36 | raise NotImplementedError 37 | 38 | 39 | class CatFusion(BaseFusion): 40 | def _func(self, x1, x2): 41 | return torch.cat((x1, x2), dim=1) 42 | 43 | 44 | class AdditiveFusion(BaseFusion): 45 | def _func(self, x1, x2): 46 | return x1 + x2 47 | 48 | 49 | class TakeFirstFusion(BaseFusion): 50 | def _func(self, x1, x2): 51 | return x1 52 | 53 | 54 | class TakeSecondFusion(BaseFusion): 55 | def _func(self, x1, x2): 56 | return x2 57 | 58 | -------------------------------------------------------------------------------- /src/nn/unpool.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | __all__ = ['IndexUnpool'] 5 | 6 | 7 | class IndexUnpool(nn.Module): 8 | """Simple unpooling operation that redistributes i+1-level features 9 | to i-level nodes based on their indexing. 10 | """ 11 | 12 | def forward(self, x, idx): 13 | return x.index_select(0, idx) 14 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import * 2 | -------------------------------------------------------------------------------- /src/transforms/debug.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from src.data import NAG 3 | from src.transforms import Transform 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | __all__ = ['HelloWorld'] 10 | 11 | 12 | class HelloWorld(Transform): 13 | _IN_TYPE = NAG 14 | _OUT_TYPE = NAG 15 | 16 | def _process(self, nag): 17 | log.info("\n**** Hello World ! ****\n") 18 | return nag 19 | -------------------------------------------------------------------------------- /src/transforms/device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.transforms import Transform 3 | from src.data import NAG 4 | 5 | 6 | __all__ = ['DataTo', 'NAGTo'] 7 | 8 | 9 | class DataTo(Transform): 10 | """Move Data object to specified device.""" 11 | 12 | def __init__(self, device): 13 | if not isinstance(device, torch.device): 14 | device = torch.device(device) 15 | self.device = device 16 | 17 | def _process(self, data): 18 | if data.device == self.device: 19 | return data 20 | return data.to(self.device) 21 | 22 | 23 | class NAGTo(Transform): 24 | """Move Data object to specified device.""" 25 | 26 | _IN_TYPE = NAG 27 | _OUT_TYPE = NAG 28 | 29 | def __init__(self, device): 30 | if not isinstance(device, torch.device): 31 | device = torch.device(device) 32 | self.device = device 33 | 34 | def _process(self, nag): 35 | if nag.device == self.device: 36 | return nag 37 | return nag.to(self.device) 38 | -------------------------------------------------------------------------------- /src/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from torch_geometric.transforms import BaseTransform 3 | 4 | from src.data import Data 5 | 6 | 7 | __all__ = ['Transform'] 8 | 9 | 10 | class Transform(BaseTransform): 11 | """Transform on `_IN_TYPE` returning `_OUT_TYPE`.""" 12 | 13 | _IN_TYPE = Data 14 | _OUT_TYPE = Data 15 | _NO_REPR = [] 16 | 17 | def _process(self, x: _IN_TYPE): 18 | raise NotImplementedError 19 | 20 | def __call__(self, x: Union[_IN_TYPE, List]): 21 | assert isinstance(x, (self._IN_TYPE, list)) 22 | if isinstance(x, list): 23 | return [self.__call__(e) for e in x] 24 | return self._process(x) 25 | 26 | @property 27 | def _repr_dict(self): 28 | return {k: v for k, v in self.__dict__.items() if k not in self._NO_REPR} 29 | 30 | def __repr__(self): 31 | attr_repr = ', '.join([f'{k}={v}' for k, v in self._repr_dict.items()]) 32 | return f'{self.__class__.__name__}({attr_repr})' 33 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .point import * 2 | from .keys import * 3 | from .color import * 4 | from .configs import * 5 | from .dropout import * 6 | from .hydra import * 7 | from .list import * 8 | from .tensor import * 9 | from .cpu import * 10 | from .features import * 11 | from .geometry import * 12 | from .io import * 13 | from .neighbors import * 14 | from .partition import * 15 | from .sparse import * 16 | from .edge import * 17 | from .pylogger import get_pylogger 18 | from .rich_utils import enforce_tags, print_config_tree 19 | from .utils import * 20 | from .histogram import * 21 | from .loss import * 22 | from .memory import * 23 | from .nn import * 24 | from .scatter import * 25 | from .encoding import * 26 | from .time import * 27 | from .multiprocessing import * 28 | from .wandb import * 29 | from .parameter import * 30 | from .graph import * 31 | from .semantic import * 32 | from .instance import * 33 | from .output_panoptic import * 34 | from .output_semantic import * 35 | from .widgets import * 36 | from .ground import * 37 | -------------------------------------------------------------------------------- /src/utils/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pyrootutils 4 | 5 | 6 | __all__ = ['get_config_structure'] 7 | 8 | 9 | def get_config_structure(start_directory=None, indent=0, verbose=False): 10 | """Parse a config file structure in search for .yaml files 11 | """ 12 | # If not provided, search the project configs directory 13 | if start_directory is None: 14 | root = str(pyrootutils.setup_root( 15 | search_from='', 16 | indicator=[".git", "README.md"], 17 | pythonpath=True, 18 | dotenv=True)) 19 | start_directory = osp.join(root, 'configs') 20 | 21 | # Structure to store the file hierarchy: 22 | # - first value is a dictionary of directories 23 | # - second value is a list of yaml files 24 | struct = ({}, []) 25 | 26 | # Recursively gather files and directories in the current directory 27 | for item in os.listdir(start_directory): 28 | item_path = os.path.join(start_directory, item) 29 | 30 | if os.path.isdir(item_path): 31 | if verbose: 32 | print(f"{' ' * indent}Directory: {item}") 33 | struct[0][item] = get_config_structure( 34 | start_directory=item_path, indent=indent + 1) 35 | 36 | elif os.path.isfile(item_path): 37 | filename, extension = osp.splitext(item) 38 | if extension == '.yaml': 39 | struct[1].append(filename) 40 | if verbose: 41 | print(f"{' ' * indent}File: {item}") 42 | 43 | return struct 44 | -------------------------------------------------------------------------------- /src/utils/cpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | 5 | 6 | __all__ = ['available_cpu_count'] 7 | 8 | 9 | def available_cpu_count(): 10 | """ Number of available virtual or physical CPUs on this system, i.e. 11 | user/real as output by time(1) when called with an optimally scaling 12 | userspace-only program""" 13 | 14 | # cpuset 15 | # cpuset may restrict the number of *available* processors 16 | try: 17 | m = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', 18 | open('/proc/self/status').read()) 19 | if m: 20 | res = bin(int(m.group(1).replace(',', ''), 16)).count('1') 21 | if res > 0: 22 | return res 23 | except IOError: 24 | pass 25 | 26 | # Python 2.6+ 27 | try: 28 | import multiprocessing 29 | return multiprocessing.cpu_count() 30 | except (ImportError, NotImplementedError): 31 | pass 32 | 33 | # https://github.com/giampaolo/psutil 34 | try: 35 | import psutil 36 | return psutil.cpu_count() # psutil.NUM_CPUS on old versions 37 | except (ImportError, AttributeError): 38 | pass 39 | 40 | # POSIX 41 | try: 42 | res = int(os.sysconf('SC_NPROCESSORS_ONLN')) 43 | 44 | if res > 0: 45 | return res 46 | except (AttributeError, ValueError): 47 | pass 48 | 49 | # Windows 50 | try: 51 | res = int(os.environ['NUMBER_OF_PROCESSORS']) 52 | 53 | if res > 0: 54 | return res 55 | except (KeyError, ValueError): 56 | pass 57 | 58 | # jython 59 | try: 60 | from java.lang import Runtime 61 | runtime = Runtime.getRuntime() 62 | res = runtime.availableProcessors() 63 | if res > 0: 64 | return res 65 | except ImportError: 66 | pass 67 | 68 | # BSD 69 | try: 70 | sysctl = subprocess.Popen(['sysctl', '-n', 'hw.ncpu'], 71 | stdout=subprocess.PIPE) 72 | scStdout = sysctl.communicate()[0] 73 | res = int(scStdout) 74 | 75 | if res > 0: 76 | return res 77 | except (OSError, ValueError): 78 | pass 79 | 80 | # Linux 81 | try: 82 | res = open('/proc/cpuinfo').read().count('processor\t:') 83 | 84 | if res > 0: 85 | return res 86 | except IOError: 87 | pass 88 | 89 | # Solaris 90 | try: 91 | pseudoDevices = os.listdir('/devices/pseudo/') 92 | res = 0 93 | for pd in pseudoDevices: 94 | if re.match(r'^cpuid@[0-9]+$', pd): 95 | res += 1 96 | 97 | if res > 0: 98 | return res 99 | except OSError: 100 | pass 101 | 102 | # Other UNIXes (heuristic) 103 | try: 104 | try: 105 | dmesg = open('/var/run/dmesg.boot').read() 106 | except IOError: 107 | dmesgProcess = subprocess.Popen(['dmesg'], stdout=subprocess.PIPE) 108 | dmesg = dmesgProcess.communicate()[0] 109 | 110 | res = 0 111 | while '\ncpu' + str(res) + ':' in dmesg: 112 | res += 1 113 | 114 | if res > 0: 115 | return res 116 | except OSError: 117 | pass 118 | 119 | raise Exception('Can not determine number of CPUs on this system') 120 | -------------------------------------------------------------------------------- /src/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from six.moves import urllib 4 | import ssl 5 | import subprocess 6 | 7 | 8 | def download_url(url, folder, log=True): 9 | """Download the content of an URL to a specific folder. 10 | 11 | :param url: string 12 | :param folder: string 13 | :param log: bool 14 | If `False`, will not print anything to the console. 15 | :return: 16 | """ 17 | filename = url.rpartition("/")[2] 18 | path = osp.join(folder, filename) 19 | if osp.exists(path): # pragma: no cover 20 | if log: 21 | print("Using exist file", filename) 22 | return path 23 | if log: 24 | print("Downloading", url) 25 | try: 26 | os.makedirs(folder) 27 | except: 28 | pass 29 | context = ssl._create_unverified_context() 30 | data = urllib.request.urlopen(url, context=context) 31 | with open(path, "wb") as f: 32 | f.write(data.read()) 33 | return path 34 | 35 | 36 | def run_command(cmd): 37 | """Run a command-line process from Python and print its outputs in 38 | an online fashion. 39 | 40 | Credit: https://www.endpointdev.com/blog/2015/01/getting-realtime-output-using-python/ 41 | """ 42 | # Create the process 43 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) 44 | # p = subprocess.run(cmd, shell=True) 45 | 46 | # Poll process.stdout to show stdout live 47 | while True: 48 | output = p.stdout.readline() 49 | if p.poll() is not None: 50 | break 51 | if output: 52 | print(output.strip()) 53 | rc = p.poll() 54 | print('Done') 55 | print('') 56 | 57 | return rc 58 | -------------------------------------------------------------------------------- /src/utils/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['dropout'] 5 | 6 | 7 | def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False): 8 | n = a.shape[dim] 9 | to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0] 10 | out = a if inplace else a.clone() 11 | 12 | 13 | if not to_mean: 14 | out.index_fill_(dim, to_drop, 0) 15 | return out 16 | 17 | if dim == 1: 18 | out[:, to_drop] = a.mean(dim=0)[to_drop] 19 | return out 20 | 21 | out[to_drop] = a.mean(dim=0) 22 | return out 23 | -------------------------------------------------------------------------------- /src/utils/edge.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 2 | from src.utils.sparse import indices_to_pointers 3 | from src.utils.tensor import arange_interleave 4 | 5 | 6 | __all__ = ['edge_index_to_uid', 'edge_wise_points'] 7 | 8 | 9 | def edge_index_to_uid(edge_index): 10 | """Compute consecutive unique identifiers for the edges. This may be 11 | needed for scatter operations. 12 | """ 13 | assert edge_index.dim() == 2 14 | assert edge_index.shape[0] == 2 15 | source = edge_index[0] 16 | target = edge_index[1] 17 | edge_uid = source * (max(source.max(), target.max()) + 1) + target 18 | edge_uid = consecutive_cluster(edge_uid)[0] 19 | return edge_uid 20 | 21 | 22 | def edge_wise_points(points, index, edge_index): 23 | """Given a graph of point segments, compute the concatenation of 24 | points belonging to either source or target segments for each edge 25 | of the segment graph. This operation arises when dealing with 26 | pairwise relationships between point segments. 27 | 28 | Warning: the output tensors might be memory-intensive 29 | 30 | :param points: (N, D) tensor 31 | Points 32 | :param index: (N) LongTensor 33 | Segment index, for each point 34 | :param edge_index: (2, E) LongTensor 35 | Edges of the segment graph 36 | """ 37 | assert points.dim() == 2 38 | assert index.dim() == 1 39 | assert points.shape[0] == index.shape[0] 40 | assert edge_index.dim() == 2 41 | assert edge_index.shape[0] == 2 42 | assert edge_index.max() <= index.max() 43 | 44 | # We define the segments in the first row of edge_index as 'source' 45 | # segments, while the elements of the second row are 'target' 46 | # segments. The corresponding variables are prepended with 's_' and 47 | # 't_' for clarity 48 | s_idx = edge_index[0] 49 | t_idx = edge_index[1] 50 | 51 | # Compute consecutive unique identifiers for the edges 52 | uid = edge_index_to_uid(edge_index) 53 | 54 | # Compute the pointers and ordering to express the segments and the 55 | # points they hold in CSR format 56 | pointers, order = indices_to_pointers(index) 57 | 58 | # Compute the size of each segment 59 | segment_size = index.bincount() 60 | 61 | # Expand the edge variables to point-edge values. That is, the 62 | # concatenation of all the source -or target- points for each edge. 63 | # The corresponding variables are prepended with 'S_' and 'T_' for 64 | # clarity 65 | def expand(source=True): 66 | x_idx = s_idx if source else t_idx 67 | size = segment_size[x_idx] 68 | start = pointers[:-1][x_idx] 69 | X_points_idx = order[arange_interleave(size, start=start)] 70 | X_points = points[X_points_idx] 71 | X_uid = uid.repeat_interleave(size, dim=0) 72 | return X_points, X_points_idx, X_uid 73 | 74 | S_points, S_points_idx, S_uid = expand(source=True) 75 | T_points, T_points_idx, T_uid = expand(source=False) 76 | 77 | return (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) 78 | -------------------------------------------------------------------------------- /src/utils/encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['fourier_position_encoder'] 5 | 6 | 7 | def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1): 8 | """ 9 | Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50% 10 | of the encoding dimensions are untouched and free to use. This is 11 | important when the positional encoding is added to learned feature 12 | embeddings. If the positional encoding uses too much of the encoding 13 | dimensions, it may be detrimental for the embeddings. 14 | 15 | The default `f_min` and `f_max` values are set so as to ensure 16 | a '~50% use of the encoding dimensions' and a '~1e-3 precision in 17 | the position encoding if pos is 1D'. 18 | 19 | :param pos: [M, M] Tensor 20 | Positions are expected to be in [-1, 1] 21 | :param dim: int 22 | Number of encoding dimensions, size of the encoding space. Note 23 | that increasing this is NOT the most direct way of improving 24 | spatial encoding precision or compactness. See `f_min` and 25 | `f_max` instead 26 | :param f_min: float 27 | Lower bound for the frequency range. Rules how much 'room' the 28 | positional encodings leave in the encoding space for additive 29 | embeddings 30 | :param f_max: float 31 | Upper bound for the frequency range. Rules how precise the 32 | encoding can be. Increase this if you need to capture finer 33 | spatial details 34 | :return: 35 | """ 36 | assert pos.abs().max() <= 1, "Positions must be in [-1, 1]" 37 | assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor" 38 | 39 | # We preferably operate 2D tensors 40 | if pos.dim() == 1: 41 | pos = pos.view(-1, 1) 42 | 43 | # Make sure M divides dim 44 | N, M = pos.shape 45 | D = dim // M 46 | # assert dim % M == 0, "`dim` must be a multiple of the number of input spatial dimensions" 47 | # assert D % 2 == 0, "`dim / M` must be a even number" 48 | 49 | # To avoid uncomfortable border effects with -1 and +1 coordinates 50 | # having the same (or very close) encodings, we convert [-1, 1] 51 | # coordinates to [-π/2, π/2] for safety 52 | pos = pos * torch.pi / 2 53 | 54 | # Compute frequencies on a logarithmic range from f_min to f_max 55 | device = pos.device 56 | f_min = torch.tensor([f_min], device=device) 57 | f_max = torch.tensor([f_max], device=device) 58 | w = torch.logspace(f_max.log(), f_min.log(), D, device=device) 59 | 60 | # Compute sine and cosine encodings 61 | pos_enc = pos.view(N, M, 1) * w.view(1, -1) 62 | pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos() 63 | pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin() 64 | pos_enc = pos_enc.view(N, -1) 65 | 66 | # In case dim is not a multiple of 2 * M, we pad missing dimensions 67 | # with zeros 68 | if pos_enc.shape[1] < dim: 69 | zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device) 70 | pos_enc = torch.hstack((pos_enc, zeros)) 71 | 72 | return pos_enc 73 | -------------------------------------------------------------------------------- /src/utils/features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.utils.color import to_float_rgb 3 | 4 | 5 | __all__ = ['rgb2hsv', 'rgb2lab'] 6 | 7 | 8 | def rgb2hsv(rgb, epsilon=1e-10): 9 | """Convert a 2D tensor of RGB colors int [0, 255] or float [0, 1] to 10 | HSV format. 11 | 12 | Credit: https://www.linuxtut.com/en/20819a90872275811439 13 | """ 14 | assert rgb.ndim == 2 15 | assert rgb.shape[1] == 3 16 | 17 | rgb = rgb.clone() 18 | 19 | # Convert colors to float in [0, 1] 20 | rgb = to_float_rgb(rgb) 21 | 22 | r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2] 23 | max_rgb, argmax_rgb = rgb.max(1) 24 | min_rgb, argmin_rgb = rgb.min(1) 25 | 26 | max_min = max_rgb - min_rgb + epsilon 27 | 28 | h1 = 60.0 * (g - r) / max_min + 60.0 29 | h2 = 60.0 * (b - g) / max_min + 180.0 30 | h3 = 60.0 * (r - b) / max_min + 300.0 31 | 32 | h = torch.stack((h2, h3, h1), dim=0).gather( 33 | dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0) 34 | s = max_min / (max_rgb + epsilon) 35 | v = max_rgb 36 | 37 | return torch.stack((h, s, v), dim=1) 38 | 39 | 40 | def rgb2lab(rgb): 41 | """Convert a tensor of RGB colors int[0, 255] or float [0, 1] to LAB 42 | colors. 43 | 44 | Reimplemented from: 45 | https://gist.github.com/manojpandey/f5ece715132c572c80421febebaf66ae 46 | """ 47 | rgb = rgb.clone() 48 | device = rgb.device 49 | 50 | # Convert colors to float in [0, 1] 51 | rgb = to_float_rgb(rgb) 52 | 53 | # Prepare RGB to XYZ 54 | mask = rgb > 0.04045 55 | rgb[mask] = ((rgb[mask] + 0.055) / 1.055) ** 2.4 56 | rgb[~mask] = rgb[~mask] / 12.92 57 | rgb *= 100 58 | 59 | # RGB to XYZ conversion 60 | m = torch.tensor([ 61 | [0.4124, 0.2126, 0.0193], 62 | [0.3576, 0.7152, 0.1192], 63 | [0.1805, 0.0722, 0.9505]], device=device) 64 | xyz = (rgb @ m).round(decimals=4) 65 | 66 | # Observer=2°, Illuminant=D6 67 | # ref_X=95.047, ref_Y=100.000, ref_Z=108.883 68 | scale = torch.tensor([[95.047, 100.0, 108.883]], device=device) 69 | xyz /= scale 70 | 71 | # Prepare XYZ for LAB 72 | mask = xyz > 0.008856 73 | xyz[mask] = xyz[mask] ** (1 / 3.) 74 | xyz[~mask] = 7.787 * xyz[~mask] + 1 / 7.25 75 | 76 | # XYZ to LAB conversion 77 | lab = torch.zeros_like(xyz) 78 | m = torch.tensor([ 79 | [0, 500, 0], 80 | [116, -500, 200], 81 | [0, 0, -200]], device=device, dtype=torch.float) 82 | lab = xyz @ m 83 | lab[:, 0] -= 16 84 | lab = lab.round(decimals=4) 85 | 86 | return lab 87 | -------------------------------------------------------------------------------- /src/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | __all__ = [ 6 | 'cross_product_matrix', 'rodrigues_rotation_matrix', 'base_vectors_3d'] 7 | 8 | 9 | def cross_product_matrix(k): 10 | """Compute the cross-product matrix of a vector k. 11 | 12 | Credit: https://github.com/torch-points3d/torch-points3d 13 | """ 14 | return torch.tensor( 15 | [[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]], device=k.device) 16 | 17 | 18 | def rodrigues_rotation_matrix(axis, theta_degrees): 19 | """Given an axis and a rotation angle, compute the rotation matrix 20 | using the Rodrigues formula. 21 | 22 | Source : https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 23 | Credit: https://github.com/torch-points3d/torch-points3d 24 | """ 25 | axis = axis / axis.norm() 26 | K = cross_product_matrix(axis) 27 | t = torch.tensor([theta_degrees / 180. * np.pi], device=axis.device) 28 | R = torch.eye(3, device=axis.device) \ 29 | + torch.sin(t) * K + (1 - torch.cos(t)) * K.mm(K) 30 | return R 31 | 32 | 33 | def base_vectors_3d(x): 34 | """Compute orthonormal bases for a set of 3D vectors. The 1st base 35 | vector is the normalized input vector, while the 2nd and 3rd vectors 36 | are constructed in the corresponding orthogonal plane. Note that 37 | this problem is underconstrained and, as such, any rotation of the 38 | output base around the 1st vector is a valid orthonormal base. 39 | """ 40 | assert x.dim() == 2 41 | assert x.shape[1] == 3 42 | 43 | # First direction is along x 44 | a = x 45 | 46 | # If x is 0 vector (norm=0), arbitrarily put a to (1, 0, 0) 47 | a[torch.where(a.norm(dim=1) == 0)[0]] = torch.tensor( 48 | [[1, 0, 0]], dtype=x.dtype, device=x.device) 49 | 50 | # Safely normalize a 51 | a = a / a.norm(dim=1).view(-1, 1) 52 | 53 | # Build a vector orthogonal to a 54 | b = torch.vstack((a[:, 1] - a[:, 2], a[:, 2] - a[:, 0], a[:, 0] - a[:, 1])).T 55 | 56 | # In the same fashion as when building a, the second base vector 57 | # may be 0 by construction (i.e. a is of type (v, v, v)). So we need 58 | # to deal with this edge case by setting 59 | b[torch.where(b.norm(dim=1) == 0)[0]] = torch.tensor( 60 | [[2, 1, -1]], dtype=x.dtype, device=x.device) 61 | 62 | # Safely normalize b 63 | b /= b.norm(dim=1).view(-1, 1) 64 | 65 | # Cross product of a and b to build the 3rd base vector 66 | c = torch.linalg.cross(a, b) 67 | 68 | return torch.cat((a.unsqueeze(1), b.unsqueeze(1), c.unsqueeze(1)), dim=1) 69 | -------------------------------------------------------------------------------- /src/utils/histogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | 4 | 5 | __all__ = ['histogram_to_atomic', 'atomic_to_histogram'] 6 | 7 | 8 | def histogram_to_atomic(gt, pred): 9 | """Convert ground truth and predictions at a segment level (i.e. 10 | ground truth is 2D tensor carrying histogram of labels in each 11 | segment), to pointwise 1D ground truth and predictions. 12 | 13 | :param gt: 1D or 2D torch.Tensor 14 | :param pred: 1D or 2D torch.Tensor 15 | """ 16 | assert gt.dim() <= 2 17 | 18 | # Edge cases where nothing happens 19 | if gt.dim() == 1: 20 | return gt, pred 21 | if gt.shape[1] == 1: 22 | return gt.squeeze(1), pred 23 | 24 | # Initialization 25 | num_nodes, num_classes = gt.shape 26 | device = pred.device 27 | 28 | # Flatten the pointwise ground truth 29 | point_gt = torch.arange( 30 | num_classes, device=device).repeat(num_nodes).repeat_interleave( 31 | gt.flatten()) 32 | 33 | # Expand the pointwise ground truth 34 | point_pred = pred.repeat_interleave(gt.sum(dim=1), dim=0) 35 | 36 | return point_gt, point_pred 37 | 38 | 39 | def atomic_to_histogram(item, idx, n_bins=None): 40 | """Convert point-level positive integer data to histograms of 41 | segment-level labels, based on idx. 42 | 43 | :param item: 1D or 2D torch.Tensor 44 | :param idx: 1D torch.Tensor 45 | """ 46 | assert item.ge(0).all(), \ 47 | "Mean aggregation only supports positive integers" 48 | assert item.dtype in [torch.uint8, torch.int, torch.long], \ 49 | "Mean aggregation only supports positive integers" 50 | assert item.ndim <= 2, \ 51 | "Voting and histograms are only supported for 1D and " \ 52 | "2D tensors" 53 | 54 | # Initialization 55 | n_bins = item.max() + 1 if n_bins is None else n_bins 56 | 57 | # Temporarily convert input item to long 58 | in_dtype = item.dtype 59 | item = item.long() 60 | 61 | # Important: if values are already 2D, we consider them to 62 | # be histograms and will simply scatter_add them 63 | if item.ndim == 2: 64 | return scatter_add(item, idx, dim=0) 65 | 66 | # Convert values to one-hot encoding. Values are temporarily offset 67 | # to 0 to save some memory and compute in one-hot encoding and 68 | # scatter_add 69 | offset = item.min() 70 | item = torch.nn.functional.one_hot(item - offset) 71 | 72 | # Count number of occurrence of each value 73 | hist = scatter_add(item, idx, dim=0) 74 | N = hist.shape[0] 75 | device = hist.device 76 | 77 | # Prepend 0 columns to the histogram for bins removed due to 78 | # offsetting 79 | bins_before = torch.zeros( 80 | N, offset, device=device, dtype=torch.long) 81 | hist = torch.cat((bins_before, hist), dim=1) 82 | 83 | # Append columns to the histogram for unobserved classes/bins 84 | bins_after = torch.zeros( 85 | N, n_bins - hist.shape[1], device=device, 86 | dtype=torch.long) 87 | hist = torch.cat((hist, bins_after), dim=1) 88 | 89 | # Restore input dtype 90 | hist = hist.to(in_dtype) 91 | 92 | return hist 93 | -------------------------------------------------------------------------------- /src/utils/hydra.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | from hydra import initialize, compose 3 | from hydra.core.global_hydra import GlobalHydra 4 | 5 | 6 | __all__ = ['init_config'] 7 | 8 | 9 | def init_config(config_name='train.yaml', overrides=[]): 10 | # Registering the "eval" resolver allows for advanced config 11 | # interpolation with arithmetic operations: 12 | # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html 13 | from omegaconf import OmegaConf 14 | if not OmegaConf.has_resolver('eval'): 15 | OmegaConf.register_new_resolver('eval', eval) 16 | 17 | GlobalHydra.instance().clear() 18 | pyrootutils.setup_root(".", pythonpath=True) 19 | with initialize(version_base='1.2', config_path="../../configs"): 20 | cfg = compose(config_name=config_name, overrides=overrides) 21 | return cfg 22 | -------------------------------------------------------------------------------- /src/utils/keys.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | 4 | __all__ = [ 5 | 'POINT_FEATURES', 'SEGMENT_BASE_FEATURES', 'SUBEDGE_FEATURES', 6 | 'ON_THE_FLY_HORIZONTAL_FEATURES', 'ON_THE_FLY_VERTICAL_FEATURES', 7 | 'sanitize_keys'] 8 | 9 | 10 | POINT_FEATURES = [ 11 | 'rgb', 12 | 'hsv', 13 | 'lab', 14 | 'density', 15 | 'linearity', 16 | 'planarity', 17 | 'scattering', 18 | 'verticality', 19 | 'elevation', 20 | 'normal', 21 | 'length', 22 | 'surface', 23 | 'volume', 24 | 'curvature', 25 | 'intensity', 26 | 'pos_room'] 27 | 28 | SEGMENT_BASE_FEATURES = [ 29 | 'linearity', 30 | 'planarity', 31 | 'scattering', 32 | 'verticality', 33 | 'curvature', 34 | 'log_length', 35 | 'log_surface', 36 | 'log_volume', 37 | 'normal', 38 | 'log_size'] 39 | 40 | SUBEDGE_FEATURES = [ 41 | 'mean_off', 42 | 'std_off', 43 | 'mean_dist'] 44 | 45 | ON_THE_FLY_HORIZONTAL_FEATURES = [ 46 | 'mean_off', 47 | 'std_off', 48 | 'mean_dist', 49 | 'angle_source', 50 | 'angle_target', 51 | 'centroid_dir', 52 | 'centroid_dist', 53 | 'normal_angle', 54 | 'log_length', 55 | 'log_surface', 56 | 'log_volume', 57 | 'log_size'] 58 | 59 | ON_THE_FLY_VERTICAL_FEATURES = [ 60 | 'centroid_dir', 61 | 'centroid_dist', 62 | 'normal_angle', 63 | 'log_length', 64 | 'log_surface', 65 | 'log_volume', 66 | 'log_size'] 67 | 68 | 69 | def sanitize_keys(keys, default=[]): 70 | """Sanitize an iterable of string key into a sorted list of unique 71 | keys. This is necessary for consistently hashing key list arguments 72 | of some transforms. 73 | """ 74 | # Convert to list of keys 75 | if isinstance(keys, str): 76 | out = [keys] 77 | elif isinstance(keys, Iterable): 78 | out = list(keys) 79 | else: 80 | out = list(default) 81 | 82 | assert all(isinstance(x, str) for x in out), \ 83 | f"Input 'keys' must be a string or an iterable of strings, but some " \ 84 | f"non-string elements were found in '{keys}'" 85 | 86 | # Remove duplicates and sort elements 87 | out = tuple(sorted(list(set(out)))) 88 | 89 | return out 90 | -------------------------------------------------------------------------------- /src/utils/list.py: -------------------------------------------------------------------------------- 1 | __all__ = ['listify', 'listify_with_reference'] 2 | 3 | 4 | def listify(obj): 5 | """Convert `obj` to nested lists. 6 | """ 7 | if obj is None or isinstance(obj, str): 8 | return obj 9 | if not hasattr(obj, '__len__'): 10 | return obj 11 | if hasattr(obj, 'dim') and obj.dim() == 0: 12 | return obj 13 | if len(obj) == 0: 14 | return obj 15 | return [listify(x) for x in obj] 16 | 17 | 18 | def listify_with_reference(arg_ref, *args): 19 | """listify `arg_ref` and the `args`, while ensuring that the length 20 | of `args` match the length of `arg_ref`. This is typically needed 21 | for parsing the input arguments of a function from an OmegaConf. 22 | """ 23 | arg_ref = listify(arg_ref) 24 | args_out = [listify(a) for a in args] 25 | 26 | if arg_ref is None: 27 | return [], *([] for _ in args) 28 | 29 | if not isinstance(arg_ref, list): 30 | return [arg_ref], *[[a] for a in args_out] 31 | 32 | if len(arg_ref) == 0: 33 | return [], *([] for _ in args) 34 | 35 | for i, a in enumerate(args_out): 36 | if not isinstance(a, list): 37 | a = [a] 38 | if len(a) != len(arg_ref): 39 | a = a * len(arg_ref) 40 | args_out[i] = a 41 | 42 | return arg_ref, *args_out 43 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['loss_with_sample_weights', 'loss_with_target_histogram'] 5 | 6 | 7 | def loss_with_sample_weights(criterion, pred, y, weights): 8 | assert weights.dim() == 1 9 | assert pred.shape[0] == y.shape[0] == weights.shape[0] 10 | 11 | reduction_backup = criterion.reduction 12 | criterion.reduction = 'none' 13 | 14 | weights = weights.float() / weights.sum() 15 | 16 | loss = criterion(pred, y) 17 | loss = loss.sum(dim=1) if loss.dim() > 1 else loss 18 | loss = (loss * weights).sum() 19 | 20 | criterion.reduction = reduction_backup 21 | 22 | return loss 23 | 24 | 25 | def loss_with_target_histogram(criterion, pred, y_hist): 26 | assert pred.dim() == 2 27 | assert y_hist.dim() == 2 28 | assert pred.shape[0] == y_hist.shape[0] 29 | 30 | y_mask = y_hist != 0 31 | logits_flat = pred.repeat_interleave(y_mask.sum(dim=1), dim=0) 32 | y_flat = torch.where(y_mask)[1] 33 | weights = y_hist[y_mask] 34 | 35 | loss = loss_with_sample_weights( 36 | criterion, logits_flat, y_flat, weights) 37 | 38 | return loss 39 | -------------------------------------------------------------------------------- /src/utils/memory.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | 4 | 5 | __all__ = ['print_memory_size', 'garbage_collection_cuda'] 6 | 7 | 8 | def print_memory_size(a): 9 | assert isinstance(a, torch.Tensor) 10 | memory = a.element_size() * a.nelement() 11 | if memory > 1024 * 1024 * 1024: 12 | print(f'Memory: {memory / (1024 * 1024 * 1024):0.3f} Gb') 13 | return 14 | if memory > 1024 * 1024: 15 | print(f'Memory: {memory / (1024 * 1024):0.3f} Mb') 16 | return 17 | if memory > 1024: 18 | print(f'Memory: {memory / 1024:0.3f} Kb') 19 | return 20 | print(f'Memory: {memory:0.3f} bytes') 21 | 22 | 23 | def is_oom_error(exception: BaseException) -> bool: 24 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 25 | 26 | 27 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 28 | def is_cuda_out_of_memory(exception: BaseException) -> bool: 29 | return ( 30 | isinstance(exception, RuntimeError) 31 | and len(exception.args) == 1 32 | and "CUDA" in exception.args[0] 33 | and "out of memory" in exception.args[0] 34 | ) 35 | 36 | 37 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 38 | def is_cudnn_snafu(exception: BaseException) -> bool: 39 | # For/because of https://github.com/pytorch/pytorch/issues/4107 40 | return ( 41 | isinstance(exception, RuntimeError) 42 | and len(exception.args) == 1 43 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 44 | ) 45 | 46 | 47 | # based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py 48 | def is_out_of_cpu_memory(exception: BaseException) -> bool: 49 | return ( 50 | isinstance(exception, RuntimeError) 51 | and len(exception.args) == 1 52 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 53 | ) 54 | 55 | 56 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 57 | def garbage_collection_cuda() -> None: 58 | """Garbage collection Torch (CUDA) memory.""" 59 | gc.collect() 60 | try: 61 | # This is the last thing that should cause an OOM error, but seemingly it can. 62 | torch.cuda.empty_cache() 63 | except RuntimeError as exception: 64 | if not is_oom_error(exception): 65 | # Only handle OOM errors 66 | raise 67 | -------------------------------------------------------------------------------- /src/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from itertools import repeat 3 | 4 | 5 | __all__ = ['starmap_with_kwargs'] 6 | 7 | 8 | def starmap_with_kwargs(fn, args_iter, kwargs_iter, processes=4): 9 | """By default, starmap only accepts args and not kwargs. This is a 10 | helper to get around this problem. 11 | 12 | :param fn: callable 13 | The function to starmap 14 | :param args_iter: iterable 15 | Iterable of the args 16 | :param kwargs_iter: iterable or dict 17 | Kwargs for `fn`. If an iterable is passed, the corresponding 18 | kwargs will be passed to each process. If a dictionary is 19 | passed, these same kwargs will be repeated and passed to all 20 | processes. NB: this behavior only works for kwargs, if the same 21 | args need to be passed to the `fn`, the adequate iterable must 22 | be passed as input 23 | :param processes: int 24 | Number of processes 25 | :return: 26 | """ 27 | # Prepare kwargs 28 | if kwargs_iter is None: 29 | kwargs_iter = repeat({}) 30 | if isinstance(kwargs_iter, dict): 31 | kwargs_iter = repeat(kwargs_iter) 32 | 33 | # Apply fn in multiple processes 34 | with multiprocessing.get_context("spawn").Pool(processes=processes) as pool: 35 | args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) 36 | out = pool.starmap(apply_args_and_kwargs, args_for_starmap) 37 | 38 | return out 39 | 40 | def apply_args_and_kwargs(fn, args, kwargs): 41 | return fn(*args, **kwargs) 42 | -------------------------------------------------------------------------------- /src/utils/parameter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | __all__ = ['LearnableParameter'] 5 | 6 | 7 | class LearnableParameter(nn.Parameter): 8 | """A simple class to be used for learnable parameters (e.g. learnable 9 | position encodings, queries, keys, ...). Using this is useful to use 10 | custom weight initialization. 11 | """ 12 | 13 | -------------------------------------------------------------------------------- /src/utils/partition.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 2 | from src.utils.point import is_xyz_tensor 3 | 4 | 5 | __all__ = ['xy_partition'] 6 | 7 | 8 | def xy_partition(pos, grid, consecutive=True): 9 | """Partition a point cloud based on a regular XY grid. Returns, for 10 | each point, the index of the grid cell it falls into. 11 | 12 | :param pos: Tensor 13 | Point cloud 14 | :param grid: float 15 | Grid size 16 | :param consecutive: bool 17 | Whether the grid cell indices should be consecutive. That is to 18 | say all indices in [0, idx_max] are used. Note that this may 19 | prevent trivially mapping an index value back to the 20 | corresponding XY coordinates 21 | :return: 22 | """ 23 | assert is_xyz_tensor(pos) 24 | 25 | # Compute the (i, j) coordinates on the XY grid size 26 | i = pos[:, 0].div(grid, rounding_mode='trunc').long() 27 | j = pos[:, 1].div(grid, rounding_mode='trunc').long() 28 | 29 | # Shift coordinates to positive integer to avoid negatives 30 | # clashing with our downstream indexing mechanism 31 | i -= i.min() 32 | j -= j.min() 33 | 34 | # Compute a "manual" partition based on the grid coordinates 35 | super_index = i * (max(i.max(), j.max()) + 1) + j 36 | 37 | # If required, update the used indices to be consecutive 38 | if consecutive: 39 | super_index = consecutive_cluster(super_index)[0] 40 | 41 | return super_index 42 | -------------------------------------------------------------------------------- /src/utils/point.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['is_xyz_tensor'] 5 | 6 | 7 | def is_xyz_tensor(xyz): 8 | if not isinstance(xyz, torch.Tensor): 9 | return False 10 | if not xyz.dim() == 2: 11 | return False 12 | return xyz.shape[1] == 3 13 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "datamodule", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str]): Determines in what order config components are printed. 37 | resolve (bool): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) 98 | 99 | 100 | if __name__ == "__main__": 101 | from hydra import compose, initialize 102 | 103 | with initialize(version_base="1.2", config_path="../../configs"): 104 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 105 | print_config_tree(cfg, resolve=False, save_to_file=False) 106 | -------------------------------------------------------------------------------- /src/utils/time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from time import time 3 | 4 | 5 | __all__ = ['timer'] 6 | 7 | 8 | def timer(f, *args, text='', text_size=64, **kwargs): 9 | if isinstance(text, str) and len(text) > 0: 10 | text = text 11 | elif hasattr(f, '__name__'): 12 | text = f.__name__ 13 | elif hasattr(f, '__class__'): 14 | text = f.__class__.__name__ 15 | else: 16 | text = '' 17 | torch.cuda.synchronize() 18 | start = time() 19 | out = f(*args, **kwargs) 20 | torch.cuda.synchronize() 21 | padding = '.' * (text_size - len(text)) 22 | print(f'{text}{padding}: {time() - start:0.3f}s') 23 | return out 24 | -------------------------------------------------------------------------------- /src/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | 4 | 5 | __all__ = ['wandb_confusion_matrix'] 6 | 7 | 8 | def wandb_confusion_matrix(cm, class_names=None, title=None): 9 | """Replaces the "normal" wandb way of logging a confusion matrix: 10 | 11 | https://github.com/wandb/wandb/blob/main/wandb/plot/confusion_matrix.py 12 | 13 | Indeed, the native wandb confusion matrix logging requires the 14 | element-wise prediction and ground truth. This is not adapted when 15 | we already have the confusion matrix at hand or that the number of 16 | elements is too large (e.g. point clouds). 17 | 18 | :param cm: 19 | :return: 20 | """ 21 | assert isinstance(cm, torch.Tensor) 22 | assert cm.dim() == 2 23 | assert cm.shape[0] == cm.shape[1] 24 | assert not cm.is_floating_point() 25 | 26 | # Move confusion matrix to CPU and convert to list 27 | cm = cm.cpu().tolist() 28 | num_classes = len(cm) 29 | 30 | # Prepare class names 31 | if class_names is None: 32 | class_names = [f"Class_{i}" for i in range(0, num_classes)] 33 | 34 | # Convert to wandb table format 35 | data = [] 36 | for i in range(num_classes): 37 | for j in range(num_classes): 38 | data.append([class_names[i], class_names[j], cm[i][j]]) 39 | 40 | columns = ["Actual", "Predicted", "nPredictions"] 41 | return wandb.plot_table( 42 | "wandb/confusion_matrix/v1", 43 | wandb.Table(columns=columns, data=data), 44 | {x: x for x in columns}, 45 | {"title": title or ""}) 46 | -------------------------------------------------------------------------------- /src/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization import show 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | import pytest 3 | from hydra import compose, initialize 4 | from hydra.core.global_hydra import GlobalHydra 5 | from omegaconf import DictConfig, open_dict 6 | 7 | 8 | @pytest.fixture(scope="package") 9 | def cfg_train_global() -> DictConfig: 10 | with initialize(version_base="1.2", config_path="../configs"): 11 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 12 | 13 | # set defaults for all tests 14 | with open_dict(cfg): 15 | cfg.paths.root_dir = str(pyrootutils.find_root()) 16 | cfg.trainer.max_epochs = 1 17 | cfg.trainer.limit_train_batches = 0.01 18 | cfg.trainer.limit_val_batches = 0.1 19 | cfg.trainer.limit_test_batches = 0.1 20 | cfg.trainer.accelerator = "cpu" 21 | cfg.trainer.devices = 1 22 | cfg.datamodule.num_workers = 0 23 | cfg.datamodule.pin_memory = False 24 | cfg.extras.print_config = False 25 | cfg.extras.enforce_tags = False 26 | cfg.logger = None 27 | 28 | return cfg 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def cfg_eval_global() -> DictConfig: 33 | with initialize(version_base="1.2", config_path="../configs"): 34 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 35 | 36 | # set defaults for all tests 37 | with open_dict(cfg): 38 | cfg.paths.root_dir = str(pyrootutils.find_root()) 39 | cfg.trainer.max_epochs = 1 40 | cfg.trainer.limit_test_batches = 0.1 41 | cfg.trainer.accelerator = "cpu" 42 | cfg.trainer.devices = 1 43 | cfg.datamodule.num_workers = 0 44 | cfg.datamodule.pin_memory = False 45 | cfg.extras.print_config = False 46 | cfg.extras.enforce_tags = False 47 | cfg.logger = None 48 | 49 | return cfg 50 | 51 | 52 | # this is called by each test which uses `cfg_train` arg 53 | # each test generates its own temporary logging path 54 | @pytest.fixture(scope="function") 55 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig: 56 | cfg = cfg_train_global.copy() 57 | 58 | with open_dict(cfg): 59 | cfg.paths.output_dir = str(tmp_path) 60 | cfg.paths.log_dir = str(tmp_path) 61 | 62 | yield cfg 63 | 64 | GlobalHydra.instance().clear() 65 | 66 | 67 | # this is called by each test which uses `cfg_eval` arg 68 | # each test generates its own temporary logging path 69 | @pytest.fixture(scope="function") 70 | def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: 71 | cfg = cfg_eval_global.copy() 72 | 73 | with open_dict(cfg): 74 | cfg.paths.output_dir = str(tmp_path) 75 | cfg.paths.log_dir = str(tmp_path) 76 | 77 | yield cfg 78 | 79 | GlobalHydra.instance().clear() 80 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drprojects/superpoint_transformer/a0f753b35b86e06d426113bdeac9b0123b220aa3/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment.""" 9 | try: 10 | return pkg_resources.require(package_name) is not None 11 | except pkg_resources.DistributionNotFound: 12 | return False 13 | 14 | 15 | _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() 16 | 17 | _IS_WINDOWS = platform.system() == "Windows" 18 | 19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 20 | 21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 23 | 24 | _WANDB_AVAILABLE = _package_available("wandb") 25 | _NEPTUNE_AVAILABLE = _package_available("neptune") 26 | _COMET_AVAILABLE = _package_available("comet_ml") 27 | _MLFLOW_AVAILABLE = _package_available("mlflow") 28 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]): 12 | """Default method for executing shell commands with pytest and sh package.""" 13 | msg = None 14 | try: 15 | sh.python(command) 16 | except sh.ErrorReturnCode as e: 17 | msg = e.stderr.decode() 18 | if msg: 19 | pytest.fail(msg=msg) 20 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig): 7 | assert cfg_train 8 | assert cfg_train.datamodule 9 | assert cfg_train.model 10 | assert cfg_train.trainer 11 | 12 | HydraConfig().set_config(cfg_train) 13 | 14 | hydra.utils.instantiate(cfg_train.datamodule) 15 | hydra.utils.instantiate(cfg_train.model) 16 | hydra.utils.instantiate(cfg_train.trainer) 17 | 18 | 19 | def test_eval_config(cfg_eval: DictConfig): 20 | assert cfg_eval 21 | assert cfg_eval.datamodule 22 | assert cfg_eval.model 23 | assert cfg_eval.trainer 24 | 25 | HydraConfig().set_config(cfg_eval) 26 | 27 | hydra.utils.instantiate(cfg_eval.datamodule) 28 | hydra.utils.instantiate(cfg_eval.model) 29 | hydra.utils.instantiate(cfg_eval.trainer) 30 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.eval import evaluate 8 | from src.train import train 9 | 10 | 11 | @pytest.mark.slow 12 | def test_train_eval(tmp_path, cfg_train, cfg_eval): 13 | """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" 14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 15 | 16 | with open_dict(cfg_train): 17 | cfg_train.trainer.max_epochs = 1 18 | cfg_train.test = True 19 | 20 | HydraConfig().set_config(cfg_train) 21 | train_metric_dict, _ = train(cfg_train) 22 | 23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 24 | 25 | with open_dict(cfg_eval): 26 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 27 | 28 | HydraConfig().set_config(cfg_eval) 29 | test_metric_dict, _ = evaluate(cfg_eval) 30 | 31 | assert test_metric_dict["test/acc"] > 0.0 32 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 33 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_if import RunIf 4 | from tests.helpers.run_sh_command import run_sh_command 5 | 6 | startfile = "src/train.py" 7 | overrides = ["logger=[]"] 8 | 9 | 10 | @RunIf(sh=True) 11 | @pytest.mark.slow 12 | def test_experiments(tmp_path): 13 | """Test running all available experiment configs with fast_dev_run=True.""" 14 | command = [ 15 | startfile, 16 | "-m", 17 | "experiment=glob(*)", 18 | "hydra.sweep.dir=" + str(tmp_path), 19 | "++trainer.fast_dev_run=true", 20 | ] + overrides 21 | run_sh_command(command) 22 | 23 | 24 | @RunIf(sh=True) 25 | @pytest.mark.slow 26 | def test_hydra_sweep(tmp_path): 27 | """Test default hydra sweep.""" 28 | command = [ 29 | startfile, 30 | "-m", 31 | "hydra.sweep.dir=" + str(tmp_path), 32 | "model.optimizer.lr=0.005,0.01", 33 | "++trainer.fast_dev_run=true", 34 | ] + overrides 35 | 36 | run_sh_command(command) 37 | 38 | 39 | @RunIf(sh=True) 40 | @pytest.mark.slow 41 | def test_hydra_sweep_ddp_sim(tmp_path): 42 | """Test default hydra sweep with ddp sim.""" 43 | command = [ 44 | startfile, 45 | "-m", 46 | "hydra.sweep.dir=" + str(tmp_path), 47 | "trainer=ddp_sim", 48 | "trainer.max_epochs=3", 49 | "+trainer.limit_train_batches=0.01", 50 | "+trainer.limit_val_batches=0.1", 51 | "+trainer.limit_test_batches=0.1", 52 | "model.optimizer.lr=0.005,0.01,0.02", 53 | ] + overrides 54 | run_sh_command(command) 55 | 56 | 57 | @RunIf(sh=True) 58 | @pytest.mark.slow 59 | def test_optuna_sweep(tmp_path): 60 | """Test optuna sweep.""" 61 | command = [ 62 | startfile, 63 | "-m", 64 | "hparams_search=mnist_optuna", 65 | "hydra.sweep.dir=" + str(tmp_path), 66 | "hydra.sweeper.n_trials=10", 67 | "hydra.sweeper.sampler.n_startup_trials=5", 68 | "++trainer.fast_dev_run=true", 69 | ] + overrides 70 | run_sh_command(command) 71 | 72 | 73 | @RunIf(wandb=True, sh=True) 74 | @pytest.mark.slow 75 | def test_optuna_sweep_ddp_sim_wandb(tmp_path): 76 | """Test optuna sweep with wandb and ddp sim.""" 77 | command = [ 78 | startfile, 79 | "-m", 80 | "hparams_search=mnist_optuna", 81 | "hydra.sweep.dir=" + str(tmp_path), 82 | "hydra.sweeper.n_trials=5", 83 | "trainer=ddp_sim", 84 | "trainer.max_epochs=3", 85 | "+trainer.limit_train_batches=0.01", 86 | "+trainer.limit_val_batches=0.1", 87 | "+trainer.limit_test_batches=0.1", 88 | "logger=wandb", 89 | ] 90 | run_sh_command(command) 91 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.train import train 8 | from tests.helpers.run_if import RunIf 9 | 10 | 11 | def test_train_fast_dev_run(cfg_train): 12 | """Run for 1 train, val and test step.""" 13 | HydraConfig().set_config(cfg_train) 14 | with open_dict(cfg_train): 15 | cfg_train.trainer.fast_dev_run = True 16 | cfg_train.trainer.accelerator = "cpu" 17 | train(cfg_train) 18 | 19 | 20 | @RunIf(min_gpus=1) 21 | def test_train_fast_dev_run_gpu(cfg_train): 22 | """Run for 1 train, val and test step on GPU.""" 23 | HydraConfig().set_config(cfg_train) 24 | with open_dict(cfg_train): 25 | cfg_train.trainer.fast_dev_run = True 26 | cfg_train.trainer.accelerator = "gpu" 27 | train(cfg_train) 28 | 29 | 30 | @RunIf(min_gpus=1) 31 | @pytest.mark.slow 32 | def test_train_epoch_gpu_amp(cfg_train): 33 | """Train 1 epoch on GPU with mixed-precision.""" 34 | HydraConfig().set_config(cfg_train) 35 | with open_dict(cfg_train): 36 | cfg_train.trainer.max_epochs = 1 37 | cfg_train.trainer.accelerator = "cpu" 38 | cfg_train.trainer.precision = 16 39 | train(cfg_train) 40 | 41 | 42 | @pytest.mark.slow 43 | def test_train_epoch_double_val_loop(cfg_train): 44 | """Train 1 epoch with validation loop twice per epoch.""" 45 | HydraConfig().set_config(cfg_train) 46 | with open_dict(cfg_train): 47 | cfg_train.trainer.max_epochs = 1 48 | cfg_train.trainer.val_check_interval = 0.5 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | def test_train_ddp_sim(cfg_train): 54 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" 55 | HydraConfig().set_config(cfg_train) 56 | with open_dict(cfg_train): 57 | cfg_train.trainer.max_epochs = 2 58 | cfg_train.trainer.accelerator = "cpu" 59 | cfg_train.trainer.devices = 2 60 | cfg_train.trainer.strategy = "ddp_spawn" 61 | train(cfg_train) 62 | 63 | 64 | @pytest.mark.slow 65 | def test_train_resume(tmp_path, cfg_train): 66 | """Run 1 epoch, finish, and resume for another epoch.""" 67 | with open_dict(cfg_train): 68 | cfg_train.trainer.max_epochs = 1 69 | 70 | HydraConfig().set_config(cfg_train) 71 | metric_dict_1, _ = train(cfg_train) 72 | 73 | files = os.listdir(tmp_path / "checkpoints") 74 | assert "last.ckpt" in files 75 | assert "epoch_000.ckpt" in files 76 | 77 | with open_dict(cfg_train): 78 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 79 | cfg_train.trainer.max_epochs = 2 80 | 81 | metric_dict_2, _ = train(cfg_train) 82 | 83 | files = os.listdir(tmp_path / "checkpoints") 84 | assert "epoch_001.ckpt" in files 85 | assert "epoch_002.ckpt" not in files 86 | 87 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] 88 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"] 89 | --------------------------------------------------------------------------------