├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ ├── framework-request.md │ └── question.md └── workflows │ ├── python-lint.yml │ ├── python-publish.yml │ └── python-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── .readthedocs.yml ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── disent ├── __init__.py ├── dataset │ ├── __init__.py │ ├── _base.py │ ├── data │ │ ├── __init__.py │ │ ├── _episodes.py │ │ ├── _episodes__custom.py │ │ ├── _groundtruth.py │ │ ├── _groundtruth__cars3d.py │ │ ├── _groundtruth__dsprites.py │ │ ├── _groundtruth__dsprites_imagenet.py │ │ ├── _groundtruth__mpi3d.py │ │ ├── _groundtruth__norb.py │ │ ├── _groundtruth__shapes3d.py │ │ ├── _groundtruth__sprites.py │ │ ├── _groundtruth__xcolumns.py │ │ ├── _groundtruth__xyobject.py │ │ ├── _groundtruth__xysquares.py │ │ ├── _random__teapots3d.py │ │ └── _raw.py │ ├── sampling │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _groundtruth__dist.py │ │ ├── _groundtruth__pair.py │ │ ├── _groundtruth__pair_orig.py │ │ ├── _groundtruth__single.py │ │ ├── _groundtruth__triplet.py │ │ ├── _groundtruth__walk.py │ │ ├── _random__any.py │ │ ├── _random__episodes.py │ │ └── _single.py │ ├── transform │ │ ├── __init__.py │ │ ├── _augment.py │ │ ├── _augment_disent.py │ │ ├── _transforms.py │ │ └── functional.py │ ├── util │ │ ├── __init__.py │ │ ├── datafile.py │ │ ├── formats │ │ │ ├── __init__.py │ │ │ ├── hdf5.py │ │ │ └── npz.py │ │ ├── state_space.py │ │ └── stats.py │ └── wrapper │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _dither.py │ │ └── _masked.py ├── frameworks │ ├── __init__.py │ ├── _ae_mixin.py │ ├── _framework.py │ ├── ae │ │ ├── __init__.py │ │ ├── _supervised__adaneg_tae.py │ │ ├── _supervised__tae.py │ │ ├── _unsupervised__ae.py │ │ ├── _unsupervised__dotae.py │ │ └── _weaklysupervised__adaae.py │ ├── helper │ │ ├── __init__.py │ │ ├── latent_distributions.py │ │ ├── reconstructions.py │ │ └── util.py │ └── vae │ │ ├── __init__.py │ │ ├── _supervised__adaneg_tvae.py │ │ ├── _supervised__tvae.py │ │ ├── _unsupervised__betatcvae.py │ │ ├── _unsupervised__betavae.py │ │ ├── _unsupervised__dfcvae.py │ │ ├── _unsupervised__dipvae.py │ │ ├── _unsupervised__dotvae.py │ │ ├── _unsupervised__infovae.py │ │ ├── _unsupervised__vae.py │ │ └── _weaklysupervised__adavae.py ├── metrics │ ├── __init__.py │ ├── _dci.py │ ├── _factor_vae.py │ ├── _factored_components.py │ ├── _flatness.py │ ├── _mig.py │ ├── _sap.py │ ├── _unsupervised.py │ └── utils.py ├── model │ ├── __init__.py │ ├── _base.py │ └── ae │ │ ├── __init__.py │ │ ├── _linear.py │ │ ├── _norm_conv64.py │ │ ├── _vae_conv64.py │ │ └── _vae_fc.py ├── nn │ ├── __init__.py │ ├── activations.py │ ├── functional │ │ ├── __init__.py │ │ ├── _conv2d.py │ │ ├── _conv2d_kernels.py │ │ ├── _correlation.py │ │ ├── _dct.py │ │ ├── _mean.py │ │ ├── _norm.py │ │ ├── _other.py │ │ ├── _pca.py │ │ └── _util_generic.py │ ├── loss │ │ ├── __init__.py │ │ ├── kl.py │ │ ├── reduction.py │ │ ├── softsort.py │ │ ├── triplet.py │ │ └── triplet_mining.py │ ├── modules.py │ └── weights.py ├── registry │ ├── __init__.py │ └── _registry.py ├── schedule │ ├── __init__.py │ ├── _schedule.py │ └── lerp.py └── util │ ├── __init__.py │ ├── array.py │ ├── deprecate.py │ ├── function.py │ ├── imports.py │ ├── inout │ ├── __init__.py │ ├── cache.py │ ├── files.py │ ├── hashing.py │ ├── paths.py │ └── tar.py │ ├── iters.py │ ├── jit.py │ ├── lightning │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── _callback_log_metrics.py │ │ ├── _callback_print_progress.py │ │ ├── _callback_vis_dists.py │ │ ├── _callback_vis_latents.py │ │ ├── _callbacks_base.py │ │ └── _helper.py │ └── logger_util.py │ ├── math │ ├── __init__.py │ ├── dither.py │ ├── integer.py │ └── random.py │ ├── profiling.py │ ├── seeds.py │ ├── strings │ ├── __init__.py │ ├── colors.py │ └── fmt.py │ └── visualize │ ├── __init__.py │ ├── plot.py │ ├── vis_img.py │ ├── vis_latents.py │ └── vis_util.py ├── docs ├── examples │ ├── extend_experiment │ │ ├── code │ │ │ ├── __init__.py │ │ │ ├── groundtruth__xyblocks.py │ │ │ ├── random_data.py │ │ │ ├── weaklysupervised__si_adavae.py │ │ │ └── weaklysupervised__si_betavae.py │ │ ├── config │ │ │ ├── README.md │ │ │ ├── config_alt.yaml │ │ │ ├── config_alt_test.yaml │ │ │ ├── dataset │ │ │ │ ├── E--mask-dthr-cars3d.yaml │ │ │ │ ├── E--mask-dthr-dsprites.yaml │ │ │ │ ├── E--mask-dthr-pseudorandom.yaml │ │ │ │ ├── E--mask-dthr-shapes3d.yaml │ │ │ │ ├── E--mask-dthr-smallnorb.yaml │ │ │ │ ├── E--pseudorandom.yaml │ │ │ │ ├── E--xyblocks.yaml │ │ │ │ └── E--xyblocks_grey.yaml │ │ │ ├── framework │ │ │ │ ├── E--si-adavae.yaml │ │ │ │ └── E--si-betavae.yaml │ │ │ ├── run_location │ │ │ │ ├── cluster_rsync_tmp.yaml │ │ │ │ ├── cluster_shr.yaml │ │ │ │ └── cluster_tmp.yaml │ │ │ └── run_plugins │ │ │ │ └── default.yaml │ │ ├── run.py │ │ └── run.sh │ ├── mnist_example.py │ ├── overview_data.py │ ├── overview_dataset_loader.py │ ├── overview_dataset_pair.py │ ├── overview_dataset_pair_augment.py │ ├── overview_dataset_single.py │ ├── overview_framework_adagvae.py │ ├── overview_framework_ae.py │ ├── overview_framework_betavae.py │ ├── overview_framework_betavae_scheduled.py │ ├── overview_framework_train_val.py │ ├── overview_metrics.py │ ├── plotting_examples │ │ ├── __init__.py │ │ ├── compute_xysquares_overlap.py │ │ ├── plot_dataset_overlap_diffs.py │ │ ├── plot_dataset_traversal_dists.py │ │ ├── plot_dataset_traversals.py │ │ ├── plots │ │ │ └── .gitignore │ │ ├── save_dataset_animation.py │ │ ├── save_metric_visualisation.py │ │ └── util │ │ │ ├── __init__.py │ │ │ └── gadfly.mplstyle │ └── readme_example.py ├── img │ └── traversals │ │ ├── traversal-transpose__cars3d.jpg │ │ ├── traversal-transpose__dsprites-imagenet-bg-100.jpg │ │ ├── traversal-transpose__dsprites-imagenet-bg-50.jpg │ │ ├── traversal-transpose__dsprites-imagenet-fg-100.jpg │ │ ├── traversal-transpose__dsprites-imagenet-fg-50.jpg │ │ ├── traversal-transpose__dsprites.jpg │ │ ├── traversal-transpose__mpi3d-real.jpg │ │ ├── traversal-transpose__mpi3d-realistic.jpg │ │ ├── traversal-transpose__mpi3d-toy.jpg │ │ ├── traversal-transpose__shapes3d.jpg │ │ ├── traversal-transpose__smallnorb.jpg │ │ ├── traversal-transpose__sprites.jpg │ │ ├── traversal-transpose__xy-blocks.jpg │ │ ├── traversal-transpose__xy-object-shaded.jpg │ │ ├── traversal-transpose__xy-object.jpg │ │ ├── traversal-transpose__xy-single-square__spacing8.jpg │ │ └── traversal-transpose__xy-squares__spacing8.jpg ├── index.md ├── quickstart.md └── requirements.txt ├── experiment ├── __init__.py ├── config │ ├── __init__.py │ ├── augment │ │ ├── example.yaml │ │ └── none.yaml │ ├── config.yaml │ ├── config_test.yaml │ ├── dataset │ │ ├── _data_type_ │ │ │ ├── episodes.yaml │ │ │ ├── gt.yaml │ │ │ └── random.yaml │ │ ├── cars3d.yaml │ │ ├── dsprites-imagenet-bg-100.yaml │ │ ├── dsprites-imagenet-bg-25.yaml │ │ ├── dsprites-imagenet-bg-50.yaml │ │ ├── dsprites-imagenet-bg-75.yaml │ │ ├── dsprites-imagenet-fg-100.yaml │ │ ├── dsprites-imagenet-fg-25.yaml │ │ ├── dsprites-imagenet-fg-50.yaml │ │ ├── dsprites-imagenet-fg-75.yaml │ │ ├── dsprites-imagenet.yaml │ │ ├── dsprites.yaml │ │ ├── mpi3d_real.yaml │ │ ├── mpi3d_realistic.yaml │ │ ├── mpi3d_toy.yaml │ │ ├── shapes3d.yaml │ │ ├── smallnorb.yaml │ │ ├── sprites.yaml │ │ ├── sprites_all.yaml │ │ ├── xyobject.yaml │ │ ├── xyobject_grey.yaml │ │ ├── xyobject_shaded.yaml │ │ ├── xyobject_shaded_grey.yaml │ │ ├── xysquares.yaml │ │ ├── xysquares_grey.yaml │ │ └── xysquares_rgb.yaml │ ├── framework │ │ ├── _input_mode_ │ │ │ ├── pair.yaml │ │ │ ├── single.yaml │ │ │ ├── triplet.yaml │ │ │ └── weak_pair.yaml │ │ ├── adaae.yaml │ │ ├── adaae_os.yaml │ │ ├── adagvae_minimal_os.yaml │ │ ├── adanegtae.yaml │ │ ├── adanegtae_d.yaml │ │ ├── adanegtvae.yaml │ │ ├── adanegtvae_d.yaml │ │ ├── adanegtvae_d_aug.yaml │ │ ├── adavae.yaml │ │ ├── adavae_os.yaml │ │ ├── ae.yaml │ │ ├── betatcvae.yaml │ │ ├── betavae.yaml │ │ ├── dfcvae.yaml │ │ ├── dipvae.yaml │ │ ├── infovae.yaml │ │ ├── tae.yaml │ │ ├── tvae.yaml │ │ └── vae.yaml │ ├── metrics │ │ ├── all.yaml │ │ ├── fast.yaml │ │ ├── none.yaml │ │ └── test.yaml │ ├── model │ │ ├── linear.yaml │ │ ├── norm_conv64.yaml │ │ ├── vae_conv64.yaml │ │ └── vae_fc.yaml │ ├── optimizer │ │ ├── adabelief.yaml │ │ ├── adam.yaml │ │ ├── amsgrad.yaml │ │ ├── radam.yaml │ │ ├── rmsprop.yaml │ │ └── sgd.yaml │ ├── run_action │ │ ├── prepare_data.yaml │ │ ├── skip.yaml │ │ └── train.yaml │ ├── run_callbacks │ │ ├── all.yaml │ │ ├── none.yaml │ │ ├── test.yaml │ │ ├── vis.yaml │ │ ├── vis_debug.yaml │ │ ├── vis_fast.yaml │ │ ├── vis_quick.yaml │ │ ├── vis_skip_first.yaml │ │ └── vis_slow.yaml │ ├── run_launcher │ │ ├── local.yaml │ │ └── slurm.yaml │ ├── run_length │ │ ├── debug.yaml │ │ ├── epic.yaml │ │ ├── long.yaml │ │ ├── longmed.yaml │ │ ├── medium.yaml │ │ ├── short.yaml │ │ ├── test.yaml │ │ ├── tiny.yaml │ │ └── xtiny.yaml │ ├── run_location │ │ ├── local.yaml │ │ ├── local_cpu.yaml │ │ └── local_gpu.yaml │ ├── run_logging │ │ ├── none.yaml │ │ ├── wandb.yaml │ │ ├── wandb_fast.yaml │ │ ├── wandb_fast_offline.yaml │ │ └── wandb_slow.yaml │ ├── run_plugins │ │ └── default.yaml │ ├── sampling │ │ ├── _sampler_ │ │ │ ├── episodes__pair.yaml │ │ │ ├── episodes__single.yaml │ │ │ ├── episodes__triplet.yaml │ │ │ ├── episodes__weak_pair.yaml │ │ │ ├── gt__pair.yaml │ │ │ ├── gt__single.yaml │ │ │ ├── gt__triplet.yaml │ │ │ ├── gt__weak_pair.yaml │ │ │ ├── gt_dist__pair.yaml │ │ │ ├── gt_dist__single.yaml │ │ │ ├── gt_dist__triplet.yaml │ │ │ ├── gt_dist__weak_pair.yaml │ │ │ ├── random__pair.yaml │ │ │ ├── random__single.yaml │ │ │ ├── random__triplet.yaml │ │ │ └── random__weak_pair.yaml │ │ ├── default.yaml │ │ ├── default__bb.yaml │ │ ├── default__ran_l1.yaml │ │ ├── default__ran_l2.yaml │ │ ├── gt_dist__combined.yaml │ │ ├── gt_dist__combined_scaled.yaml │ │ ├── gt_dist__factors.yaml │ │ ├── gt_dist__manhat.yaml │ │ ├── gt_dist__manhat_scaled.yaml │ │ ├── gt_dist__random.yaml │ │ ├── none.yaml │ │ └── random.yaml │ └── schedule │ │ ├── adanegtvae_up_all.yaml │ │ ├── adanegtvae_up_all_full.yaml │ │ ├── adanegtvae_up_all_weak.yaml │ │ ├── adanegtvae_up_ratio.yaml │ │ ├── adanegtvae_up_ratio_full.yaml │ │ ├── adanegtvae_up_ratio_weak.yaml │ │ ├── adanegtvae_up_thresh.yaml │ │ ├── beta_cyclic.yaml │ │ ├── beta_cyclic_fast.yaml │ │ ├── beta_cyclic_slow.yaml │ │ ├── beta_decrease.yaml │ │ ├── beta_delay.yaml │ │ ├── beta_delay_long.yaml │ │ ├── beta_increase.yaml │ │ └── none.yaml ├── run.py └── util │ ├── __init__.py │ ├── _hydra_searchpath_plugin_ │ └── hydra_plugins │ │ └── searchpath_plugin.py │ ├── hydra_data.py │ ├── hydra_main.py │ ├── hydra_utils.py │ ├── path_utils.py │ └── run_utils.py ├── mkdocs.yml ├── pytest.ini ├── requirements-extra.txt ├── requirements-test.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_000_import.py ├── test_data_similarity.py ├── test_dataset_formats.py ├── test_docs_examples.py ├── test_experiment.py ├── test_frameworks.py ├── test_math.py ├── test_math_generic.py ├── test_metrics.py ├── test_models.py ├── test_registry.py ├── test_samplers.py ├── test_state_space.py ├── test_to_img.py ├── test_transform.py └── util.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a bug report to help us improve 4 | title: "[BUG]: " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behaviour. 15 | 16 | **Expected behaviour** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Additional context** 20 | Add any other context about the problem here, including any relevant system information and python version. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]: " 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/framework-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Framework Request 3 | about: Suggest for a new framework to be added to this project 4 | title: "[FRAMEWORK]: " 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Which framework would you like added to this project** 11 | - framework name 12 | - link to academic paper 13 | 14 | **Why should this framework be added?** 15 | What benefit is there to adding this framework? 16 | 17 | **Short summary of framework** 18 | What is the core algorithmic idea behind the framework in simple terms. Please give a general overview rather than advanced algorithmic concepts. 19 | - eg. The beta-VAE weights the regularisation term, aiming to improve disentanglement. 20 | - e.g The DFC-VAE augments the reconstruction loss of the VAE/beta-VAE with a perceptual loss. 21 | 22 | **Which framework does this build upon, if any?** 23 | eg. the beta-VAE extends the standard VAE 24 | eg. the Adaptive VAEs extend the beta-VAE 25 | 26 | **Are you willing to submit a PR?** 27 | Are you willing to work on this implementation and submit a PR? 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question about this project 4 | title: "[Q]: " 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Ask Away!** 11 | 12 | - Please double check the [docs](https://disent.dontpanic.sh) to make sure that your question is not already answered there. 13 | - Please double check the issues to make sure that your question has not been answered before. 14 | -------------------------------------------------------------------------------- /.github/workflows/python-lint.yml: -------------------------------------------------------------------------------- 1 | # make sure to update the corresponding configs: 2 | # - `.pre-commit-config.yaml` 3 | # - `requirements-dev.txt` 4 | 5 | name: lint 6 | 7 | on: [pull_request] 8 | 9 | jobs: 10 | black: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: psf/black@23.1.0 15 | with: 16 | options: "--check --verbose --diff --color --target-version=py38 --line-length=120" 17 | src: "." 18 | isort: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v3 22 | - uses: isort/isort-action@v1.1.0 23 | with: 24 | configuration: "--check --verbose --diff --color --py=38 --profile=black --line-length=120 --force-single-line-imports --skip-glob='disent/**__init__.py'" 25 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package 2 | # using Twine when a release is created 3 | 4 | name: publish 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Update version in setup.py to ${{ github.ref_name }} 22 | run: sed -i "s/{{VERSION_PLACEHOLDER}}/${{ github.ref_name }}/g" setup.py 23 | 24 | - name: Install dependencies 25 | run: | 26 | python3 -m pip install --upgrade pip 27 | python3 -m pip install setuptools wheel twine 28 | 29 | - name: Build and publish 30 | env: 31 | TWINE_USERNAME: __token__ 32 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 33 | run: | 34 | python3 setup.py sdist bdist_wheel 35 | python3 -m twine upload dist/* 36 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, 2 | # then run tests over a variety of Python versions. 3 | 4 | name: tests 5 | 6 | on: [pull_request] 7 | 8 | jobs: 9 | test: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest] # [ubuntu-latest, windows-latest, macos-latest] 14 | python-version: ["3.8", "3.9", "3.10"] 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | # -- caching actually slows down the action! 24 | # cache: 'pip' 25 | # cache-dependency-path: | 26 | # requirements*.txt 27 | 28 | - name: Install dependencies 29 | # torchsort needs torch first 30 | run: | 31 | python3 -m pip install --upgrade pip 32 | python3 -m pip install "torch>=2.0.0" 33 | python3 -m pip install -r requirements-test.txt 34 | 35 | - name: Test with pytest 36 | run: | 37 | python3 -m pytest --cov=disent tests/ 38 | 39 | - uses: codecov/codecov-action@v3 40 | with: 41 | token: ${{ secrets.CODECOV_TOKEN }} 42 | fail_ci_if_error: false 43 | # codecov automatically merges all generated files 44 | # if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 45 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - { id: check-added-large-files, args: ["--maxkb=300"] } 6 | - { id: check-case-conflict } 7 | - { id: detect-private-key } 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: ["--verbose", "--py=38", "--profile=black", "--line-length=120", "--force-single-line-imports"] 13 | exclude: ^disent/(.+/)?__init__\.py 14 | - repo: https://github.com/psf/black 15 | rev: 23.1.0 16 | hooks: 17 | - id: black 18 | args: ["--verbose", "--target-version=py38", "--line-length=120"] 19 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | ignore-patterns=__test__* 4 | 5 | [MESSAGES CONTROL] 6 | 7 | disable=line-too-long, 8 | missing-function-docstring, 9 | missing-module-docstring, 10 | missing-class-docstring, 11 | too-many-arguments, 12 | too-many-ancestors, 13 | too-many-locals, 14 | invalid-name, # TODO RE-ENABLE 15 | 16 | [BASIC] 17 | 18 | good-names=x,y -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | mkdocs: 9 | configuration: mkdocs.yml 10 | fail_on_warning: false 11 | 12 | # Optionally build your docs in additional formats such as PDF 13 | formats: all 14 | 15 | # Optionally set the version of Python and requirements required to build your docs 16 | python: 17 | version: 3.8 18 | install: 19 | - requirements: docs/requirements.txt 20 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nathan Juraj Michlo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /disent/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # wrapper 26 | from disent.dataset._base import DisentDataset 27 | from disent.dataset._base import DisentIterDataset 28 | -------------------------------------------------------------------------------- /disent/dataset/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # base sampler 26 | from disent.dataset.sampling._base import BaseDisentSampler 27 | 28 | # ground truth samplers 29 | from disent.dataset.sampling._groundtruth__dist import GroundTruthDistSampler 30 | from disent.dataset.sampling._groundtruth__pair import GroundTruthPairSampler 31 | from disent.dataset.sampling._groundtruth__pair_orig import GroundTruthPairOrigSampler 32 | from disent.dataset.sampling._groundtruth__single import GroundTruthSingleSampler 33 | from disent.dataset.sampling._groundtruth__triplet import GroundTruthTripleSampler 34 | from disent.dataset.sampling._groundtruth__walk import GroundTruthRandomWalkSampler 35 | 36 | # any dataset samplers 37 | from disent.dataset.sampling._single import SingleSampler 38 | from disent.dataset.sampling._random__any import RandomSampler 39 | 40 | # episode samplers 41 | from disent.dataset.sampling._random__episodes import RandomEpisodeSampler 42 | -------------------------------------------------------------------------------- /disent/dataset/sampling/_random__any.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | from typing import Tuple 26 | 27 | import numpy as np 28 | 29 | from disent.dataset.sampling._base import BaseDisentSampler 30 | 31 | # ========================================================================= # 32 | # Randomly Paired Dataset # 33 | # ========================================================================= # 34 | 35 | 36 | class RandomSampler(BaseDisentSampler): 37 | def uninit_copy(self) -> "RandomSampler": 38 | return RandomSampler(num_samples=self.num_samples) 39 | 40 | def __init__(self, num_samples=1): 41 | super().__init__(num_samples=num_samples) 42 | 43 | def _init(self, dataset): 44 | self._len = len(dataset) 45 | 46 | def _sample_idx(self, idx: int) -> Tuple[int, ...]: 47 | # sample indices 48 | return (idx, *np.random.randint(0, self._len, size=self._num_samples - 1)) 49 | 50 | 51 | # ========================================================================= # 52 | # End # 53 | # ========================================================================= # 54 | -------------------------------------------------------------------------------- /disent/dataset/sampling/_single.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | from typing import Tuple 26 | 27 | import numpy as np 28 | 29 | from disent.dataset.sampling._base import BaseDisentSampler 30 | 31 | # ========================================================================= # 32 | # Randomly Paired Dataset # 33 | # ========================================================================= # 34 | 35 | 36 | class SingleSampler(BaseDisentSampler): 37 | def uninit_copy(self) -> "SingleSampler": 38 | return SingleSampler() 39 | 40 | def __init__(self): 41 | super().__init__(num_samples=1) 42 | 43 | def _init(self, dataset): 44 | pass 45 | 46 | def _sample_idx(self, idx: int) -> Tuple[int, ...]: 47 | return (idx,) 48 | 49 | 50 | # ========================================================================= # 51 | # End # 52 | # ========================================================================= # 53 | -------------------------------------------------------------------------------- /disent/dataset/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # transforms 26 | from disent.dataset.transform._transforms import CheckTensor 27 | from disent.dataset.transform._transforms import Noop 28 | from disent.dataset.transform._transforms import ToImgTensorF32 29 | from disent.dataset.transform._transforms import ToImgTensorU8 30 | from disent.dataset.transform._transforms import ToStandardisedTensor # deprecated 31 | from disent.dataset.transform._transforms import ToUint8Tensor # deprecated 32 | 33 | # augments 34 | from disent.dataset.transform._augment import FftGaussianBlur 35 | from disent.dataset.transform._augment import FftBoxBlur 36 | from disent.dataset.transform._augment import FftKernel 37 | 38 | # disent dataset augment 39 | from disent.dataset.transform._augment_disent import DisentDatasetTransform 40 | -------------------------------------------------------------------------------- /disent/dataset/util/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/dataset/util/formats/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/dataset/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # base wrapper 26 | from disent.dataset.wrapper._base import WrappedDataset 27 | 28 | # wrapper datasets 29 | from disent.dataset.wrapper._dither import DitheredDataset 30 | from disent.dataset.wrapper._masked import MaskedDataset 31 | -------------------------------------------------------------------------------- /disent/dataset/wrapper/_base.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | import logging 26 | 27 | from torch.utils.data import Dataset 28 | 29 | from disent.dataset.data import GroundTruthData 30 | 31 | log = logging.getLogger(__name__) 32 | 33 | 34 | # ========================================================================= # 35 | # Dithered Dataset # 36 | # ========================================================================= # 37 | 38 | 39 | class WrappedDataset(Dataset): 40 | def __len__(self): 41 | raise NotImplementedError 42 | 43 | def __getitem__(self, item): 44 | raise NotImplementedError 45 | 46 | @property 47 | def data(self) -> Dataset: 48 | raise NotImplementedError 49 | 50 | @property 51 | def gt_data(self) -> GroundTruthData: 52 | assert isinstance(self.data, GroundTruthData) 53 | return self.data 54 | 55 | 56 | # ========================================================================= # 57 | # END # 58 | # ========================================================================= # 59 | -------------------------------------------------------------------------------- /disent/frameworks/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # export 27 | from disent.frameworks._framework import DisentConfigurable 28 | from disent.frameworks._framework import DisentFramework 29 | -------------------------------------------------------------------------------- /disent/frameworks/ae/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # supervised frameworks 26 | from disent.frameworks.ae._supervised__tae import TripletAe 27 | from disent.frameworks.ae._supervised__adaneg_tae import AdaNegTripletAe 28 | 29 | # unsupervised frameworks 30 | from disent.frameworks.ae._unsupervised__ae import Ae 31 | from disent.frameworks.ae._unsupervised__dotae import DataOverlapTripletAe 32 | 33 | # weakly supervised frameworks 34 | from disent.frameworks.ae._weaklysupervised__adaae import AdaAe 35 | -------------------------------------------------------------------------------- /disent/frameworks/helper/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/frameworks/vae/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # supervised frameworks 26 | from disent.frameworks.vae._supervised__tvae import TripletVae 27 | from disent.frameworks.vae._supervised__adaneg_tvae import AdaNegTripletVae 28 | 29 | # unsupervised frameworks 30 | from disent.frameworks.vae._unsupervised__betatcvae import BetaTcVae 31 | from disent.frameworks.vae._unsupervised__betavae import BetaVae 32 | from disent.frameworks.vae._unsupervised__dfcvae import DfcVae 33 | from disent.frameworks.vae._unsupervised__dipvae import DipVae 34 | from disent.frameworks.vae._unsupervised__infovae import InfoVae 35 | from disent.frameworks.vae._unsupervised__vae import Vae 36 | from disent.frameworks.vae._unsupervised__dotvae import DataOverlapTripletVae 37 | 38 | # weakly supervised frameworks 39 | from disent.frameworks.vae._weaklysupervised__adavae import AdaVae 40 | from disent.frameworks.vae._weaklysupervised__adavae import AdaGVaeMinimal 41 | -------------------------------------------------------------------------------- /disent/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # Expose 26 | from ._dci import metric_dci 27 | from ._factor_vae import metric_factor_vae 28 | from ._mig import metric_mig 29 | from ._sap import metric_sap 30 | from ._unsupervised import metric_unsupervised 31 | 32 | # Michlo et al. 33 | from disent.metrics._flatness import metric_flatness 34 | from disent.metrics._factored_components import metric_factored_components 35 | from disent.metrics._factored_components import metric_distances 36 | from disent.metrics._factored_components import metric_linearity 37 | -------------------------------------------------------------------------------- /disent/model/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # encoders & decoders 26 | from disent.model._base import AutoEncoder 27 | from disent.model._base import DisentEncoder 28 | from disent.model._base import DisentDecoder 29 | -------------------------------------------------------------------------------- /disent/model/ae/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | # encoders & decoders 26 | from disent.model.ae._vae_conv64 import DecoderConv64 27 | from disent.model.ae._vae_conv64 import EncoderConv64 28 | from disent.model.ae._norm_conv64 import DecoderConv64Norm 29 | from disent.model.ae._norm_conv64 import EncoderConv64Norm 30 | from disent.model.ae._vae_fc import DecoderFC 31 | from disent.model.ae._vae_fc import EncoderFC 32 | from disent.model.ae._linear import DecoderLinear 33 | from disent.model.ae._linear import EncoderLinear 34 | -------------------------------------------------------------------------------- /disent/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/nn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/nn/modules.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | import lightning as L 26 | import torch 27 | 28 | # ========================================================================= # 29 | # Base Modules # 30 | # ========================================================================= # 31 | 32 | 33 | class DisentModule(torch.nn.Module): 34 | def _forward_unimplemented(self, *args): 35 | # Annoying fix applied by torch for Module.forward: 36 | # https://github.com/python/mypy/issues/8795 37 | raise RuntimeError("This should never run!") 38 | 39 | def forward(self, *args, **kwargs): 40 | raise NotImplementedError 41 | 42 | 43 | class DisentLightningModule(L.LightningModule): 44 | # make sure we don't get complaints about the missing methods! 45 | # -- we prefer to use LightningDataModule 46 | train_dataloader = None 47 | test_dataloader = None 48 | val_dataloader = None 49 | predict_dataloader = None 50 | 51 | 52 | # ========================================================================= # 53 | # END # 54 | # ========================================================================= # 55 | -------------------------------------------------------------------------------- /disent/schedule/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | from ._schedule import Schedule 26 | 27 | # schedules 28 | from ._schedule import ClipSchedule 29 | from ._schedule import CosineWaveSchedule 30 | from ._schedule import CyclicSchedule 31 | from ._schedule import LinearSchedule 32 | from ._schedule import NoopSchedule 33 | from ._schedule import MultiplySchedule 34 | from ._schedule import FixedValueSchedule 35 | from ._schedule import SingleSchedule 36 | 37 | 38 | # aliases 39 | from ._schedule import ClipSchedule as Clip 40 | from ._schedule import CosineWaveSchedule as CosineWave 41 | from ._schedule import CyclicSchedule as Cyclic 42 | from ._schedule import LinearSchedule as Linear 43 | from ._schedule import NoopSchedule as Noop 44 | from ._schedule import MultiplySchedule as Multiply 45 | from ._schedule import FixedValueSchedule as FixedValue 46 | from ._schedule import SingleSchedule as Single 47 | -------------------------------------------------------------------------------- /disent/util/function.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # ========================================================================= # 27 | # Function Helper # 28 | # ========================================================================= # 29 | 30 | 31 | def wrapped_partial(func, *args, **kwargs): 32 | """ 33 | Like functools.partial but keeps the same __name__ and __doc__ 34 | on the returned function. 35 | """ 36 | import functools 37 | 38 | partial_func = functools.partial(func, *args, **kwargs) 39 | functools.update_wrapper(partial_func, func) 40 | return partial_func 41 | 42 | 43 | # ========================================================================= # 44 | # END # 45 | # ========================================================================= # 46 | -------------------------------------------------------------------------------- /disent/util/inout/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/util/inout/tar.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2023 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | import os 26 | from pathlib import Path 27 | from typing import Union 28 | 29 | 30 | def tar_safe_extract_all(in_file: Union[Path, str], out_dir: Union[Path, str]): 31 | import tarfile 32 | 33 | in_file = str(in_file) 34 | out_dir = str(out_dir) 35 | 36 | def _is_safe_to_extract(tar): 37 | for member in tar.getmembers(): 38 | # check inside directory 39 | abs_dir = os.path.abspath(out_dir) 40 | abs_targ = os.path.abspath(os.path.join(out_dir, member.name)) 41 | common_prefix = os.path.commonprefix([abs_dir, abs_targ]) 42 | # raise exception if not 43 | if common_prefix != abs_dir: 44 | raise Exception("Attempted path traversal in tar file") 45 | 46 | # this is unsafe tar extraction 47 | with tarfile.open(in_file) as f: 48 | _is_safe_to_extract(f) 49 | f.extractall(out_dir) 50 | -------------------------------------------------------------------------------- /disent/util/jit.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # ========================================================================= # 27 | # Numba Is An Optional Dependency # 28 | # ========================================================================= # 29 | 30 | 31 | def try_njit(*args, **kwargs): 32 | """ 33 | Wrapper around numba.njit 34 | - If numba is installed, then we JIT the decorated function 35 | - If numba is missing, then we do nothing and leave the function untouched! 36 | """ 37 | try: 38 | from numba import njit 39 | except ImportError: 40 | # dummy njit 41 | def njit(*args, **kwargs): 42 | def _wrapper(func): 43 | import warnings 44 | 45 | warnings.warn(f"failed to JIT compile: {func}, numba is not installed!") 46 | return func 47 | 48 | return _wrapper 49 | 50 | # try and JIT compile function! 51 | return njit(*args, **kwargs) 52 | 53 | 54 | # ========================================================================= # 55 | # END # 56 | # ========================================================================= # 57 | -------------------------------------------------------------------------------- /disent/util/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/util/lightning/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic 26 | from disent.util.lightning.callbacks._callbacks_base import BaseCallbackTimed 27 | 28 | from disent.util.lightning.callbacks._callback_print_progress import LoggerProgressCallback 29 | from disent.util.lightning.callbacks._callback_log_metrics import VaeMetricLoggingCallback 30 | from disent.util.lightning.callbacks._callback_vis_latents import VaeLatentCycleLoggingCallback 31 | from disent.util.lightning.callbacks._callback_vis_dists import VaeGtDistsLoggingCallback 32 | -------------------------------------------------------------------------------- /disent/util/math/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/util/math/integer.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # ========================================================================= # 27 | # Working With Arbitrary Precision Integers # 28 | # ========================================================================= # 29 | 30 | 31 | def gcd(a: int, b: int) -> int: 32 | """ 33 | Compute the greatest common divisor of a and b 34 | TODO: not actually sure if this returns the correct values for zero or negative inputs? 35 | """ 36 | assert isinstance(a, int), f"number must be an int, got: {type(a)}" 37 | assert isinstance(b, int), f"number must be an int, got: {type(b)}" 38 | while b > 0: 39 | a, b = b, a % b 40 | return a 41 | 42 | 43 | def lcm(a: int, b: int) -> int: 44 | """ 45 | Compute the lowest common multiple of a and b 46 | TODO: not actually sure if this returns the correct values for zero or negative inputs? 47 | """ 48 | return (a * b) // gcd(a, b) 49 | 50 | 51 | # ========================================================================= # 52 | # End # 53 | # ========================================================================= # 54 | -------------------------------------------------------------------------------- /disent/util/strings/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /disent/util/strings/colors.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # ========================================================================= # 27 | # Ansi Colors # 28 | # ========================================================================= # 29 | 30 | 31 | RST = "\033[0m" 32 | 33 | # dark colors 34 | GRY = "\033[90m" 35 | lRED = "\033[91m" 36 | lGRN = "\033[92m" 37 | lYLW = "\033[93m" 38 | lBLU = "\033[94m" 39 | lMGT = "\033[95m" 40 | lCYN = "\033[96m" 41 | WHT = "\033[97m" 42 | 43 | # light colors 44 | BLK = "\033[30m" 45 | RED = "\033[31m" 46 | GRN = "\033[32m" 47 | YLW = "\033[33m" 48 | BLU = "\033[34m" 49 | MGT = "\033[35m" 50 | CYN = "\033[36m" 51 | lGRY = "\033[37m" 52 | 53 | 54 | # ========================================================================= # 55 | # END # 56 | # ========================================================================= # 57 | -------------------------------------------------------------------------------- /disent/util/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/code/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | import logging 26 | 27 | import disent.registry as R 28 | from docs.examples.extend_experiment.code.random_data import RandomData 29 | 30 | log = logging.getLogger(__name__) 31 | 32 | 33 | def register_to_disent(): 34 | log.info("Registering example with disent!") 35 | 36 | # DATASETS.setm[...] is an alias for DATASETS[...] that only sets the value if it does not already exist. 37 | # -- register_to_disent should be able to be called multiple times in the same run! 38 | 39 | # register: datasets 40 | R.DATASETS.setm["pseudorandom"] = R.LazyImport("docs.examples.extend_experiment.code.random_data.RandomData") 41 | R.DATASETS.setm["xyblocks"] = R.LazyImport( 42 | "docs.examples.extend_experiment.code.groundtruth__xyblocks.XYBlocksData" 43 | ) 44 | 45 | # register: VAEs 46 | R.FRAMEWORKS.setm["si_ada_vae"] = R.LazyImport( 47 | "docs.examples.extend_experiment.code.weaklysupervised__si_adavae.SwappedInputAdaVae" 48 | ) 49 | R.FRAMEWORKS.setm["si_beta_vae"] = R.LazyImport( 50 | "docs.examples.extend_experiment.code.weaklysupervised__si_betavae.SwappedInputBetaVae" 51 | ) 52 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/README.md: -------------------------------------------------------------------------------- 1 | # Research - Config 2 | 3 | - These configs are added to the experiment search path such that any 4 | files are found and read before that of the default experiment config. 5 | This means that if a file has the same name, it will overwrite the default file! 6 | The search path is overridden by setting the `DISENT_CONFIGS_PREPEND` environment variable. 7 | 8 | - Additionally, we expose the research code by registering it with disent using the experiment 9 | plugin functionality. See `config/run_plugins`. The plugin will register each metric, framework 10 | or dataset with the `disent.registry`. Allowing easy use elsewhere through config entries. 11 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/config_alt.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # defaults lists override entries from this file! 3 | # data 4 | - sampling: default__bb 5 | - dataset: xyobject 6 | - augment: none 7 | # system 8 | - framework: betavae 9 | - model: vae_conv64 10 | # training 11 | - optimizer: adam 12 | - schedule: none 13 | - metrics: all 14 | - run_length: long 15 | # logs 16 | - run_callbacks: vis 17 | - run_logging: wandb 18 | # runtime 19 | - run_location: cluster_shr 20 | - run_launcher: slurm 21 | - run_action: train 22 | # experiment 23 | - run_plugins: default 24 | 25 | settings: 26 | job: 27 | user: '${oc.env:USER}' 28 | project: 'DELETE' 29 | name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' 30 | seed: NULL 31 | framework: 32 | beta: 0.0316 33 | recon_loss: mse 34 | loss_reduction: mean # beta scaling 35 | framework_opt: 36 | latent_distribution: normal # only used by VAEs 37 | overlap_loss: NULL # only used for experimental dotvae and dorvae 38 | usage_ratio: 0.5 # only used by adversarial masked datasets 39 | model: 40 | z_size: 25 41 | weight_init: 'xavier_normal' # xavier_normal, default 42 | dataset: 43 | batch_size: 256 44 | optimizer: 45 | lr: 1e-3 46 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/config_alt_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # defaults lists override entries from this file! 3 | # data 4 | - sampling: default__bb 5 | - dataset: xyobject 6 | - augment: example 7 | # system 8 | - framework: betavae 9 | - model: linear 10 | # training 11 | - optimizer: adam 12 | - schedule: beta_cyclic 13 | - metrics: test 14 | - run_length: test 15 | # logs 16 | - run_callbacks: test 17 | - run_logging: none 18 | # runtime 19 | - run_location: local_cpu 20 | - run_launcher: local 21 | - run_action: train 22 | # experiment 23 | - run_plugins: default 24 | 25 | settings: 26 | job: 27 | user: 'invalid' 28 | project: 'invalid' 29 | name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' 30 | seed: NULL 31 | framework: 32 | beta: 0.0316 33 | recon_loss: mse 34 | loss_reduction: mean # beta scaling 35 | framework_opt: 36 | latent_distribution: normal # only used by VAEs 37 | overlap_loss: NULL # only used for experimental dotvae and dorvae 38 | usage_ratio: 0.5 # only used by adversarial masked datasets 39 | model: 40 | z_size: 25 41 | weight_init: 'xavier_normal' # xavier_normal, default 42 | dataset: 43 | batch_size: 5 44 | optimizer: 45 | lr: 1e-3 46 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--mask-dthr-cars3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: random 3 | 4 | name: mask_dthr_cars3d 5 | 6 | data: 7 | _target_: disent.dataset.wrapper.DitheredDataset 8 | dither_n: 2 9 | keep_ratio: 0.5 10 | gt_data: 11 | _target_: disent.dataset.data.Cars3d64Data 12 | data_root: ${dsettings.storage.data_root} 13 | prepare: ${dsettings.dataset.prepare} 14 | 15 | transform: 16 | _target_: disent.dataset.transform.ToImgTensorF32 17 | mean: ${dataset.meta.vis_mean} 18 | std: ${dataset.meta.vis_std} 19 | 20 | meta: 21 | x_shape: [3, 64, 64] 22 | vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] 23 | vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] 24 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--mask-dthr-dsprites.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: random 3 | 4 | name: mask_dthr_dsprites 5 | 6 | data: 7 | _target_: disent.dataset.wrapper.DitheredDataset 8 | dither_n: 2 9 | keep_ratio: 0.5 10 | gt_data: 11 | _target_: disent.dataset.data.DSpritesData 12 | data_root: ${dsettings.storage.data_root} 13 | prepare: ${dsettings.dataset.prepare} 14 | in_memory: ${dsettings.dataset.try_in_memory} 15 | 16 | transform: 17 | _target_: disent.dataset.transform.ToImgTensorF32 18 | mean: ${dataset.meta.vis_mean} 19 | std: ${dataset.meta.vis_std} 20 | 21 | meta: 22 | x_shape: [1, 64, 64] 23 | vis_mean: [0.042494423521889584] 24 | vis_std: [0.19516645880626055] 25 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--mask-dthr-pseudorandom.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: random 3 | 4 | name: mask_dthr_pseudorandom 5 | 6 | data: 7 | _target_: disent.dataset.wrapper.DitheredDataset 8 | dither_n: 2 9 | keep_ratio: 0.5 10 | gt_data: 11 | _target_: docs.examples.extend_experiment.code.random_data.RandomData 12 | 13 | transform: 14 | _target_: disent.dataset.transform.ToImgTensorF32 15 | mean: ${dataset.meta.vis_mean} 16 | std: ${dataset.meta.vis_std} 17 | 18 | meta: 19 | x_shape: [1, 64, 64] 20 | vis_mean: [0.4999966931838419] 21 | vis_std: [0.2897895504502549] 22 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--mask-dthr-shapes3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: random 3 | 4 | name: mask_dthr_shapes3d 5 | 6 | data: 7 | _target_: disent.dataset.wrapper.DitheredDataset 8 | dither_n: 2 9 | keep_ratio: 0.5 10 | gt_data: 11 | _target_: disent.dataset.data.Shapes3dData 12 | data_root: ${dsettings.storage.data_root} 13 | prepare: ${dsettings.dataset.prepare} 14 | in_memory: ${dsettings.dataset.try_in_memory} 15 | 16 | transform: 17 | _target_: disent.dataset.transform.ToImgTensorF32 18 | mean: ${dataset.meta.vis_mean} 19 | std: ${dataset.meta.vis_std} 20 | 21 | meta: 22 | x_shape: [3, 64, 64] 23 | vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] 24 | vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] 25 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--mask-dthr-smallnorb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: random 3 | 4 | name: mask_dthr_smallnorb 5 | 6 | data: 7 | _target_: disent.dataset.wrapper.DitheredDataset 8 | dither_n: 2 9 | keep_ratio: 0.5 10 | gt_data: 11 | _target_: disent.dataset.data.SmallNorb64Data 12 | data_root: ${dsettings.storage.data_root} 13 | prepare: ${dsettings.dataset.prepare} 14 | is_test: False 15 | 16 | transform: 17 | _target_: disent.dataset.transform.ToImgTensorF32 18 | mean: ${dataset.meta.vis_mean} 19 | std: ${dataset.meta.vis_std} 20 | 21 | meta: 22 | x_shape: [1, 64, 64] 23 | vis_mean: [0.7520918401088603] 24 | vis_std: [0.09563879016827262] 25 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--pseudorandom.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: pseudorandom 5 | 6 | data: 7 | _target_: docs.examples.extend_experiment.code.random_data.RandomData 8 | 9 | transform: 10 | _target_: disent.dataset.transform.ToImgTensorF32 11 | mean: ${dataset.meta.vis_mean} 12 | std: ${dataset.meta.vis_std} 13 | 14 | meta: 15 | x_shape: [1, 64, 64] 16 | vis_mean: [0.4999966931838419] 17 | vis_std: [0.2897895504502549] 18 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--xyblocks.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyblocks 5 | 6 | data: 7 | _target_: docs.examples.extend_experiment.code.groundtruth__xyblocks.XYBlocksData 8 | rgb: TRUE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [3, 64, 64] 17 | vis_mean: [0.10040509259259259, 0.10040509259259259, 0.10040509259259259] 18 | vis_std: [0.21689087652106678, 0.21689087652106676, 0.21689087652106678] 19 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/dataset/E--xyblocks_grey.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyblocks_grey 5 | 6 | data: 7 | _target_: docs.examples.extend_experiment.code.groundtruth__xyblocks.XYBlocksData 8 | rgb: FALSE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [1, 64, 64] 17 | vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" 18 | vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" 19 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/framework/E--si-adavae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: pair 3 | 4 | name: si-adavae 5 | 6 | cfg: 7 | _target_: docs.examples.extend_experiment.code.weaklysupervised__si_adavae.SwappedInputAdaVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # adavae 21 | ada_average_mode: gvae # gvae or ml-vae 22 | ada_thresh_mode: symmetric_kl 23 | ada_thresh_ratio: 0.5 24 | # swapped target 25 | swap_chance: 0.1 26 | 27 | meta: 28 | model_z_multiplier: 2 29 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/framework/E--si-betavae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: pair 3 | 4 | name: si-betavae 5 | 6 | cfg: 7 | _target_: docs.examples.extend_experiment.code.weaklysupervised__si_betavae.SwappedInputBetaVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # swapped target 21 | swap_chance: 0.1 22 | 23 | meta: 24 | model_z_multiplier: 2 25 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/run_location/cluster_rsync_tmp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # -- this run_location reads data from the /tmp folder but does not prepare it! 4 | # | *NB* the difference from `cluster_tmp` is that this script `rsync`s the already 5 | # | prepared files from the network drive to the /tmp folder instead of preparing 6 | # | it all. This is useful if the data is expensive to download, and the data needs 7 | # | to be constantly read off the disk! 8 | 9 | dsettings: 10 | trainer: 11 | cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly 12 | storage: 13 | logs_dir: 'logs' 14 | data_root: ${rsync_dir:'${oc.env:HOME}/downloads/datasets','/tmp/${oc.env:USER}/datasets'} 15 | dataset: 16 | prepare: TRUE 17 | try_in_memory: FALSE 18 | launcher: 19 | partition: stampede 20 | array_parallelism: 16 21 | exclude: "cluster92,cluster94,cluster96" 22 | 23 | datamodule: 24 | gpu_augment: FALSE 25 | prepare_data_per_node: TRUE 26 | dataloader: 27 | num_workers: 16 28 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 29 | batch_size: ${settings.dataset.batch_size} 30 | 31 | hydra: 32 | job: 33 | name: 'disent' 34 | run: 35 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 36 | sweep: 37 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 38 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 39 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/run_location/cluster_shr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # -- this run_location reads data from a network drive but does not prepare it! 4 | # | this is useful if we read data once into memory 5 | # | 6 | # | 7 | # | 8 | 9 | dsettings: 10 | trainer: 11 | cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly 12 | storage: 13 | logs_dir: 'logs' 14 | data_root: '${oc.env:HOME}/downloads/datasets' # WE NEED TO BE VERY CAREFUL ABOUT USING A SHARED DRIVE 15 | dataset: 16 | prepare: FALSE # WE MUST PREPARE DATA MANUALLY BEFOREHAND 17 | try_in_memory: TRUE 18 | launcher: 19 | partition: stampede 20 | array_parallelism: 16 21 | exclude: "cluster92,cluster94,cluster96" 22 | 23 | datamodule: 24 | gpu_augment: FALSE 25 | prepare_data_per_node: TRUE 26 | dataloader: 27 | num_workers: 16 28 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 29 | batch_size: ${settings.dataset.batch_size} 30 | 31 | hydra: 32 | job: 33 | name: 'disent' 34 | run: 35 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 36 | sweep: 37 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 38 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 39 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/run_location/cluster_tmp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # -- this run_location reads data from the /tmp folder and prepares it if it is missing 4 | # this is useful if we need to read data constantly from the disk 5 | # | 6 | # | 7 | # | 8 | 9 | dsettings: 10 | trainer: 11 | cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly 12 | storage: 13 | logs_dir: 'logs' 14 | data_root: '/tmp/${oc.env:USER}/datasets' 15 | dataset: 16 | prepare: TRUE 17 | try_in_memory: TRUE 18 | launcher: 19 | partition: stampede 20 | array_parallelism: 16 21 | exclude: "cluster92,cluster94,cluster96" 22 | 23 | datamodule: 24 | gpu_augment: FALSE 25 | prepare_data_per_node: TRUE 26 | dataloader: 27 | num_workers: 16 28 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 29 | batch_size: ${settings.dataset.batch_size} 30 | 31 | hydra: 32 | job: 33 | name: 'disent' 34 | run: 35 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 36 | sweep: 37 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 38 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 39 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/config/run_plugins/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # call the listed functions here before the experiment is started 4 | # - make sure we register the research code to disent so that it can be used! 5 | experiment: 6 | plugins: 7 | - docs.examples.extend_experiment.code.register_to_disent 8 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/run.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from experiment.run import hydra_experiment 4 | 5 | # This python script is functionally equivalent to run.sh 6 | 7 | # run the experiment and append the new config search path 8 | # - for example: 9 | # $ python3 run.py dataset=E--pseudorandom framework=E--si-betavae 10 | if __name__ == "__main__": 11 | hydra_experiment(search_dirs_prepend=os.path.abspath(os.path.join(__file__, "../config"))) 12 | -------------------------------------------------------------------------------- /docs/examples/extend_experiment/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This bash script is functionally equivalent to run.py 4 | 5 | # get the various dirs relative to this file 6 | SCRIPT_DIR="$(realpath -s "$(dirname -- "${BASH_SOURCE[0]}")")" # get the current script dir 7 | DISENT_DIR="$(realpath -s "$SCRIPT_DIR/../../..")" # get the root directory for `disent` 8 | SEARCH_DIR="${DISENT_DIR}/docs/examples/extend_experiment/config" 9 | RUN_SCRIPT="${DISENT_DIR}/experiment/run.py" 10 | 11 | echo "DISENT_DIR=$DISENT_DIR" 12 | echo "SCRIPT_DIR=$SCRIPT_DIR" 13 | echo "SEARCH_DIR=$SEARCH_DIR" 14 | echo "RUN_SCRIPT=$RUN_SCRIPT" 15 | 16 | # run the experiment, passing arguments to this script to the experiment instead! 17 | # - for example: 18 | # $ run.sh dataset=E--pseudorandom framework=E--si-betavae 19 | # - is equivalent to: 20 | # PYTHONPATH="$DISENT_DIR" DISENT_CONFIGS_PREPEND="$SEARCH_DIR" python3 "$RUN_SCRIPT" dataset=E--pseudorandom framework=E--si-betavae 21 | PYTHONPATH="$DISENT_DIR" DISENT_CONFIGS_PREPEND="$SEARCH_DIR" python3 "$RUN_SCRIPT" "$@" 22 | -------------------------------------------------------------------------------- /docs/examples/mnist_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lightning as L 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets 6 | from tqdm import tqdm 7 | 8 | from disent.dataset import DisentDataset 9 | from disent.dataset.sampling import RandomSampler 10 | from disent.dataset.transform import ToImgTensorF32 11 | from disent.frameworks.vae import AdaVae 12 | from disent.model import AutoEncoder 13 | from disent.model.ae import DecoderFC 14 | from disent.model.ae import EncoderFC 15 | 16 | 17 | # modify the mnist dataset to only return images, not labels 18 | class MNIST(datasets.MNIST): 19 | def __getitem__(self, index): 20 | img, target = super().__getitem__(index) 21 | return img 22 | 23 | 24 | # make mnist dataset -- adjust num_samples here to match framework. TODO: add tests that can fail with a warning -- dataset downloading is not always reliable 25 | data_folder = os.path.abspath(os.path.join(__file__, "../data/dataset")) 26 | dataset_train = DisentDataset( 27 | MNIST(data_folder, train=True, download=True, transform=ToImgTensorF32()), sampler=RandomSampler(num_samples=2) 28 | ) 29 | dataset_test = MNIST(data_folder, train=False, download=True, transform=ToImgTensorF32()) 30 | 31 | # create the dataloaders 32 | # - if you use `num_workers != 0` in the DataLoader, the make sure to 33 | # wrap `trainer.fit` with `if __name__ == '__main__': ...` 34 | dataloader_train = DataLoader(dataset=dataset_train, batch_size=128, shuffle=True, num_workers=0) 35 | dataloader_test = DataLoader(dataset=dataset_test, batch_size=128, shuffle=True, num_workers=0) 36 | 37 | # create the model 38 | module = AdaVae( 39 | model=AutoEncoder( 40 | encoder=EncoderFC(x_shape=(1, 28, 28), z_size=9, z_multiplier=2), 41 | decoder=DecoderFC(x_shape=(1, 28, 28), z_size=9), 42 | ), 43 | cfg=AdaVae.cfg( 44 | optimizer="adam", 45 | optimizer_kwargs=dict(lr=1e-3), 46 | beta=4, 47 | recon_loss="mse", 48 | loss_reduction="mean_sum", # "mean_sum" is the traditional loss reduction mode, rather than "mean" 49 | ), 50 | ) 51 | 52 | # train the model 53 | trainer = L.Trainer( 54 | logger=False, enable_checkpointing=False, max_steps=2048 55 | ) # callbacks=[VaeLatentCycleLoggingCallback(every_n_steps=250, plt_show=True)] 56 | trainer.fit(module, dataloader_train) 57 | 58 | # move back to gpu & manually encode some observation 59 | for xs in tqdm(dataloader_test, desc="Custom Evaluation"): 60 | zs = module.encode(xs.to(module.device)) 61 | -------------------------------------------------------------------------------- /docs/examples/overview_data.py: -------------------------------------------------------------------------------- 1 | from disent.dataset.data import XYObjectData 2 | 3 | data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette="rgb_1") 4 | 5 | print(f"Number of observations: {len(data)} == {data.size}") 6 | print(f"Observation shape: {data.img_shape}") 7 | print(f"Num Factors: {data.num_factors}") 8 | print(f"Factor Names: {data.factor_names}") 9 | print(f"Factor Sizes: {data.factor_sizes}") 10 | 11 | for i, obs in enumerate(data): 12 | print( 13 | f"i={i}", 14 | f'pos: ({", ".join(data.factor_names)}) = {tuple(data.idx_to_pos(i))}', 15 | f"obs={obs.tolist()}", 16 | sep=" | ", 17 | ) 18 | -------------------------------------------------------------------------------- /docs/examples/overview_dataset_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from disent.dataset import DisentDataset 4 | from disent.dataset.data import XYObjectData 5 | from disent.dataset.sampling import GroundTruthPairOrigSampler 6 | from disent.dataset.transform import ToImgTensorF32 7 | 8 | # prepare the data 9 | data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette="rgb_1") 10 | dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) 11 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0) 12 | 13 | # iterate over single epoch 14 | for batch in dataloader: 15 | (x0, x1) = batch["x_targ"] 16 | print(x0.dtype, x0.min(), x0.max(), x0.shape) 17 | -------------------------------------------------------------------------------- /docs/examples/overview_dataset_pair.py: -------------------------------------------------------------------------------- 1 | from disent.dataset import DisentDataset 2 | from disent.dataset.data import XYObjectData 3 | from disent.dataset.sampling import GroundTruthPairOrigSampler 4 | from disent.dataset.transform import ToImgTensorF32 5 | 6 | # prepare the data 7 | data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette="rgb_1") 8 | dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) 9 | 10 | # iterate over single epoch 11 | for obs in dataset: 12 | # singles are contained in tuples of size 1 for compatibility with pairs with size 2 13 | (x0, x1) = obs["x_targ"] 14 | print(x0.dtype, x0.min(), x0.max(), x0.shape) 15 | -------------------------------------------------------------------------------- /docs/examples/overview_dataset_pair_augment.py: -------------------------------------------------------------------------------- 1 | from disent.dataset import DisentDataset 2 | from disent.dataset.data import XYObjectData 3 | from disent.dataset.sampling import GroundTruthPairSampler 4 | from disent.dataset.transform import FftBoxBlur 5 | from disent.dataset.transform import ToImgTensorF32 6 | 7 | # prepare the data 8 | data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette="rgb_1") 9 | dataset = DisentDataset( 10 | data, sampler=GroundTruthPairSampler(), transform=ToImgTensorF32(), augment=FftBoxBlur(radius=1, p=1.0) 11 | ) 12 | 13 | # iterate over single epoch 14 | for obs in dataset: 15 | # if augment is not specified, then the augmented 'x' key does not exist! 16 | (x0, x1), (x0_targ, x1_targ) = obs["x"], obs["x_targ"] 17 | print(x0.dtype, x0.min(), x0.max(), x0.shape) 18 | -------------------------------------------------------------------------------- /docs/examples/overview_dataset_single.py: -------------------------------------------------------------------------------- 1 | from disent.dataset import DisentDataset 2 | from disent.dataset.data import XYObjectData 3 | 4 | # prepare the data 5 | # - DisentDataset is a generic wrapper around torch Datasets that prepares 6 | # the data for the various frameworks according to some sampling strategy 7 | # by default this sampling strategy just returns the data at the given idx. 8 | data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette="rgb_1") 9 | dataset = DisentDataset(data, transform=None, augment=None) 10 | 11 | # iterate over single epoch 12 | for obs in dataset: 13 | # transform(data[i]) gives 'x_targ', then augment(x_targ) gives 'x' 14 | (x0,) = obs["x_targ"] 15 | print(x0.dtype, x0.min(), x0.max(), x0.shape) 16 | -------------------------------------------------------------------------------- /docs/examples/overview_framework_adagvae.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | 4 | from disent.dataset import DisentDataset 5 | from disent.dataset.data import XYObjectData 6 | from disent.dataset.sampling import GroundTruthPairOrigSampler 7 | from disent.dataset.transform import ToImgTensorF32 8 | from disent.frameworks.vae import AdaVae 9 | from disent.model import AutoEncoder 10 | from disent.model.ae import DecoderConv64 11 | from disent.model.ae import EncoderConv64 12 | from disent.util import is_test_run # you can ignore and remove this 13 | 14 | # prepare the data 15 | data = XYObjectData() 16 | dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) 17 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0) 18 | 19 | # create the pytorch lightning system 20 | module: L.LightningModule = AdaVae( 21 | model=AutoEncoder( 22 | encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), 23 | decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), 24 | ), 25 | cfg=AdaVae.cfg( 26 | optimizer="adam", 27 | optimizer_kwargs=dict(lr=1e-3), 28 | loss_reduction="mean_sum", 29 | beta=4, 30 | ada_average_mode="gvae", 31 | ada_thresh_mode="kl", 32 | ), 33 | ) 34 | 35 | # train the model 36 | trainer = L.Trainer(logger=False, enable_checkpointing=False, fast_dev_run=is_test_run()) 37 | trainer.fit(module, dataloader) 38 | -------------------------------------------------------------------------------- /docs/examples/overview_framework_ae.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | 4 | from disent.dataset import DisentDataset 5 | from disent.dataset.data import XYObjectData 6 | from disent.dataset.transform import ToImgTensorF32 7 | from disent.frameworks.ae import Ae 8 | from disent.model import AutoEncoder 9 | from disent.model.ae import DecoderConv64 10 | from disent.model.ae import EncoderConv64 11 | from disent.util import is_test_run # you can ignore and remove this 12 | 13 | # prepare the data 14 | data = XYObjectData() 15 | dataset = DisentDataset(data, transform=ToImgTensorF32()) 16 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0) 17 | 18 | # create the pytorch lightning system 19 | module: L.LightningModule = Ae( 20 | model=AutoEncoder( 21 | encoder=EncoderConv64(x_shape=data.x_shape, z_size=6), 22 | decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), 23 | ), 24 | cfg=Ae.cfg(optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum"), 25 | ) 26 | 27 | # train the model 28 | trainer = L.Trainer(logger=False, enable_checkpointing=False, fast_dev_run=is_test_run()) 29 | trainer.fit(module, dataloader) 30 | -------------------------------------------------------------------------------- /docs/examples/overview_framework_betavae.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | 4 | from disent.dataset import DisentDataset 5 | from disent.dataset.data import XYObjectData 6 | from disent.dataset.transform import ToImgTensorF32 7 | from disent.frameworks.vae import BetaVae 8 | from disent.model import AutoEncoder 9 | from disent.model.ae import DecoderConv64 10 | from disent.model.ae import EncoderConv64 11 | from disent.util import is_test_run # you can ignore and remove this 12 | 13 | # prepare the data 14 | data = XYObjectData() 15 | dataset = DisentDataset(data, transform=ToImgTensorF32()) 16 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0) 17 | 18 | # create the pytorch lightning system 19 | module: L.LightningModule = BetaVae( 20 | model=AutoEncoder( 21 | encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), 22 | decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), 23 | ), 24 | cfg=BetaVae.cfg(optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum", beta=4), 25 | ) 26 | 27 | # train the model 28 | trainer = L.Trainer(logger=False, enable_checkpointing=False, fast_dev_run=is_test_run()) 29 | trainer.fit(module, dataloader) 30 | -------------------------------------------------------------------------------- /docs/examples/overview_framework_betavae_scheduled.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | 4 | from disent.dataset import DisentDataset 5 | from disent.dataset.data import XYObjectData 6 | from disent.dataset.transform import ToImgTensorF32 7 | from disent.frameworks.vae import BetaVae 8 | from disent.model import AutoEncoder 9 | from disent.model.ae import DecoderConv64 10 | from disent.model.ae import EncoderConv64 11 | from disent.schedule import CyclicSchedule 12 | from disent.util import is_test_run # you can ignore and remove this 13 | 14 | # prepare the data 15 | data = XYObjectData() 16 | dataset = DisentDataset(data, transform=ToImgTensorF32()) 17 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0) 18 | 19 | # create the pytorch lightning system 20 | module: L.LightningModule = BetaVae( 21 | model=AutoEncoder( 22 | encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), 23 | decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), 24 | ), 25 | cfg=BetaVae.cfg(optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum", beta=4), 26 | ) 27 | 28 | # register the scheduler with the DisentFramework 29 | # - cyclic scheduler from: https://arxiv.org/abs/1903.10145 30 | module.register_schedule( 31 | "beta", 32 | CyclicSchedule( 33 | period=1024, # repeat every: trainer.global_step % period 34 | ), 35 | ) 36 | 37 | # train the model 38 | trainer = L.Trainer(logger=False, enable_checkpointing=False, fast_dev_run=is_test_run()) 39 | trainer.fit(module, dataloader) 40 | -------------------------------------------------------------------------------- /docs/examples/overview_metrics.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | 4 | from disent.dataset import DisentDataset 5 | from disent.dataset.data import XYObjectData 6 | from disent.dataset.transform import ToImgTensorF32 7 | from disent.frameworks.vae import BetaVae 8 | from disent.metrics import metric_dci 9 | from disent.metrics import metric_mig 10 | from disent.model import AutoEncoder 11 | from disent.model.ae import DecoderConv64 12 | from disent.model.ae import EncoderConv64 13 | from disent.util import is_test_run 14 | 15 | data = XYObjectData() 16 | dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None) 17 | dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) 18 | 19 | 20 | def make_vae(beta): 21 | return BetaVae( 22 | model=AutoEncoder( 23 | encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), 24 | decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), 25 | ), 26 | cfg=BetaVae.cfg(optimizer="adam", optimizer_kwargs=dict(lr=1e-3), beta=beta), 27 | ) 28 | 29 | 30 | def train(module): 31 | trainer = L.Trainer(logger=False, enable_checkpointing=False, max_steps=256, fast_dev_run=is_test_run()) 32 | trainer.fit(module, dataloader) 33 | 34 | # we cannot guarantee which device the representation is on 35 | get_repr = lambda x: module.encode(x.to(module.device)) 36 | 37 | # evaluate 38 | return { 39 | **metric_dci( 40 | dataset, 41 | get_repr, 42 | num_train=10 if is_test_run() else 1000, 43 | num_test=5 if is_test_run() else 500, 44 | boost_mode="sklearn", 45 | ), 46 | **metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000), 47 | } 48 | 49 | 50 | a_results = train(make_vae(beta=4)) 51 | b_results = train(make_vae(beta=0.01)) 52 | 53 | print("beta=4: ", a_results) 54 | print("beta=0.01:", b_results) 55 | -------------------------------------------------------------------------------- /docs/examples/plotting_examples/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /docs/examples/plotting_examples/plots/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.jpg 3 | *.gif 4 | -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__cars3d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__cars3d.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-50.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-50.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__dsprites.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__dsprites.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__mpi3d-real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__mpi3d-real.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__mpi3d-realistic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__mpi3d-realistic.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__mpi3d-toy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__mpi3d-toy.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__shapes3d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__shapes3d.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__smallnorb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__smallnorb.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__sprites.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__sprites.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__xy-blocks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__xy-blocks.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__xy-object-shaded.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__xy-object.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__xy-object.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__xy-single-square__spacing8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__xy-single-square__spacing8.jpg -------------------------------------------------------------------------------- /docs/img/traversals/traversal-transpose__xy-squares__spacing8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmichlo/disent/8f061a87076adeae8d6e5b0fa984b660cd40e026/docs/img/traversals/traversal-transpose__xy-squares__spacing8.jpg -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Disent 2 | 3 | PyTorch Lightning disentanglement framwork implementing various modular VAEs. 4 | 5 | Various unique optional features exist, including data augmentations, 6 | as well as the first (?) unofficial implementation of the tensorflow based [Ada-GVAE](https://github.com/google-research/disentanglement_lib). 7 | 8 | ## Goals 9 | 10 | Disent aims to fill the following criteria: 11 | 1. Provide **high quality**, **readable**, **consistent** and **easily comparable** implementations of frameworks 12 | 2. **Highlight difference** between framework implementations by overriding **hooks** and minimising duplicate code 13 | 3. Use **best practice** eg. `torch.distributions` 14 | 4. Be extremely **flexible** & configurable 15 | 5. Load data from disk for low memory systems 16 | 17 | ## Citing Disent 18 | 19 | Please use the following citation if you use Disent in your research: 20 | 21 | ```bibtex 22 | @Misc{Michlo2021Disent, 23 | author = {Nathan Juraj Michlo}, 24 | title = {Disent - A modular disentangled representation learning framework for pytorch}, 25 | howpublished = {Github}, 26 | year = {2021}, 27 | url = {https://github.com/nmichlo/disent} 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | mkdocs == 1.1.2 3 | mkdocstrings == 0.14.0 4 | mkdocs-material == 6.2.5 5 | mkdocs-git-revision-date-localized-plugin == 0.8 6 | # pygments == 2.7.4 7 | # pymdown-extensions == 8.1 8 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /experiment/config/__init__.py: -------------------------------------------------------------------------------- 1 | # for some unknown reason this file is required for tests/test_experiment.py to work 2 | # this is very odd. Although it might be fixed in hydra 1.1? 3 | -------------------------------------------------------------------------------- /experiment/config/augment/example.yaml: -------------------------------------------------------------------------------- 1 | name: basic 2 | 3 | augment_cls: 4 | _target_: torchvision.transforms.ColorJitter 5 | brightness: 0.1 6 | contrast: 0.1 7 | saturation: 0.1 8 | hue: 0.1 9 | -------------------------------------------------------------------------------- /experiment/config/augment/none.yaml: -------------------------------------------------------------------------------- 1 | name: none 2 | 3 | augment_cls: NULL 4 | -------------------------------------------------------------------------------- /experiment/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # defaults lists override entries from this file! 3 | # data 4 | - sampling: default__bb 5 | - dataset: xyobject 6 | - augment: none 7 | # system 8 | - framework: betavae 9 | - model: vae_conv64 10 | # training 11 | - optimizer: adam 12 | - schedule: none 13 | - metrics: all 14 | - run_length: long 15 | # logs 16 | - run_callbacks: vis 17 | - run_logging: none 18 | # runtime 19 | - run_location: local 20 | - run_launcher: local 21 | - run_action: train 22 | # experiment 23 | - run_plugins: default 24 | 25 | settings: 26 | job: 27 | user: '${oc.env:USER}' 28 | project: 'DELETE' 29 | name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' 30 | seed: NULL 31 | framework: 32 | beta: 0.0316 33 | recon_loss: mse 34 | loss_reduction: mean # beta scaling 35 | framework_opt: 36 | latent_distribution: normal # only used by VAEs 37 | model: 38 | z_size: 25 39 | weight_init: 'xavier_normal' # xavier_normal, default 40 | dataset: 41 | batch_size: 256 42 | optimizer: 43 | lr: 1e-3 44 | checkpoint: 45 | # load_checkpoint: NULL # NULL or string 46 | save_checkpoint: FALSE # boolean, save at end of run -- more advanced checkpointing can be done with a callback! 47 | -------------------------------------------------------------------------------- /experiment/config/config_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # defaults lists override entries from this file! 3 | # data 4 | - sampling: default__bb 5 | - dataset: xyobject 6 | - augment: example 7 | # system 8 | - framework: betavae 9 | - model: linear 10 | # training 11 | - optimizer: adam 12 | - schedule: beta_cyclic 13 | - metrics: test 14 | - run_length: test 15 | # logs 16 | - run_callbacks: test 17 | - run_logging: none 18 | # runtime 19 | - run_location: local_cpu 20 | - run_launcher: local 21 | - run_action: train 22 | # experiment 23 | - run_plugins: default 24 | 25 | settings: 26 | job: 27 | user: 'invalid' 28 | project: 'invalid' 29 | name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' 30 | seed: NULL 31 | framework: 32 | beta: 0.0316 33 | recon_loss: mse 34 | loss_reduction: mean # beta scaling 35 | framework_opt: 36 | latent_distribution: normal # only used by VAEs 37 | model: 38 | z_size: 25 39 | weight_init: 'xavier_normal' # xavier_normal, default 40 | dataset: 41 | batch_size: 5 42 | optimizer: 43 | lr: 1e-3 44 | checkpoint: 45 | # load_checkpoint: NULL # NULL or string 46 | save_checkpoint: TRUE # boolean, save at end of run -- more advanced checkpointing can be done with a callback! 47 | -------------------------------------------------------------------------------- /experiment/config/dataset/_data_type_/episodes.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the data's defaults list 2 | name: episodes 3 | -------------------------------------------------------------------------------- /experiment/config/dataset/_data_type_/gt.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the data's defaults list 2 | name: gt 3 | -------------------------------------------------------------------------------- /experiment/config/dataset/_data_type_/random.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the data's defaults list 2 | name: random 3 | -------------------------------------------------------------------------------- /experiment/config/dataset/cars3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: cars3d 5 | 6 | data: 7 | _target_: disent.dataset.data.Cars3d64Data 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | 11 | transform: 12 | _target_: disent.dataset.transform.ToImgTensorF32 13 | mean: ${dataset.meta.vis_mean} 14 | std: ${dataset.meta.vis_std} 15 | 16 | meta: 17 | x_shape: [3, 64, 64] 18 | vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] 19 | vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] 20 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-bg-100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_bg_100 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 100 9 | mode: bg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404] 22 | vis_std: [0.2505510666843685, 0.2500725980366869, 0.2562415603123114] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-bg-25.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_bg_25 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 25 9 | mode: bg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.15596283852200074, 0.14847876264131535, 0.13644703866118635] 22 | vis_std: [0.18208653250875798, 0.18323109038468802, 0.18569624396763393] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-bg-50.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_bg_50 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 50 9 | mode: bg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.271323621109491, 0.25634066038331416, 0.23223046934400662] 22 | vis_std: [0.18930391112143766, 0.19067969524425118, 0.19523218572886117] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-bg-75.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_bg_75 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 75 9 | mode: bg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.38577296742807327, 0.3632825822323436, 0.3271231888851156] 22 | vis_std: [0.21392191050784257, 0.2146731716558466, 0.2204460568339597] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-fg-100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_fg_100 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 100 9 | mode: fg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384] 22 | vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-fg-25.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_fg_25 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 25 9 | mode: fg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.03697718115834816, 0.03648095993826591, 0.03589183623762013] 22 | vis_std: [0.17009317531572005, 0.16780075430655303, 0.16508779008691726] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-fg-50.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_fg_50 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 50 9 | mode: fg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.031541090790578506, 0.030549541980176148, 0.029368756624861398] 22 | vis_std: [0.14609029304575144, 0.14150919987547408, 0.13607872227034773] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites-imagenet-fg-75.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites_imagenet_fg_75 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesImagenetData 8 | visibility: 75 9 | mode: fg 10 | data_root: ${dsettings.storage.data_root} 11 | prepare: ${dsettings.dataset.prepare} 12 | in_memory: ${dsettings.dataset.try_in_memory} 13 | 14 | transform: 15 | _target_: disent.dataset.transform.ToImgTensorF32 16 | mean: ${dataset.meta.vis_mean} 17 | std: ${dataset.meta.vis_std} 18 | 19 | meta: 20 | x_shape: [3, 64, 64] 21 | vis_mean: [0.02606445677382044, 0.024577082627819637, 0.02280587082174753] 22 | vis_std: [0.12307153238282868, 0.11624914830767437, 0.1081911967745551] 23 | -------------------------------------------------------------------------------- /experiment/config/dataset/dsprites.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: dsprites 5 | 6 | data: 7 | _target_: disent.dataset.data.DSpritesData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | in_memory: ${dsettings.dataset.try_in_memory} 11 | 12 | transform: 13 | _target_: disent.dataset.transform.ToImgTensorF32 14 | mean: ${dataset.meta.vis_mean} 15 | std: ${dataset.meta.vis_std} 16 | 17 | meta: 18 | x_shape: [1, 64, 64] 19 | vis_mean: [0.042494423521889584] 20 | vis_std: [0.19516645880626055] 21 | -------------------------------------------------------------------------------- /experiment/config/dataset/mpi3d_real.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: mpi3d_real 5 | 6 | data: 7 | _target_: disent.dataset.data.Mpi3dData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | in_memory: ${dsettings.dataset.try_in_memory} 11 | subset: 'real' 12 | 13 | transform: 14 | _target_: disent.dataset.transform.ToImgTensorF32 15 | mean: ${dataset.meta.vis_mean} 16 | std: ${dataset.meta.vis_std} 17 | 18 | meta: 19 | x_shape: [3, 64, 64] 20 | vis_mean: [0.12848577057593918, 0.1648033279246875, 0.13971583058948006] 21 | vis_std: [0.09329210572942123, 0.09203401520672466, 0.10322983729706256] 22 | -------------------------------------------------------------------------------- /experiment/config/dataset/mpi3d_realistic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: mpi3d_realistic 5 | 6 | data: 7 | _target_: disent.dataset.data.Mpi3dData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | in_memory: ${dsettings.dataset.try_in_memory} 11 | subset: 'realistic' 12 | 13 | transform: 14 | _target_: disent.dataset.transform.ToImgTensorF32 15 | mean: ${dataset.meta.vis_mean} 16 | std: ${dataset.meta.vis_std} 17 | 18 | meta: 19 | x_shape: [3, 64, 64] 20 | vis_mean: [0.17986945797157425, 0.20474678611954758, 0.18148154235228137] 21 | vis_std: [0.08746476487506775, 0.09330995331830938, 0.09242232801328121] 22 | -------------------------------------------------------------------------------- /experiment/config/dataset/mpi3d_toy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: mpi3d_toy 5 | 6 | data: 7 | _target_: disent.dataset.data.Mpi3dData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | in_memory: ${dsettings.dataset.try_in_memory} 11 | subset: 'toy' 12 | 13 | transform: 14 | _target_: disent.dataset.transform.ToImgTensorF32 15 | mean: ${dataset.meta.vis_mean} 16 | std: ${dataset.meta.vis_std} 17 | 18 | meta: 19 | x_shape: [3, 64, 64] 20 | vis_mean: [0.22437325567325045, 0.22141877351640138, 0.22625457849943273] 21 | vis_std: [0.0690013611690731, 0.06343387069571882, 0.07607519758722009] 22 | -------------------------------------------------------------------------------- /experiment/config/dataset/shapes3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: 3dshapes 5 | 6 | data: 7 | _target_: disent.dataset.data.Shapes3dData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | in_memory: ${dsettings.dataset.try_in_memory} 11 | 12 | transform: 13 | _target_: disent.dataset.transform.ToImgTensorF32 14 | mean: ${dataset.meta.vis_mean} 15 | std: ${dataset.meta.vis_std} 16 | 17 | meta: 18 | x_shape: [3, 64, 64] 19 | vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] 20 | vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] 21 | -------------------------------------------------------------------------------- /experiment/config/dataset/smallnorb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: smallnorb 5 | 6 | data: 7 | _target_: disent.dataset.data.SmallNorb64Data 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | is_test: False 11 | 12 | transform: 13 | _target_: disent.dataset.transform.ToImgTensorF32 14 | mean: ${dataset.meta.vis_mean} 15 | std: ${dataset.meta.vis_std} 16 | 17 | meta: 18 | x_shape: [1, 64, 64] 19 | vis_mean: [0.7520918401088603] 20 | vis_std: [0.09563879016827262] 21 | -------------------------------------------------------------------------------- /experiment/config/dataset/sprites.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: sprites 5 | 6 | data: 7 | _target_: disent.dataset.data.SpritesData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | 11 | transform: 12 | _target_: disent.dataset.transform.ToImgTensorF32 13 | mean: ${dataset.meta.vis_mean} 14 | std: ${dataset.meta.vis_std} 15 | 16 | meta: 17 | x_shape: [3, 64, 64] 18 | vis_mean: [0.09906152159057463, 0.0778614646916404, 0.07261320645877936] 19 | vis_std: [0.23002326114948654, 0.19781224128167926, 0.18283647186482793] 20 | -------------------------------------------------------------------------------- /experiment/config/dataset/sprites_all.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: sprites 5 | 6 | data: 7 | _target_: disent.dataset.data.SpritesAllData 8 | data_root: ${dsettings.storage.data_root} 9 | prepare: ${dsettings.dataset.prepare} 10 | 11 | transform: 12 | _target_: disent.dataset.transform.ToImgTensorF32 13 | mean: ${dataset.meta.vis_mean} 14 | std: ${dataset.meta.vis_std} 15 | 16 | meta: 17 | x_shape: [3, 64, 64] 18 | vis_mean: [0.09933294682299235, 0.07689101333193574, 0.0724788139837905] 19 | vis_std: [0.22891812398973602, 0.19518729133092955, 0.18148902745291426] 20 | -------------------------------------------------------------------------------- /experiment/config/dataset/xyobject.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyobject 5 | 6 | data: 7 | _target_: disent.dataset.data.XYObjectData 8 | rgb: TRUE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [3, 64, 64] 17 | vis_mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] 18 | vis_std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] 19 | -------------------------------------------------------------------------------- /experiment/config/dataset/xyobject_grey.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyobject_grey 5 | 6 | data: 7 | _target_: disent.dataset.data.XYObjectData 8 | rgb: FALSE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [1, 64, 64] 17 | vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" 18 | vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" 19 | -------------------------------------------------------------------------------- /experiment/config/dataset/xyobject_shaded.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyobject_shaded 5 | 6 | data: 7 | _target_: disent.dataset.data.XYObjectShadedData 8 | rgb: TRUE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [3, 64, 64] 17 | vis_mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] 18 | vis_std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] 19 | -------------------------------------------------------------------------------- /experiment/config/dataset/xyobject_shaded_grey.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xyobject_shaded_grey 5 | 6 | data: 7 | _target_: disent.dataset.data.XYObjectShadedData 8 | rgb: FALSE 9 | 10 | transform: 11 | _target_: disent.dataset.transform.ToImgTensorF32 12 | mean: ${dataset.meta.vis_mean} 13 | std: ${dataset.meta.vis_std} 14 | 15 | meta: 16 | x_shape: [1, 64, 64] 17 | vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" 18 | vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" 19 | -------------------------------------------------------------------------------- /experiment/config/dataset/xysquares.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xysquares_minimal 5 | 6 | data: 7 | _target_: disent.dataset.data.XYSquaresMinimalData 8 | 9 | transform: 10 | _target_: disent.dataset.transform.ToImgTensorF32 11 | mean: ${dataset.meta.vis_mean} 12 | std: ${dataset.meta.vis_std} 13 | 14 | meta: 15 | x_shape: [3, 64, 64] 16 | vis_mean: [0.015625, 0.015625, 0.015625] 17 | vis_std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] 18 | -------------------------------------------------------------------------------- /experiment/config/dataset/xysquares_grey.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xysquares_grey 5 | 6 | data: 7 | _target_: disent.dataset.data.XYSquaresData 8 | square_size: 8 # AFFECTS: mean and std 9 | image_size: 64 # usually ok to adjust 10 | grid_size: 8 # usually ok to adjust 11 | grid_spacing: 8 # usually ok to adjust 12 | num_squares: 3 # AFFECTS: mean and std 13 | rgb: FALSE # AFFECTS: mean and std 14 | 15 | transform: 16 | _target_: disent.dataset.transform.ToImgTensorF32 17 | mean: ${dataset.meta.vis_mean} 18 | std: ${dataset.meta.vis_std} 19 | 20 | meta: 21 | x_shape: [1, 64, 64] 22 | vis_mean: [0.046146392822265625] 23 | vis_std: [0.2096506119375896] 24 | -------------------------------------------------------------------------------- /experiment/config/dataset/xysquares_rgb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _data_type_: gt 3 | 4 | name: xysquares_rgb 5 | 6 | data: 7 | _target_: disent.dataset.data.XYSquaresData 8 | square_size: 8 # AFFECTS: mean and std 9 | image_size: 64 # usually ok to adjust 10 | grid_size: 8 # usually ok to adjust 11 | grid_spacing: 8 # usually ok to adjust 12 | num_squares: 3 # AFFECTS: mean and std 13 | rgb: TRUE # AFFECTS: mean and std 14 | 15 | transform: 16 | _target_: disent.dataset.transform.ToImgTensorF32 17 | mean: ${dataset.meta.vis_mean} 18 | std: ${dataset.meta.vis_std} 19 | 20 | meta: 21 | x_shape: [3, 64, 64] 22 | vis_mean: [0.015625, 0.015625, 0.015625] 23 | vis_std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] 24 | -------------------------------------------------------------------------------- /experiment/config/framework/_input_mode_/pair.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the framework's defaults list 2 | name: pair 3 | num: 2 4 | -------------------------------------------------------------------------------- /experiment/config/framework/_input_mode_/single.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the framework's defaults list 2 | name: single 3 | num: 1 4 | -------------------------------------------------------------------------------- /experiment/config/framework/_input_mode_/triplet.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the framework's defaults list 2 | name: triplet 3 | num: 3 4 | -------------------------------------------------------------------------------- /experiment/config/framework/_input_mode_/weak_pair.yaml: -------------------------------------------------------------------------------- 1 | # controlled by the framework's defaults list 2 | name: weak_pair 3 | num: 2 4 | -------------------------------------------------------------------------------- /experiment/config/framework/adaae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: pair 3 | 4 | name: adaae 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.AdaAe.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | # adavae 16 | ada_thresh_ratio: 0.5 17 | 18 | meta: 19 | model_z_multiplier: 1 20 | -------------------------------------------------------------------------------- /experiment/config/framework/adaae_os.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: weak_pair # original sampling 3 | 4 | name: adaae 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.AdaAe.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | # adavae 16 | ada_thresh_ratio: 0.5 17 | 18 | meta: 19 | model_z_multiplier: 1 20 | -------------------------------------------------------------------------------- /experiment/config/framework/adagvae_minimal_os.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: weak_pair 3 | 4 | name: adagvae_minimal_os 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.AdaGVaeMinimal.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | 21 | meta: 22 | model_z_multiplier: 2 23 | -------------------------------------------------------------------------------- /experiment/config/framework/adanegtae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: triplet 3 | 4 | name: adanegtae 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.AdaNegTripletAe.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | # tvae: triplet stuffs 16 | triplet_loss: triplet 17 | triplet_margin_min: 0.001 18 | triplet_margin_max: 1 19 | triplet_scale: 0.1 20 | triplet_p: 1 21 | # adavae 22 | ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << 23 | # ada_tvae - loss 24 | adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" 25 | 26 | meta: 27 | model_z_multiplier: 1 28 | -------------------------------------------------------------------------------- /experiment/config/framework/adanegtae_d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: adanegtae_d 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.DataOverlapTripletAe.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | # tvae: triplet stuffs 16 | triplet_loss: triplet 17 | triplet_margin_min: 0.001 18 | triplet_margin_max: 1 19 | triplet_scale: 0.1 20 | triplet_p: 1 21 | # adavae 22 | ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << 23 | # ada_tvae - loss 24 | adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" 25 | # dotvae 26 | overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value 27 | overlap_num: 512 28 | overlap_mine_ratio: 0.1 29 | overlap_mine_triplet_mode: 'none' # none, hard_neg, semi_hard_neg, hard_pos, easy_pos, ran:hard_neg+hard_pos <- etc, dynamically evaluated, can chain multiple "+"s 30 | # dotvae -- augment 31 | overlap_augment_mode: 'none' 32 | overlap_augment: NULL 33 | 34 | meta: 35 | model_z_multiplier: 2 36 | -------------------------------------------------------------------------------- /experiment/config/framework/adanegtvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: triplet 3 | 4 | name: adanegtvae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.AdaNegTripletVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # tvae: triplet stuffs 21 | triplet_loss: triplet 22 | triplet_margin_min: 0.001 23 | triplet_margin_max: 1 24 | triplet_scale: 0.1 25 | triplet_p: 1 26 | # adavae 27 | ada_average_mode: gvae 28 | ada_thresh_mode: dist # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist 29 | ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << 30 | # ada_tvae - loss 31 | adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" 32 | # ada_tvae - averaging 33 | adat_share_mask_mode: posterior 34 | 35 | meta: 36 | model_z_multiplier: 2 37 | -------------------------------------------------------------------------------- /experiment/config/framework/adanegtvae_d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: adanegtvae_d 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.DataOverlapTripletVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # tvae: triplet stuffs 21 | triplet_loss: triplet 22 | triplet_margin_min: 0.001 23 | triplet_margin_max: 1 24 | triplet_scale: 0.1 25 | triplet_p: 1 26 | # adavae 27 | ada_average_mode: gvae 28 | ada_thresh_mode: dist # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist 29 | ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << 30 | # ada_tvae - loss 31 | adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" 32 | # ada_tvae - averaging 33 | adat_share_mask_mode: posterior 34 | # dotvae 35 | overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value 36 | overlap_num: 512 37 | overlap_mine_ratio: 0.1 38 | overlap_mine_triplet_mode: 'none' # none, hard_neg, semi_hard_neg, hard_pos, easy_pos, ran:hard_neg+hard_pos <- etc, dynamically evaluated, can chain multiple "+"s 39 | # dotvae -- augment 40 | overlap_augment_mode: 'augment' # none, augment, augment_each (if overlap_augment is NULL, then it is the same as setting this to "none") 41 | overlap_augment: NULL 42 | 43 | meta: 44 | model_z_multiplier: 2 45 | -------------------------------------------------------------------------------- /experiment/config/framework/adavae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: pair 3 | 4 | name: adavae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.AdaVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # adavae 21 | ada_average_mode: gvae # gvae or ml-vae 22 | ada_thresh_mode: symmetric_kl 23 | ada_thresh_ratio: 0.5 24 | 25 | meta: 26 | model_z_multiplier: 2 27 | -------------------------------------------------------------------------------- /experiment/config/framework/adavae_os.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: weak_pair # original sampling 3 | 4 | name: adavae_os 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.AdaVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # adavae 21 | ada_average_mode: gvae # gvae or ml-vae 22 | ada_thresh_mode: symmetric_kl 23 | ada_thresh_ratio: 0.5 24 | 25 | meta: 26 | model_z_multiplier: 2 27 | -------------------------------------------------------------------------------- /experiment/config/framework/ae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: ae 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.Ae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | 16 | meta: 17 | model_z_multiplier: 1 18 | -------------------------------------------------------------------------------- /experiment/config/framework/betatcvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: betatcvae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.BetaTcVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-TcVae 19 | beta: ${settings.framework.beta} 20 | 21 | meta: 22 | model_z_multiplier: 2 23 | -------------------------------------------------------------------------------- /experiment/config/framework/betavae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: betavae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.BetaVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | 21 | meta: 22 | model_z_multiplier: 2 23 | -------------------------------------------------------------------------------- /experiment/config/framework/dfcvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: dfcvae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.DfcVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # dfcvae 21 | feature_layers: ['14', '24', '34', '43'] 22 | feature_inputs_mode: 'none' # none, clamp, assert 23 | 24 | meta: 25 | model_z_multiplier: 2 26 | -------------------------------------------------------------------------------- /experiment/config/framework/dipvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: dipvae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.DipVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # DIP-VAE 21 | dip_mode: 'ii' # "i" or "ii" 22 | dip_beta: 1.0 23 | lambda_d: 1.0 # diagonal weight 24 | lambda_od: 0.5 # off diagonal weight 25 | 26 | meta: 27 | model_z_multiplier: 2 28 | -------------------------------------------------------------------------------- /experiment/config/framework/infovae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: infovae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.InfoVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Info-VAE 19 | # info vae is not based off beta vae, but with 20 | # the correct parameter choice this can equal the beta vae 21 | info_alpha: -0.5 22 | info_lambda: 5.0 23 | info_kernel: "rbf" # rbf kernel is the only kernel currently 24 | 25 | meta: 26 | model_z_multiplier: 2 27 | -------------------------------------------------------------------------------- /experiment/config/framework/tae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: triplet 3 | 4 | name: tae 5 | 6 | cfg: 7 | _target_: disent.frameworks.ae.TripletAe.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # disable various components 12 | detach_decoder: FALSE 13 | disable_rec_loss: FALSE 14 | disable_aug_loss: FALSE 15 | # tvae: triplet stuffs 16 | triplet_loss: triplet 17 | triplet_margin_min: 0.001 18 | triplet_margin_max: 1 19 | triplet_scale: 0.1 20 | triplet_p: 1 21 | 22 | meta: 23 | model_z_multiplier: 1 24 | -------------------------------------------------------------------------------- /experiment/config/framework/tvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: triplet 3 | 4 | name: tvae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.TripletVae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | # Beta-VAE 19 | beta: ${settings.framework.beta} 20 | # tvae: triplet stuffs 21 | triplet_loss: triplet 22 | triplet_margin_min: 0.001 23 | triplet_margin_max: 1 24 | triplet_scale: 0.1 25 | triplet_p: 1 26 | 27 | meta: 28 | model_z_multiplier: 2 29 | -------------------------------------------------------------------------------- /experiment/config/framework/vae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _input_mode_: single 3 | 4 | name: vae 5 | 6 | cfg: 7 | _target_: disent.frameworks.vae.Vae.cfg 8 | # base ae 9 | recon_loss: ${settings.framework.recon_loss} 10 | loss_reduction: ${settings.framework.loss_reduction} 11 | # base vae 12 | latent_distribution: ${settings.framework_opt.latent_distribution} 13 | # disable various components 14 | detach_decoder: FALSE 15 | disable_reg_loss: FALSE 16 | disable_rec_loss: FALSE 17 | disable_aug_loss: FALSE 18 | 19 | meta: 20 | model_z_multiplier: 2 21 | -------------------------------------------------------------------------------- /experiment/config/metrics/all.yaml: -------------------------------------------------------------------------------- 1 | metric_list: 2 | - flatness: {} 3 | - factored_components: {} 4 | - mig: {} 5 | - sap: {} 6 | - dci: 7 | every_n_steps: 7200 8 | on_final: TRUE 9 | - factor_vae: 10 | every_n_steps: 7200 11 | on_final: TRUE 12 | - unsupervised: {} 13 | 14 | # these are the default settings, these can be placed in the list above 15 | default_on_final: TRUE 16 | default_on_train: TRUE 17 | default_every_n_steps: 2400 18 | default_begin_first_step: FALSE 19 | -------------------------------------------------------------------------------- /experiment/config/metrics/fast.yaml: -------------------------------------------------------------------------------- 1 | metric_list: 2 | - flatness: {} 3 | - factored_components: {} 4 | - mig: {} 5 | - sap: {} 6 | - unsupervised: {} 7 | 8 | # these are the default settings, these can be placed in the list above 9 | default_on_final: TRUE 10 | default_on_train: TRUE 11 | default_every_n_steps: 2400 12 | default_begin_first_step: FALSE 13 | -------------------------------------------------------------------------------- /experiment/config/metrics/none.yaml: -------------------------------------------------------------------------------- 1 | metric_list: [] 2 | 3 | # these are the default settings, these can be placed in the list above 4 | default_on_final: TRUE 5 | default_on_train: TRUE 6 | default_every_n_steps: 1200 7 | default_begin_first_step: FALSE 8 | -------------------------------------------------------------------------------- /experiment/config/metrics/test.yaml: -------------------------------------------------------------------------------- 1 | metric_list: 2 | - flatness: 3 | every_n_steps: 110 4 | - factored_components: 5 | every_n_steps: 111 6 | - mig: 7 | every_n_steps: 112 8 | - sap: 9 | every_n_steps: 113 10 | - unsupervised: 11 | every_n_steps: 114 12 | - dci: 13 | every_n_steps: 115 14 | - factor_vae: 15 | every_n_steps: 116 16 | 17 | # these are the default settings, these can be placed in the list above 18 | default_on_final: FALSE 19 | default_on_train: TRUE 20 | default_every_n_steps: 200 21 | default_begin_first_step: FALSE 22 | -------------------------------------------------------------------------------- /experiment/config/model/linear.yaml: -------------------------------------------------------------------------------- 1 | name: linear 2 | 3 | model_cls: 4 | # weight initialisation 5 | _target_: disent.nn.weights.init_model_weights 6 | mode: ${settings.model.weight_init} 7 | model: 8 | # auto-encoder 9 | _target_: disent.model.AutoEncoder 10 | encoder: 11 | _target_: disent.model.ae.EncoderLinear 12 | x_shape: ${dataset.meta.x_shape} 13 | z_size: ${settings.model.z_size} 14 | z_multiplier: ${framework.meta.model_z_multiplier} 15 | decoder: 16 | _target_: disent.model.ae.DecoderLinear 17 | x_shape: ${dataset.meta.x_shape} 18 | z_size: ${settings.model.z_size} 19 | -------------------------------------------------------------------------------- /experiment/config/model/norm_conv64.yaml: -------------------------------------------------------------------------------- 1 | name: norm_conv64 2 | 3 | model_cls: 4 | # weight initialisation 5 | _target_: disent.nn.weights.init_model_weights 6 | mode: ${settings.model.weight_init} 7 | model: 8 | # auto-encoder 9 | _target_: disent.model.AutoEncoder 10 | encoder: 11 | _target_: disent.model.ae.EncoderConv64Norm 12 | x_shape: ${dataset.meta.x_shape} 13 | z_size: ${settings.model.z_size} 14 | z_multiplier: ${framework.meta.model_z_multiplier} 15 | activation: ${model.meta.activation} 16 | norm: ${model.meta.norm} 17 | norm_pre_act: ${model.meta.norm_pre_act} 18 | decoder: 19 | _target_: disent.model.ae.DecoderConv64Norm 20 | x_shape: ${dataset.meta.x_shape} 21 | z_size: ${settings.model.z_size} 22 | activation: ${model.meta.activation} 23 | norm: ${model.meta.norm} 24 | norm_pre_act: ${model.meta.norm_pre_act} 25 | 26 | meta: 27 | activation: swish # leaky_relu, relu 28 | norm: layer # batch, instance, layer, layer_chn, none 29 | norm_pre_act: TRUE 30 | -------------------------------------------------------------------------------- /experiment/config/model/vae_conv64.yaml: -------------------------------------------------------------------------------- 1 | name: vae_conv64 2 | 3 | model_cls: 4 | # weight initialisation 5 | _target_: disent.nn.weights.init_model_weights 6 | mode: ${settings.model.weight_init} 7 | model: 8 | # auto-encoder 9 | _target_: disent.model.AutoEncoder 10 | encoder: 11 | _target_: disent.model.ae.EncoderConv64 12 | x_shape: ${dataset.meta.x_shape} 13 | z_size: ${settings.model.z_size} 14 | z_multiplier: ${framework.meta.model_z_multiplier} 15 | decoder: 16 | _target_: disent.model.ae.DecoderConv64 17 | x_shape: ${dataset.meta.x_shape} 18 | z_size: ${settings.model.z_size} 19 | -------------------------------------------------------------------------------- /experiment/config/model/vae_fc.yaml: -------------------------------------------------------------------------------- 1 | name: vae_fc 2 | 3 | model_cls: 4 | # weight initialisation 5 | _target_: disent.nn.weights.init_model_weights 6 | mode: ${settings.model.weight_init} 7 | model: 8 | # auto-encoder 9 | _target_: disent.model.AutoEncoder 10 | encoder: 11 | _target_: disent.model.ae.EncoderFC 12 | x_shape: ${dataset.meta.x_shape} 13 | z_size: ${settings.model.z_size} 14 | z_multiplier: ${framework.meta.model_z_multiplier} 15 | decoder: 16 | _target_: disent.model.ae.DecoderFC 17 | x_shape: ${dataset.meta.x_shape} 18 | z_size: ${settings.model.z_size} 19 | -------------------------------------------------------------------------------- /experiment/config/optimizer/adabelief.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch_optimizer.AdaBelief 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} 8 | betas: [0.9, 0.999] 9 | eps: 1e-8 10 | weight_decay: 0 11 | 12 | amsgrad: False 13 | weight_decouple: False 14 | fixed_decay: False 15 | rectify: False 16 | -------------------------------------------------------------------------------- /experiment/config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch.optim.Adam 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} 8 | betas: [0.9, 0.999] 9 | eps: 1e-8 10 | weight_decay: 0 11 | 12 | amsgrad: False 13 | -------------------------------------------------------------------------------- /experiment/config/optimizer/amsgrad.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch.optim.Adam 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} 8 | betas: [0.9, 0.999] 9 | eps: 1e-8 10 | weight_decay: 0 11 | 12 | amsgrad: True 13 | -------------------------------------------------------------------------------- /experiment/config/optimizer/radam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch_optimizer.RAdam 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} 8 | betas: [0.9, 0.999] 9 | eps: 1e-8 10 | weight_decay: 0 11 | -------------------------------------------------------------------------------- /experiment/config/optimizer/rmsprop.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch.optim.RMSprop 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} # default was 1e-2 8 | alpha: 0.99 9 | eps: 1e-8 10 | weight_decay: 0 11 | 12 | momentum: 0 13 | centered: False 14 | -------------------------------------------------------------------------------- /experiment/config/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | framework: 4 | cfg: 5 | optimizer: torch.optim.SGD 6 | optimizer_kwargs: 7 | lr: ${settings.optimizer.lr} 8 | momentum: 0 9 | dampening: 0 10 | weight_decay: 0 11 | nesterov: False 12 | -------------------------------------------------------------------------------- /experiment/config/run_action/prepare_data.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | action: prepare_data 3 | 4 | # override settings from job/location 5 | dsettings: 6 | dataset: 7 | try_in_memory: FALSE 8 | prepare: TRUE 9 | -------------------------------------------------------------------------------- /experiment/config/run_action/skip.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | action: skip 3 | -------------------------------------------------------------------------------- /experiment/config/run_action/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | action: train 3 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/all.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 3600 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 3600 19 | traversal_repeats: 100 20 | begin_first_step: TRUE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | # empty! 5 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 3 9 | begin_first_step: FALSE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 4 19 | traversal_repeats: 3 20 | begin_first_step: FALSE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 3600 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 3600 19 | traversal_repeats: 100 20 | begin_first_step: TRUE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis_debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 600 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 600 19 | traversal_repeats: 50 20 | begin_first_step: FALSE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis_fast.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 1800 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 1800 19 | traversal_repeats: 100 20 | begin_first_step: TRUE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis_quick.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 600 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 1800 19 | traversal_repeats: 50 20 | begin_first_step: FALSE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis_skip_first.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 3600 9 | begin_first_step: FALSE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 3600 19 | traversal_repeats: 100 20 | begin_first_step: FALSE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_callbacks/vis_slow.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | callbacks: 4 | 5 | latent_cycle: 6 | _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback 7 | seed: 7777 8 | every_n_steps: 7200 9 | begin_first_step: TRUE 10 | mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' 11 | log_wandb: ${logging.wandb.enabled} 12 | recon_mean: ${dataset.meta.vis_mean} 13 | recon_std: ${dataset.meta.vis_std} 14 | 15 | gt_dists: 16 | _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback 17 | seed: 7777 18 | every_n_steps: 7200 19 | traversal_repeats: 100 20 | begin_first_step: TRUE 21 | log_wandb: ${logging.wandb.enabled} 22 | batch_size: ${settings.dataset.batch_size} 23 | -------------------------------------------------------------------------------- /experiment/config/run_launcher/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/launcher: basic 5 | -------------------------------------------------------------------------------- /experiment/config/run_launcher/slurm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/launcher: submitit_slurm 5 | 6 | hydra: 7 | launcher: 8 | partition: ${dsettings.launcher.partition} 9 | mem_gb: 0 10 | timeout_min: 1440 # minutes 11 | submitit_folder: '${hydra.sweep.dir}/%j' 12 | array_parallelism: ${dsettings.launcher.array_parallelism} 13 | exclude: ${dsettings.launcher.exclude} 14 | -------------------------------------------------------------------------------- /experiment/config/run_length/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 3 5 | max_steps: 3 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/epic.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 230400 5 | max_steps: 230400 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/long.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 115200 5 | max_steps: 115200 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/longmed.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 86400 5 | max_steps: 86400 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 57600 5 | max_steps: 57600 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/short.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 28800 5 | max_steps: 28800 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 5 5 | max_steps: 5 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 14400 5 | max_steps: 14400 6 | -------------------------------------------------------------------------------- /experiment/config/run_length/xtiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 7200 5 | max_steps: 7200 6 | -------------------------------------------------------------------------------- /experiment/config/run_location/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dsettings: 4 | trainer: 5 | cuda: NULL # `NULL` tries to use CUDA if it is available, otherwise defaulting to the CPU 6 | storage: 7 | logs_dir: 'logs' 8 | data_root: '/tmp/${oc.env:USER}/datasets' 9 | dataset: 10 | prepare: TRUE 11 | try_in_memory: TRUE 12 | 13 | datamodule: 14 | gpu_augment: FALSE 15 | prepare_data_per_node: TRUE 16 | dataloader: 17 | num_workers: 8 18 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 19 | batch_size: ${settings.dataset.batch_size} 20 | 21 | hydra: 22 | job: 23 | name: 'disent' 24 | run: 25 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 26 | sweep: 27 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 28 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 29 | -------------------------------------------------------------------------------- /experiment/config/run_location/local_cpu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dsettings: 4 | trainer: 5 | cuda: FALSE # The job will only use the CPU 6 | storage: 7 | logs_dir: 'logs' 8 | data_root: '/tmp/${oc.env:USER}/datasets' 9 | dataset: 10 | prepare: TRUE 11 | try_in_memory: TRUE 12 | 13 | datamodule: 14 | gpu_augment: FALSE 15 | prepare_data_per_node: TRUE 16 | dataloader: 17 | num_workers: 8 18 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 19 | batch_size: ${settings.dataset.batch_size} 20 | 21 | hydra: 22 | job: 23 | name: 'disent' 24 | run: 25 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 26 | sweep: 27 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 28 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 29 | -------------------------------------------------------------------------------- /experiment/config/run_location/local_gpu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dsettings: 4 | trainer: 5 | cuda: TRUE # `TRUE` forces cuda to be used. The job fails if cuda is not available! 6 | storage: 7 | logs_dir: 'logs' 8 | data_root: '/tmp/${oc.env:USER}/datasets' 9 | dataset: 10 | prepare: TRUE 11 | try_in_memory: TRUE 12 | 13 | datamodule: 14 | gpu_augment: FALSE 15 | prepare_data_per_node: TRUE 16 | dataloader: 17 | num_workers: 8 18 | pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! 19 | batch_size: ${settings.dataset.batch_size} 20 | 21 | hydra: 22 | job: 23 | name: 'disent' 24 | run: 25 | dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 26 | sweep: 27 | dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' 28 | subdir: '${hydra.job.id}' # hydra.job.id is not available for dir 29 | -------------------------------------------------------------------------------- /experiment/config/run_logging/none.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/job_logging: colorlog 5 | - override /hydra/hydra_logging: colorlog 6 | 7 | trainer: 8 | log_every_n_steps: 50 9 | enable_progress_bar: FALSE # disable the builtin progress bar 10 | 11 | callbacks: 12 | progress: 13 | _target_: disent.util.lightning.callbacks.LoggerProgressCallback 14 | interval: 5 15 | 16 | logging: 17 | wandb: 18 | enabled: FALSE 19 | loggers: NULL 20 | -------------------------------------------------------------------------------- /experiment/config/run_logging/wandb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/job_logging: colorlog 5 | - override /hydra/hydra_logging: colorlog 6 | 7 | trainer: 8 | log_every_n_steps: 100 9 | enable_progress_bar: FALSE # disable the builtin progress bar 10 | 11 | callbacks: 12 | progress: 13 | _target_: disent.util.lightning.callbacks.LoggerProgressCallback 14 | interval: 15 15 | 16 | logging: 17 | wandb: 18 | enabled: TRUE 19 | loggers: 20 | _target_: lightning.pytorch.loggers.WandbLogger 21 | offline: FALSE 22 | entity: ${settings.job.user} 23 | project: ${settings.job.project} 24 | name: ${settings.job.name} 25 | group: NULL 26 | tags: [] 27 | save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd 28 | # https://docs.wandb.ai/guides/track/launch#init-start-error 29 | settings: 30 | _target_: wandb.Settings 31 | start_method: "fork" # fork: linux/macos, thread: google colab 32 | -------------------------------------------------------------------------------- /experiment/config/run_logging/wandb_fast.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/job_logging: colorlog 5 | - override /hydra/hydra_logging: colorlog 6 | 7 | trainer: 8 | log_every_n_steps: 50 9 | enable_progress_bar: FALSE # disable the builtin progress bar 10 | 11 | callbacks: 12 | progress: 13 | _target_: disent.util.lightning.callbacks.LoggerProgressCallback 14 | interval: 5 15 | 16 | logging: 17 | wandb: 18 | enabled: TRUE 19 | loggers: 20 | _target_: lightning.pytorch.loggers.WandbLogger 21 | offline: FALSE 22 | entity: ${settings.job.user} 23 | project: ${settings.job.project} 24 | name: ${settings.job.name} 25 | group: NULL 26 | tags: [] 27 | save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd 28 | # https://docs.wandb.ai/guides/track/launch#init-start-error 29 | settings: 30 | _target_: wandb.Settings 31 | start_method: "fork" # fork: linux/macos, thread: google colab 32 | -------------------------------------------------------------------------------- /experiment/config/run_logging/wandb_fast_offline.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/job_logging: colorlog 5 | - override /hydra/hydra_logging: colorlog 6 | 7 | trainer: 8 | log_every_n_steps: 50 9 | enable_progress_bar: FALSE # disable the builtin progress bar 10 | 11 | callbacks: 12 | progress: 13 | _target_: disent.util.lightning.callbacks.LoggerProgressCallback 14 | interval: 5 15 | 16 | logging: 17 | wandb: 18 | enabled: TRUE 19 | loggers: 20 | _target_: lightning.pytorch.loggers.WandbLogger 21 | offline: TRUE 22 | entity: ${settings.job.user} 23 | project: ${settings.job.project} 24 | name: ${settings.job.name} 25 | group: NULL 26 | tags: [] 27 | save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd 28 | # https://docs.wandb.ai/guides/track/launch#init-start-error 29 | settings: 30 | _target_: wandb.Settings 31 | start_method: "fork" # fork: linux/macos, thread: google colab 32 | -------------------------------------------------------------------------------- /experiment/config/run_logging/wandb_slow.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/job_logging: colorlog 5 | - override /hydra/hydra_logging: colorlog 6 | 7 | trainer: 8 | log_every_n_steps: 200 9 | enable_progress_bar: FALSE # disable the builtin progress bar 10 | 11 | callbacks: 12 | progress: 13 | _target_: disent.util.lightning.callbacks.LoggerProgressCallback 14 | interval: 30 15 | 16 | logging: 17 | wandb: 18 | enabled: TRUE 19 | loggers: 20 | _target_: lightning.pytorch.loggers.WandbLogger 21 | offline: FALSE 22 | entity: ${settings.job.user} 23 | project: ${settings.job.project} 24 | name: ${settings.job.name} 25 | group: NULL 26 | tags: [] 27 | save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd 28 | # https://docs.wandb.ai/guides/track/launch#init-start-error 29 | settings: 30 | _target_: wandb.Settings 31 | start_method: "fork" # fork: linux/macos, thread: google colab 32 | -------------------------------------------------------------------------------- /experiment/config/run_plugins/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # call the listed functions here before the experiment is started 4 | # - this can be used to register functions or metrics to the disent registry for example! 5 | experiment: 6 | plugins: [] 7 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/episodes__pair.yaml: -------------------------------------------------------------------------------- 1 | name: episodes__pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomEpisodeSampler 5 | num_samples: 2 6 | # TODO: this needs to be updated to use the same API as ground_truth wrappers. 7 | sample_radius: 32 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/episodes__single.yaml: -------------------------------------------------------------------------------- 1 | name: episodes__single 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomEpisodeSampler 5 | num_samples: 1 6 | # TODO: this needs to be updated to use the same API as ground_truth wrappers. 7 | sample_radius: 32 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/episodes__triplet.yaml: -------------------------------------------------------------------------------- 1 | name: episodes__triplet 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomEpisodeSampler 5 | num_samples: 3 6 | # TODO: this needs to be updated to use the same API as ground_truth wrappers. 7 | sample_radius: 32 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/episodes__weak_pair.yaml: -------------------------------------------------------------------------------- 1 | name: episodes__weak_pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomEpisodeSampler 5 | num_samples: 2 6 | # TODO: this needs to be updated to use the same API as ground_truth wrappers. 7 | sample_radius: 32 8 | 9 | # ================================================== # 10 | # NOTE!!! THIS IS A DUMMY WRAPPER ,SO WE DON'T CRASH # 11 | # WHEN WE DO GRID SEARCHES WITH RUN EPISODE DATASETS # 12 | # ================================================== # 13 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt__pair.yaml: -------------------------------------------------------------------------------- 1 | name: gt__pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthPairSampler 5 | # factor sampling 6 | p_k_range: ${sampling.k} 7 | # radius sampling 8 | p_radius_range: ${sampling.k_radius} 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt__single.yaml: -------------------------------------------------------------------------------- 1 | name: gt__single 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthSingleSampler 5 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt__triplet.yaml: -------------------------------------------------------------------------------- 1 | name: gt__triplet 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthTripleSampler 5 | # factor sampling 6 | p_k_range: ${sampling.k} 7 | n_k_range: ${sampling.n_k} 8 | n_k_sample_mode: ${sampling.n_k_mode} 9 | n_k_is_shared: TRUE 10 | # radius sampling 11 | p_radius_range: ${sampling.k_radius} 12 | n_radius_range: ${sampling.n_k_radius} 13 | n_radius_sample_mode: ${sampling.n_k_radius_mode} 14 | # final checks 15 | swap_metric: ${sampling.swap_metric} 16 | swap_chance: ${sampling.swap_chance} 17 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt__weak_pair.yaml: -------------------------------------------------------------------------------- 1 | name: gt__weak_pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthPairOrigSampler 5 | # factor sampling 6 | p_k: ${sampling.k.1} 7 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt_dist__pair.yaml: -------------------------------------------------------------------------------- 1 | name: gt_dist__pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthDistSampler 5 | num_samples: 2 6 | triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined 7 | triplet_swap_chance: ${sampling.triplet_swap_chance} 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt_dist__single.yaml: -------------------------------------------------------------------------------- 1 | name: gt_dist__single 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthDistSampler 5 | num_samples: 1 6 | triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined 7 | triplet_swap_chance: ${sampling.triplet_swap_chance} 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt_dist__triplet.yaml: -------------------------------------------------------------------------------- 1 | name: gt_dist__triplet 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthDistSampler 5 | num_samples: 3 6 | triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined 7 | triplet_swap_chance: ${sampling.triplet_swap_chance} 8 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/gt_dist__weak_pair.yaml: -------------------------------------------------------------------------------- 1 | name: gt_dist__weak_pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.GroundTruthDistSampler 5 | num_samples: 2 6 | triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined 7 | triplet_swap_chance: ${sampling.triplet_swap_chance} 8 | 9 | # ================================================== # 10 | # NOTE!!! THIS IS A DUMMY WRAPPER ,SO WE DON'T CRASH # 11 | # WHEN WE DO GRID SEARCHES WITH RUN EPISODE DATASETS # 12 | # ================================================== # 13 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/random__pair.yaml: -------------------------------------------------------------------------------- 1 | name: random__pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomSampler 5 | num_samples: 2 6 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/random__single.yaml: -------------------------------------------------------------------------------- 1 | name: random__single 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomSampler 5 | num_samples: 1 6 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/random__triplet.yaml: -------------------------------------------------------------------------------- 1 | name: random__triplet 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomSampler 5 | num_samples: 3 6 | -------------------------------------------------------------------------------- /experiment/config/sampling/_sampler_/random__weak_pair.yaml: -------------------------------------------------------------------------------- 1 | name: random__weak_pair 2 | 3 | sampler_cls: 4 | _target_: disent.dataset.sampling.RandomSampler 5 | num_samples: 2 6 | 7 | # ================================================== # 8 | # NOTE!!! THIS IS A DUMMY WRAPPER ,SO WE DON'T CRASH # 9 | # WHEN WE DO GRID SEARCHES WITH RUN EPISODE DATASETS # 10 | # ================================================== # 11 | -------------------------------------------------------------------------------- /experiment/config/sampling/default.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: choose the default from the framework and dataset 2 | defaults: 3 | - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} 4 | 5 | name: default 6 | 7 | # this config forces an error to be thrown if 8 | # sampler config settings are required. 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/default__bb.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: choose the default from the framework and dataset 2 | defaults: 3 | - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} 4 | 5 | name: default__bb 6 | 7 | # varying factors (if applicable for pairs) -- sample in range: [min, max] 8 | k: [0, -1] 9 | k_radius: [0, -1] 10 | # varying factors (if applicable for triplets) -- sample in range: [min, max] 11 | n_k: [0, -1] 12 | n_k_mode: 'bounded_below' 13 | n_k_radius: [0, -1] 14 | n_k_radius_mode: 'bounded_below' 15 | # swap incorrect samples 16 | swap_metric: NULL 17 | # swap positive and negative if possible 18 | swap_chance: NULL 19 | -------------------------------------------------------------------------------- /experiment/config/sampling/default__ran_l1.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: choose the default from the framework and dataset 2 | defaults: 3 | - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} 4 | 5 | name: default__ran_l1 6 | 7 | # varying factors (if applicable for pairs) -- sample in range: [min, max] 8 | k: [0, -1] 9 | k_radius: [0, -1] 10 | # varying factors (if applicable for triplets) -- sample in range: [min, max] 11 | n_k: [0, -1] 12 | n_k_mode: 'random' 13 | n_k_radius: [0, -1] 14 | n_k_radius_mode: 'random' 15 | # swap incorrect samples 16 | swap_metric: 'manhattan' 17 | # swap positive and negative if possible 18 | swap_chance: NULL 19 | -------------------------------------------------------------------------------- /experiment/config/sampling/default__ran_l2.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: choose the default from the framework and dataset 2 | defaults: 3 | - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} 4 | 5 | name: default__ran_l2 6 | 7 | # varying factors (if applicable for pairs) -- sample in range: [min, max] 8 | k: [0, -1] 9 | k_radius: [0, -1] 10 | # varying factors (if applicable for triplets) -- sample in range: [min, max] 11 | n_k: [0, -1] 12 | n_k_mode: 'random' 13 | n_k_radius: [0, -1] 14 | n_k_radius_mode: 'random' 15 | # swap incorrect samples 16 | swap_metric: 'euclidean' 17 | # swap positive and negative if possible 18 | swap_chance: NULL 19 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__combined.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__combined 6 | 7 | triplet_sample_mode: "combined" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__combined_scaled.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__combined_scaled 6 | 7 | triplet_sample_mode: "combined_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__factors.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__factors 6 | 7 | triplet_sample_mode: "factors" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__manhat.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__manhat 6 | 7 | triplet_sample_mode: "manhattan" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__manhat_scaled.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__manhat_scaled 6 | 7 | triplet_sample_mode: 'manhattan_scaled' # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/gt_dist__random.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets 2 | defaults: 3 | - _sampler_: gt_dist__${framework/_input_mode_} 4 | 5 | name: gt_dist__random 6 | 7 | triplet_sample_mode: "random" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled 8 | triplet_swap_chance: 0 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/none.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force the user to choose a different sampling strategy 2 | defaults: 3 | - _sampler_: "${exit:EXITING... please specify in the defaults list a sampling method other than none}" 4 | 5 | name: none 6 | 7 | # this config forces an error to be thrown! This is to make 8 | # sure that we don't encounter errors when updating old configs. 9 | -------------------------------------------------------------------------------- /experiment/config/sampling/random.yaml: -------------------------------------------------------------------------------- 1 | # SPECIALIZATION: force the random strategy to be used as the dataset sampler 2 | defaults: 3 | - _sampler_: random__${framework/_input_mode_} 4 | 5 | name: random 6 | 7 | # this config forces an error to be thrown if 8 | # sampler config settings are required. 9 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_all.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_all 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.5 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: 17 | _target_: disent.schedule.LinearSchedule 18 | start_step: 0 19 | end_step: ${trainer.max_steps} 20 | r_start: 0.0 # none averaged 21 | r_end: 1.0 # all averaged, should this not be 0.5 the recommended value 22 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_all_full.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_all_full 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.0 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: 17 | _target_: disent.schedule.LinearSchedule 18 | start_step: 0 19 | end_step: ${trainer.max_steps} 20 | r_start: 0.0 # none averaged 21 | r_end: 1.0 # all averaged, should this not be 0.5 the recommended value 22 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_all_weak.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_all_weak 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.75 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: 17 | _target_: disent.schedule.LinearSchedule 18 | start_step: 0 19 | end_step: ${trainer.max_steps} 20 | r_start: 0.0 # none averaged 21 | r_end: 1.0 # all averaged, should this not be 0.5 the recommended value 22 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_ratio.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_ratio 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.5 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: NULL 17 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_ratio_full.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_ratio_full 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.0 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: NULL 17 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_ratio_weak.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_ratio_weak 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 1.0 7 | schedule: 8 | _target_: disent.schedule.LinearSchedule 9 | start_step: 0 10 | end_step: ${trainer.max_steps} 11 | r_start: 1.0 # normal triplet 12 | r_end: 0.75 # ada weighted triplet 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: NULL 17 | -------------------------------------------------------------------------------- /experiment/config/schedule/adanegtvae_up_thresh.yaml: -------------------------------------------------------------------------------- 1 | name: adanegtvae_up_thresh 2 | 3 | schedule_items: 4 | adat_triplet_share_scale: 5 | _target_: disent.schedule.FixedValueSchedule 6 | value: 0.5 7 | schedule: NULL 8 | # | 9 | # | 10 | # | 11 | # | 12 | # | 13 | ada_thresh_ratio: 14 | _target_: disent.schedule.FixedValueSchedule 15 | value: 0.5 16 | schedule: 17 | _target_: disent.schedule.LinearSchedule 18 | start_step: 0 19 | end_step: ${trainer.max_steps} 20 | r_start: 0.0 # none averaged 21 | r_end: 1.0 # all averaged, should this not be 0.5 the recommended value 22 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_cyclic.yaml: -------------------------------------------------------------------------------- 1 | name: beta_cyclic 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Cyclic 6 | period: 7200 7 | start_step: 3600 8 | repeats: NULL 9 | r_start: 0.001 10 | r_end: 1.0 11 | end_mode: 'end' 12 | mode: 'cosine' 13 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_cyclic_fast.yaml: -------------------------------------------------------------------------------- 1 | name: beta_cyclic_fast 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Cyclic 6 | period: 3600 7 | start_step: 3600 8 | repeats: NULL 9 | r_start: 0.001 10 | r_end: 1.0 11 | end_mode: 'end' 12 | mode: 'cosine' 13 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_cyclic_slow.yaml: -------------------------------------------------------------------------------- 1 | name: beta_cyclic_slow 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Cyclic 6 | period: 14400 7 | start_step: 3600 8 | repeats: NULL 9 | r_start: 0.001 10 | r_end: 1.0 11 | end_mode: 'end' 12 | mode: 'cosine' 13 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_decrease.yaml: -------------------------------------------------------------------------------- 1 | name: beta_decrease 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Single 6 | start_step: 0 7 | end_step: ${trainer.max_steps} 8 | r_start: 1.0 9 | r_end: 0.001 10 | mode: 'linear' 11 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_delay.yaml: -------------------------------------------------------------------------------- 1 | name: beta_increase 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Single 6 | start_step: 3600 7 | end_step: 7200 8 | r_start: 0.001 9 | r_end: 1.0 10 | mode: 'linear' 11 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_delay_long.yaml: -------------------------------------------------------------------------------- 1 | name: beta_increase 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Single 6 | start_step: 7200 7 | end_step: 14400 8 | r_start: 0.001 9 | r_end: 1.0 10 | mode: 'linear' 11 | -------------------------------------------------------------------------------- /experiment/config/schedule/beta_increase.yaml: -------------------------------------------------------------------------------- 1 | name: beta_increase 2 | 3 | schedule_items: 4 | beta: 5 | _target_: disent.schedule.Single 6 | start_step: 0 7 | end_step: ${trainer.max_steps} 8 | r_start: 0.001 9 | r_end: 1.0 10 | mode: 'linear' 11 | -------------------------------------------------------------------------------- /experiment/config/schedule/none.yaml: -------------------------------------------------------------------------------- 1 | name: none 2 | 3 | schedule_items: {} 4 | -------------------------------------------------------------------------------- /experiment/util/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is currently hacky, and hopefully temporary! See: 3 | https://github.com/facebookresearch/hydra/issues/2001 4 | """ 5 | 6 | import logging 7 | import os 8 | 9 | from hydra.core.config_search_path import ConfigSearchPath 10 | from hydra.plugins.search_path_plugin import SearchPathPlugin 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | class DisentExperimentSearchPathPlugin(SearchPathPlugin): 16 | def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: 17 | from experiment.util.hydra_main import _DISENT_CONFIG_DIRS 18 | 19 | # find paths 20 | paths = [ 21 | *os.environ.get("DISENT_CONFIGS_PREPEND", "").split(";"), 22 | *_DISENT_CONFIG_DIRS, 23 | *os.environ.get("DISENT_CONFIGS_APPEND", "").split(";"), 24 | ] 25 | # print information 26 | log.info(f" [disent-search-path-plugin]: Activated hydra plugin: {self.__class__.__name__}") 27 | log.info( 28 | f" [disent-search-path-plugin]: To register more search paths, adjust the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables!" 29 | ) 30 | # add paths 31 | for path in paths: 32 | if path: 33 | log.info(f" [disent-search-path] - {repr(path)}") 34 | search_path.append(provider="disent-searchpath-plugin", path=os.path.abspath(path)) 35 | -------------------------------------------------------------------------------- /experiment/util/hydra_utils.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | import logging 26 | from copy import deepcopy 27 | from typing import Optional 28 | from typing import Sequence 29 | 30 | import hydra 31 | from omegaconf import DictConfig 32 | from omegaconf import ListConfig 33 | from omegaconf import OmegaConf 34 | 35 | from disent.util.deprecate import deprecated 36 | 37 | log = logging.getLogger(__name__) 38 | 39 | 40 | # ========================================================================= # 41 | # Helper # 42 | # ========================================================================= # 43 | 44 | 45 | def make_non_strict(cfg: DictConfig): 46 | """ 47 | Convert the config into a mutable version. 48 | """ 49 | cfg = deepcopy(cfg) 50 | return OmegaConf.create({**cfg}) 51 | 52 | 53 | # ========================================================================= # 54 | # END # 55 | # ========================================================================= # 56 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: 🧶 Disent Docs 2 | repo_url: https://github.com/nmichlo/disent 3 | repo_name: nmichlo/disent 4 | theme: 5 | name: material 6 | palette: 7 | scheme: default 8 | primary: pink 9 | icon: 10 | repo: fontawesome/brands/github 11 | logo: material/library 12 | favicon: images/favicon.png 13 | plugins: 14 | - search 15 | - mkdocstrings # reference functions and code in markdown `::: module.class.func` 16 | - git-revision-date-localized: # visible last edit date on each page 17 | type: date 18 | fallback_to_build_date: false 19 | markdown_extensions: 20 | - admonition 21 | - pymdownx.details 22 | - pymdownx.highlight 23 | - pymdownx.inlinehilite 24 | - pymdownx.superfences 25 | - pymdownx.snippets 26 | - pymdownx.tabbed 27 | - pymdownx.arithmatex: 28 | generic: true 29 | # THE !! CURRENTLY BREAKS READTHEDOCS 30 | # https://github.com/readthedocs/readthedocs.org/issues/7865 31 | # - pymdownx.emoji: 32 | # emoji_index: !!python/name:materialx.emoji.twemoji 33 | # emoji_generator: !!python/name:materialx.emoji.to_svg 34 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | 2 | [pytest] 3 | minversion = 6.0 4 | testpaths = 5 | tests 6 | disent 7 | python_files = 8 | test_*.py 9 | __test__*.py 10 | -------------------------------------------------------------------------------- /requirements-extra.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | 3 | # OPTIONAL DEPS INCLUDING 4 | # - those referenced or imported in disent, but are optional. 5 | # - those used in documentation examples 6 | # - those used in experiment code or hydra config targets 7 | # ============= 8 | 9 | # -- INPUT / OUTPUT -- # 10 | GitPython>=3.0.0 # dataset downloads 11 | imageio>=2.9.0 # required for wandb video logging 12 | moviepy>=1.0.3 # required for wandb video logging 13 | psutil>=5.8.0 14 | 15 | # -- GRAPHING & LOGGING -- # 16 | matplotlib>=3 17 | wandb>=0.10.32 18 | 19 | # -- CONFIGS -- # 20 | omegaconf>=2.1.0 # only 2.1.0 supports nested variable interpolation eg. ${group.${group.key}} 21 | hydra-core==1.1.1 # needs omegaconf 22 | hydra-colorlog==1.1.0 23 | hydra-submitit-launcher==1.1.6 24 | 25 | # -- CONFIG TARGETS -- # 26 | torch_optimizer>=0.3.0 27 | 28 | # -- TORCH EXTENSIONS -- # 29 | # requires pytorch to be installed first (duplicated in requirements-experiment.txt) 30 | # - we need `nvcc` to be installed first, otherwise GPU kernel extensions will not be 31 | # compiled and this error will silently be skipped. If you get an error such as 32 | # $ conda install -c nvidia cuda-nvcc 33 | # - Make sure that the version of torch corresponds to the version of `nvcc`, torch needs 34 | # to be compiled with the same version! Install the correct version from: 35 | # https://pytorch.org/get-started/locally/ By default torch compiled with 10.2 is installed, 36 | # but `nvcc` will probably want to install 11. 37 | # CUDA 10.2 (as of 2022-03-15) EITHER OF: 38 | # $ pip3 install torch torchvision torchaudio 39 | # $ conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 40 | # CUDA 11.3 (as of 2022-03-15) EITHER OF: 41 | # $ pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 42 | # $ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 43 | # - I personally just manage my cuda version manually, installing the correct cudatoolkit from: https://developer.nvidia.com/cuda-toolkit-archive 44 | # Then making sure that: 45 | # PATH contains: "/usr/local/cuda/bin" 46 | # LD_LIBRARY_PATH contains: "/usr/local/cuda/lib64" 47 | torchsort>=0.1.4 48 | 49 | # -- JIT -- # 50 | numba>=0.50.0 # optimised sampling 51 | 52 | # -- Gradient Boosting -- # 53 | # lightgbm 54 | # xgboost 55 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | -r requirements-extra.txt 3 | 4 | # OTHER REQUIREMENTS: 5 | # ================== 6 | 7 | pytest>=6.2.4 8 | pytest-cov>=2.12.1 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip>=21.0 2 | 3 | # -- DATA SCIENCE & ML -- # 4 | numpy>=1.19.0 5 | torch>=2.0.0 6 | torchvision>=0.15.0 7 | lightning>=2.0.0 8 | scipy>=1.7.0 9 | scikit-learn>=1.0.0 10 | 11 | # -- INPUT & OUTPUT -- # 12 | Pillow>=8.2.0 13 | h5py>=2.10.0 # as of tensorflow 2.4 it does not support h5py 3+ 14 | requests 15 | 16 | # -- UTILITY -- # 17 | tqdm>=4.60.0 18 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | -------------------------------------------------------------------------------- /tests/test_000_import.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2022 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | 25 | 26 | # THIS TEST FILE SHOULD ALWAYS BE LOADED AND RUN FIRST 27 | from disent.frameworks.vae import BetaVae 28 | 29 | 30 | def test_000_import(): 31 | assert BetaVae 32 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 2 | # MIT License 3 | # 4 | # Copyright (c) 2021 Nathan Juraj Michlo 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ 24 | import pytest 25 | import torch 26 | 27 | from disent.model import AutoEncoder 28 | from disent.model import DisentDecoder 29 | from disent.model import DisentEncoder 30 | from disent.model.ae import * 31 | 32 | 33 | @pytest.mark.parametrize( 34 | ["encoder_cls", "decoder_cls"], 35 | [ 36 | [EncoderConv64, DecoderConv64], 37 | [EncoderConv64Norm, DecoderConv64Norm], 38 | [EncoderFC, DecoderFC], 39 | [EncoderLinear, DecoderLinear], 40 | ], 41 | ) 42 | def test_ae_models(encoder_cls: DisentEncoder, decoder_cls: DisentDecoder): 43 | x_shape, z_size = (3, 64, 64), 8 44 | # create model 45 | vae = AutoEncoder( 46 | encoder_cls(x_shape=x_shape, z_size=z_size, z_multiplier=2), 47 | decoder_cls(x_shape=x_shape, z_size=z_size, z_multiplier=1), 48 | ) 49 | # feed forward 50 | with torch.no_grad(): 51 | x = torch.randn(1, *x_shape, dtype=torch.float32) 52 | assert x.shape == (1, *x_shape) 53 | z0, z1 = vae.encode(x, chunk=True) 54 | assert z0.shape == (1, z_size) 55 | assert z1.shape == (1, z_size) 56 | y = vae.decode(z0) 57 | assert y.shape == (1, *x_shape) 58 | --------------------------------------------------------------------------------